You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2021/03/26 00:56:32 UTC

[incubator-sedona] branch master updated: [SEDONA-26] Add broadcast join support (#515)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 369b51a  [SEDONA-26] Add broadcast join support (#515)
369b51a is described below

commit 369b51a68a7983b0c6fa79c313a29bdc1ff6be11
Author: Adam Binford <ad...@gmail.com>
AuthorDate: Thu Mar 25 20:56:26 2021 -0400

    [SEDONA-26] Add broadcast join support (#515)
    
    * Initial broadcast join working in Spark 3
    
    * Add distance join, extra condition, and spark 2 support for broadcast join
    
    * Add some documentation for broadcast joins
    
    * Fix small issue with broadcast distance joins with expression for radius
    
    * Simplify tests
    
    * Use custom enum for Spark 3.1 compatibility and try to add a python test
    
    * Simplify join detection with ST_Predicate parent class and detect both sides of an And expression
    
    Co-authored-by: Adam Binford <ad...@maxar.com>
---
 .gitignore                                         |   3 +
 docs/api/sql/GeoSparkSQL-Optimizer.md              |  37 ++++
 docs/api/sql/GeoSparkSQL-Overview.md               |   5 +
 python-adapter/.gitignore                          |   4 +
 python/tests/sql/test_predicate_join.py            |  47 +++++
 spark-version-converter.py                         |   3 +-
 .../sedona/sql/utils/SedonaSQLRegistrator.scala    |   8 +-
 .../sql/sedona_sql/expressions/Predicates.scala    |  16 +-
 .../strategy/join/BroadcastIndexJoinExec.scala     | 130 +++++++++++++
 .../strategy/join/DistanceJoinExec.scala           |   2 +-
 .../strategy/join/JoinQueryDetector.scala          | 195 +++++++++++++------
 .../strategy/join/SpatialIndexExec.scala           |  62 +++++++
 ...anceJoinExec.scala => TraitJoinQueryBase.scala} |  53 +++---
 .../strategy/join/TraitJoinQueryExec.scala         |  45 +----
 .../sql/sedona_sql/strategy/join/enums.scala}      |  26 +--
 .../sedona/sql/BroadcastIndexJoinSuite.scala       | 206 +++++++++++++++++++++
 .../org/apache/sedona/sql/TestBaseScala.scala      |  10 +-
 .../apache/sedona/sql/predicateJoinTestScala.scala |   5 -
 18 files changed, 697 insertions(+), 160 deletions(-)

diff --git a/.gitignore b/.gitignore
index bbeea00..4ced234 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,3 +10,6 @@
 /conf/
 /log/
 /site/
+/.bloop/
+/.metals/
+/.vscode/
\ No newline at end of file
diff --git a/docs/api/sql/GeoSparkSQL-Optimizer.md b/docs/api/sql/GeoSparkSQL-Optimizer.md
index 222a8de..92a66f2 100644
--- a/docs/api/sql/GeoSparkSQL-Optimizer.md
+++ b/docs/api/sql/GeoSparkSQL-Optimizer.md
@@ -72,6 +72,43 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33: geometry, 2.0, true
 !!!warning
 	Sedona doesn't control the distance's unit (degree or meter). It is same with the geometry. To change the geometry's unit, please transform the coordinate reference system. See [ST_Transform](GeoSparkSQL-Function.md#st_transform).
 
+## Broadcast join
+Introduction: Perform a range join or distance join but broadcast one of the sides of the join. This maintains the partitioning of the non-broadcast side and doesn't require a shuffle.
+
+```Scala
+pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+```
+
+Spark SQL Physical plan:
+```
+== Physical Plan ==
+BroadcastIndexJoin pointshape#52: geometry, BuildRight, BuildRight, false ST_Contains(polygonshape#30, pointshape#52)
+:- Project [st_point(cast(_c0#48 as decimal(24,20)), cast(_c1#49 as decimal(24,20))) AS pointshape#52]
+:  +- FileScan csv
++- SpatialIndex polygonshape#30: geometry, QUADTREE, [id=#62]
+   +- Project [st_polygonfromenvelope(cast(_c0#22 as decimal(24,20)), cast(_c1#23 as decimal(24,20)), cast(_c2#24 as decimal(24,20)), cast(_c3#25 as decimal(24,20))) AS polygonshape#30]
+      +- FileScan csv
+```
+
+This also works for distance joins:
+
+```Scala
+pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+```
+
+Spark SQL Physical plan:
+```
+== Physical Plan ==
+BroadcastIndexJoin pointshape#52: geometry, BuildRight, BuildLeft, true, 2.0 ST_Distance(pointshape#52, pointshape#415) <= 2.0
+:- Project [st_point(cast(_c0#48 as decimal(24,20)), cast(_c1#49 as decimal(24,20))) AS pointshape#52]
+:  +- FileScan csv
++- SpatialIndex pointshape#415: geometry, QUADTREE, [id=#1068]
+   +- Project [st_point(cast(_c0#48 as decimal(24,20)), cast(_c1#49 as decimal(24,20))) AS pointshape#415]
+      +- FileScan csv
+```
+
+Note: Ff the distance is an expression, it is only evaluated on the first argument to ST_Distance (`pointDf1` above).
+
 ## Predicate pushdown
 
 Introduction: Given a join query and a predicate in the same WHERE clause, first executes the Predicate as a filter, then executes the join query*
diff --git a/docs/api/sql/GeoSparkSQL-Overview.md b/docs/api/sql/GeoSparkSQL-Overview.md
index cf596be..15edc9f 100644
--- a/docs/api/sql/GeoSparkSQL-Overview.md
+++ b/docs/api/sql/GeoSparkSQL-Overview.md
@@ -6,6 +6,11 @@ SedonaSQL supports SQL/MM Part3 Spatial SQL Standard. It includes four kinds of
 var myDataFrame = sparkSession.sql("YOUR_SQL")
 ```
 
+Alternatively, `expr` and `selectExpr` can be used:
+```Scala
+myDataFrame.withColumn("geometry", expr("ST_*")).selectExpr("ST_*")
+```
+
 * Constructor: Construct a Geometry given an input string or coordinates
 	* Example: ST_GeomFromWKT (string). Create a Geometry from a WKT String.
 	* Documentation: [Here](../GeoSparkSQL-Constructor)
diff --git a/python-adapter/.gitignore b/python-adapter/.gitignore
index b83d222..804f11a 100644
--- a/python-adapter/.gitignore
+++ b/python-adapter/.gitignore
@@ -1 +1,5 @@
 /target/
+bin
+/.settings
+/.classpath
+/.project
\ No newline at end of file
diff --git a/python/tests/sql/test_predicate_join.py b/python/tests/sql/test_predicate_join.py
index dfce133..8aee350 100644
--- a/python/tests/sql/test_predicate_join.py
+++ b/python/tests/sql/test_predicate_join.py
@@ -16,6 +16,7 @@
 #  under the License.
 
 from pyspark import Row
+from pyspark.sql.functions import broadcast, expr
 from pyspark.sql.types import StructType, StringType, IntegerType, StructField, DoubleType
 
 from tests import csv_polygon_input_location, csv_point_input_location, overlap_polygon_input_location, \
@@ -449,3 +450,49 @@ class TestPredicateJoin(TestBase):
         equal_join_df.explain()
         equal_join_df.show(3)
         assert equal_join_df.count() == 0, f"Expected 0 but got {equal_join_df.count()}"
+
+    def test_st_contains_in_broadcast_join(self):
+        polygon_csv_df = self.spark.read.format("csv").\
+                option("delimiter", ",").\
+                option("header", "false").load(
+            csv_polygon_input_location
+        )
+        polygon_csv_df.createOrReplaceTempView("polygontable")
+        polygon_csv_df.show()
+
+        polygon_df = self.spark.sql(
+            "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable")
+        polygon_df = polygon_df.repartition(7)
+        polygon_df.createOrReplaceTempView("polygondf")
+        polygon_df.show()
+
+        point_csv_df = self.spark.read.format("csv").\
+            option("delimiter", ",").\
+            option("header", "false").load(
+            csv_point_input_location
+        )
+        point_csv_df.createOrReplaceTempView("pointtable")
+        point_csv_df.show()
+
+        point_df = self.spark.sql(
+            "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable")
+        point_df = point_df.repartition(9)
+        point_df.createOrReplaceTempView("pointdf")
+        point_df.show()
+
+        range_join_df = self.spark.sql(
+            "select /*+ BROADCAST(polygondf) */ * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) ")
+
+        range_join_df.explain()
+        range_join_df.show(3)
+        assert range_join_df.rdd.getNumPartitions() == 9
+        assert range_join_df.count() == 1000
+
+        range_join_df = point_df.alias("pointdf").join(broadcast(polygon_df).alias("polygondf"), on=expr("ST_Contains(polygondf.polygonshape, pointdf.pointshape)"))
+
+        range_join_df.explain()
+        range_join_df.show(3)
+        assert range_join_df.rdd.getNumPartitions() == 9
+        assert range_join_df.count() == 1000
+
+        
diff --git a/spark-version-converter.py b/spark-version-converter.py
index 8e5bca1..c04ffe2 100644
--- a/spark-version-converter.py
+++ b/spark-version-converter.py
@@ -23,7 +23,8 @@ spark2_anchor = 'SPARK2 anchor'
 spark3_anchor = 'SPARK3 anchor'
 files = ['sql/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala',
          'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala',
-         'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala']
+         'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala',
+         'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala']
 
 def switch_version(line):
     if line[:2] == '//':
diff --git a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala b/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
index 12870c8..3e60627 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
@@ -20,18 +20,16 @@ package org.apache.sedona.sql.utils
 
 import org.apache.sedona.sql.UDF.UdfRegistrator
 import org.apache.sedona.sql.UDT.UdtRegistrator
-import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
 import org.apache.spark.sql.{SQLContext, SparkSession}
+import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
 
 object SedonaSQLRegistrator {
   def registerAll(sqlContext: SQLContext): Unit = {
-    sqlContext.experimental.extraStrategies = JoinQueryDetector :: Nil
-    UdtRegistrator.registerAll()
-    UdfRegistrator.registerAll(sqlContext)
+    registerAll(sqlContext.sparkSession)
   }
 
   def registerAll(sparkSession: SparkSession): Unit = {
-    sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
+    sparkSession.experimental.extraStrategies = new JoinQueryDetector(sparkSession) :: Nil
     UdtRegistrator.registerAll()
     UdfRegistrator.registerAll(sparkSession)
   }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
index 4e71b86..959b8de 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
@@ -25,13 +25,15 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.types.BooleanType
 
+abstract class ST_Predicate extends Expression
+
 /**
   * Test if leftGeometry full contains rightGeometry
   *
   * @param inputExpressions
   */
 case class ST_Contains(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends ST_Predicate with CodegenFallback {
 
   // This is a binary expression
   assert(inputExpressions.length == 2)
@@ -62,7 +64,7 @@ case class ST_Contains(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Intersects(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends ST_Predicate with CodegenFallback {
   override def nullable: Boolean = false
 
   // This is a binary expression
@@ -92,7 +94,7 @@ case class ST_Intersects(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Within(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends ST_Predicate with CodegenFallback {
   override def nullable: Boolean = false
 
   // This is a binary expression
@@ -123,7 +125,7 @@ case class ST_Within(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Crosses(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends ST_Predicate with CodegenFallback {
   override def nullable: Boolean = false
 
   override def toString: String = s" **${ST_Crosses.getClass.getName}**  "
@@ -153,7 +155,7 @@ case class ST_Crosses(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Overlaps(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends ST_Predicate with CodegenFallback {
   override def nullable: Boolean = false
 
   // This is a binary expression
@@ -183,7 +185,7 @@ case class ST_Overlaps(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Touches(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends ST_Predicate with CodegenFallback {
   override def nullable: Boolean = false
 
   // This is a binary expression
@@ -213,7 +215,7 @@ case class ST_Touches(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Equals(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends ST_Predicate with CodegenFallback {
   override def nullable: Boolean = false
 
   // This is a binary expression
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
new file mode 100644
index 0000000..0d9ec4f
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.join
+
+import collection.JavaConverters._
+
+import org.apache.sedona.core.spatialRDD.SpatialRDD
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, Predicate, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.index.SpatialIndex
+
+case class BroadcastIndexJoinExec(left: SparkPlan,
+                                  right: SparkPlan,
+                                  streamShape: Expression,
+                                  indexBuildSide: JoinSide,
+                                  windowJoinSide: JoinSide,
+                                  intersects: Boolean,
+                                  extraCondition: Option[Expression] = None,
+                                  radius: Option[Expression] = None)
+  extends BinaryExecNode
+    with TraitJoinQueryBase
+    with Logging {
+
+  override def output: Seq[Attribute] = left.output ++ right.output
+
+  // Using lazy val to avoid serialization
+  @transient private lazy val boundCondition: (InternalRow => Boolean) = extraCondition match {
+    case Some(condition) =>
+      Predicate.create(condition, output).eval _ // SPARK3 anchor
+//      newPredicate(condition, output).eval _ // SPARK2 anchor
+    case None =>
+      (r: InternalRow) => true
+  }
+
+  private val (streamed, broadcast) = indexBuildSide match {
+    case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
+    case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
+  }
+
+  override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+  private val (windowExpression, objectExpression) = if (indexBuildSide == windowJoinSide) {
+    (broadcast.shape, streamShape)
+  } else {
+    (streamShape, broadcast.shape)
+  }
+
+  private val spatialExpression = radius match {
+      case Some(r) if intersects => s"ST_Distance($windowExpression, $objectExpression) <= $r"
+      case Some(r) if !intersects => s"ST_Distance($windowExpression, $objectExpression) < $r"
+      case None if intersects => s"ST_Intersects($windowExpression, $objectExpression)"
+      case None if !intersects => s"ST_Contains($windowExpression, $objectExpression)"
+    }
+
+  override def simpleString(maxFields: Int): String = super.simpleString(maxFields) + s" $spatialExpression" // SPARK3 anchor
+//  override def simpleString: String = super.simpleString + s" $spatialExpression" // SPARK2 anchor
+
+  private def windowBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd: SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
+    spatialRdd.getRawSpatialRDD.rdd.flatMap { row =>
+      val candidates = index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
+      candidates
+        .filter(candidate => if (intersects) candidate.intersects(row) else candidate.covers(row))
+        .map(candidate => (candidate, row))
+    }
+  }
+
+  private def objectBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd: SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
+    spatialRdd.getRawSpatialRDD.rdd.flatMap { row =>
+      val candidates = index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
+      candidates
+        .filter(candidate => if (intersects) row.intersects(candidate) else row.covers(candidate))
+        .map(candidate => (row, candidate))
+    }
+  }
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val boundStreamShape = BindReferences.bindReference(streamShape, streamed.output)
+    val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]
+
+    // If there's a radius and the objects are being broadcast, we need to build the CircleRDD on the window stream side
+    val streamShapes = radius match {
+      case Some(radiusExpression) if indexBuildSide != windowJoinSide =>
+        toCircleRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(radiusExpression, streamed.output))
+      case _ =>
+        toSpatialRDD(streamResultsRaw, boundStreamShape)
+    }
+
+    val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()
+
+    val pairs = (indexBuildSide, windowJoinSide) match {
+      case (LeftSide, LeftSide) => windowBroadcastJoin(broadcastIndex, streamShapes)
+      case (LeftSide, RightSide) => objectBroadcastJoin(broadcastIndex, streamShapes).map { case (left, right) => (right, left) }
+      case (RightSide, LeftSide) => objectBroadcastJoin(broadcastIndex, streamShapes)
+      case (RightSide, RightSide) => windowBroadcastJoin(broadcastIndex, streamShapes).map { case (left, right) => (right, left) }
+    }
+
+    pairs.mapPartitions { iter =>
+      val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
+      iter.map {
+        case (l, r) =>
+          val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
+          val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
+          joiner.join(leftRow, rightRow)
+      }.filter(boundCondition(_))
+    }
+  }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 28087b4..a921979 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -48,7 +48,7 @@ case class DistanceJoinExec(left: SparkPlan,
                                  buildExpr: Expression,
                                  streamedRdd: RDD[UnsafeRow],
                                  streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
-    (toCircleRDD(buildRdd, buildExpr), toSpatialRdd(streamedRdd, streamedExpr))
+    (toCircleRDD(buildRdd, buildExpr), toSpatialRDD(streamedRdd, streamedExpr))
 
   private def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
     val spatialRdd = new SpatialRDD[Geometry]
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index df624f9..18a8141 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -18,68 +18,112 @@
  */
 package org.apache.spark.sql.sedona_sql.strategy.join
 
-import org.apache.spark.sql.Strategy
-import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan, LessThanOrEqual}
+import org.apache.sedona.core.enums.IndexType
+import org.apache.sedona.core.utils.SedonaConf
+import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan, LessThanOrEqual}
 import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.sedona_sql.expressions._
 
+
+
+case class JoinQueryDetection(
+  left: LogicalPlan,
+  right: LogicalPlan,
+  leftShape: Expression,
+  rightShape: Expression,
+  intersects: Boolean,
+  extraCondition: Option[Expression] = None,
+  radius: Option[Expression] = None
+)
+
 /**
   * Plans `RangeJoinExec` for inner joins on spatial relationships ST_Contains(a, b)
   * and ST_Intersects(a, b).
   *
   * Plans `DistanceJoinExec` for inner joins on spatial relationship ST_Distance(a, b) < r.
+  * 
+  * Plans `BroadcastIndexJoinExec for inner joins on spatial relationships with a broadcast hint.
   */
-object JoinQueryDetector extends Strategy {
+class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
+
+  private def getJoinDetection(
+    left: LogicalPlan,
+    right: LogicalPlan,
+    predicate: ST_Predicate,
+    extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = {
+      predicate match {
+        case ST_Contains(Seq(leftShape, rightShape)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, false, extraCondition))
+        case ST_Intersects(Seq(leftShape, rightShape)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, true, extraCondition))
+        case ST_Within(Seq(leftShape, rightShape)) =>
+          Some(JoinQueryDetection(right, left, rightShape, leftShape, false, extraCondition))
+        case ST_Overlaps(Seq(leftShape, rightShape)) =>
+          Some(JoinQueryDetection(right, left, rightShape, leftShape, false, extraCondition))
+        case ST_Touches(Seq(leftShape, rightShape)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, true, extraCondition))
+        case ST_Equals(Seq(leftShape, rightShape)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, false, extraCondition))
+        case ST_Crosses(Seq(leftShape, rightShape)) =>
+          Some(JoinQueryDetection(right, left, rightShape, leftShape, false, extraCondition))
+        case _ => None
+      }
+    }
 
   def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-
-    // ST_Contains(a, b) - a contains b
-case Join(left, right, Inner, Some(ST_Contains(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Contains(Seq(leftShape, rightShape)))) => // SPARK2 anchor
-      planSpatialJoin(left, right, Seq(leftShape, rightShape), false)
-
-    // ST_Intersects(a, b) - a intersects b
-case Join(left, right, Inner, Some(ST_Intersects(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Intersects(Seq(leftShape, rightShape)))) => // SPARK2 anchor
-      planSpatialJoin(left, right, Seq(leftShape, rightShape), true)
-
-    // ST_WITHIN(a, b) - a is within b
-case Join(left, right, Inner, Some(ST_Within(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Within(Seq(leftShape, rightShape)))) => // SPARK2 anchor
-      planSpatialJoin(right, left, Seq(rightShape, leftShape), false)
-
-    // ST_Overlaps(a, b) - a overlaps b
-case Join(left, right, Inner, Some(ST_Overlaps(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Overlaps(Seq(leftShape, rightShape)))) => // SPARK2 anchor
-      planSpatialJoin(right, left, Seq(rightShape, leftShape), false)
-
-    // ST_Touches(a, b) - a touches b
-case Join(left, right, Inner, Some(ST_Touches(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Touches(Seq(leftShape, rightShape)))) => // SPARK2 anchor
-      planSpatialJoin(left, right, Seq(leftShape, rightShape), true)
-
-    // ST_Distance(a, b) <= radius consider boundary intersection
-case Join(left, right, Inner, Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius)), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius))) => // SPARK2 anchor
-      planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, true)
-
-    // ST_Distance(a, b) < radius don't consider boundary intersection
-case Join(left, right, Inner, Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius)), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius))) => // SPARK2 anchor
-      planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, false)
-
-    // ST_Equals(a, b) - a is equal to b
-case Join(left, right, Inner, Some(ST_Equals(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Equals(Seq(leftShape, rightShape)))) => // SPARK2 anchor
-      planSpatialJoin(left, right, Seq(leftShape, rightShape), false)
-
-    // ST_Crosses(a, b) - a crosses b
-case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape)))) => // SPARK2 anchor
-      planSpatialJoin(right, left, Seq(rightShape, leftShape), false)
-
+    case Join(left, right, Inner, condition, JoinHint(leftHint, rightHint)) => { // SPARK3 anchor
+//    case Join(left, right, Inner, condition) => { // SPARK2 anchor
+      val broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST)) // SPARK3 anchor
+      val broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST)) // SPARK3 anchor
+//      val broadcastLeft = left.isInstanceOf[ResolvedHint] && left.asInstanceOf[ResolvedHint].hints.broadcast  // SPARK2 anchor
+//      val broadcastRight = right.isInstanceOf[ResolvedHint] && right.asInstanceOf[ResolvedHint].hints.broadcast // SPARK2 anchor
+
+      val queryDetection: Option[JoinQueryDetection] = condition match {
+        case Some(predicate: ST_Predicate) =>
+          getJoinDetection(left, right, predicate)
+        case Some(And(predicate: ST_Predicate, extraCondition)) =>
+          getJoinDetection(left, right, predicate, Some(extraCondition))
+        case Some(And(extraCondition, predicate: ST_Predicate)) =>
+          getJoinDetection(left, right, predicate, Some(extraCondition))
+        case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, true, None, Some(radius)))
+        case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, true, Some(extraCondition), Some(radius)))
+        case Some(And(extraCondition, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, true, Some(extraCondition), Some(radius)))
+        case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, false, None, Some(radius)))
+        case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, false, Some(extraCondition), Some(radius)))
+        case Some(And(extraCondition, LessThan(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, false, Some(extraCondition), Some(radius)))
+        case _ =>
+          None
+      }
+
+      val sedonaConf = new SedonaConf(sparkSession.sparkContext.conf)
+
+      if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
+        queryDetection match {
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, intersects, extraCondition, radius)) =>
+            planBroadcastJoin(left, right, Seq(leftShape, rightShape), intersects, sedonaConf.getIndexType, broadcastLeft, extraCondition, radius)
+          case _ =>
+            Nil
+        }
+      } else {
+        queryDetection match {
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, intersects, extraCondition, None)) =>
+            planSpatialJoin(left, right, Seq(leftShape, rightShape), intersects, extraCondition)
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, intersects, extraCondition, Some(radius))) =>
+            planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, intersects, extraCondition)
+          case None => 
+            Nil
+        }
+      }
+    }
     case _ =>
       Nil
   }
@@ -95,11 +139,11 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
   private def matchExpressionsToPlans(exprA: Expression,
                                       exprB: Expression,
                                       planA: LogicalPlan,
-                                      planB: LogicalPlan): Option[(LogicalPlan, LogicalPlan)] =
+                                      planB: LogicalPlan): Option[(LogicalPlan, LogicalPlan, Boolean)] =
     if (matches(exprA, planA) && matches(exprB, planB)) {
-      Some((planA, planB))
+      Some((planA, planB, false))
     } else if (matches(exprA, planB) && matches(exprB, planA)) {
-      Some((planB, planA))
+      Some((planB, planA, true))
     } else {
       None
     }
@@ -115,7 +159,7 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
     val relationship = if (intersects) "ST_Intersects" else "ST_Contains";
 
     matchExpressionsToPlans(a, b, left, right) match {
-      case Some((planA, planB)) =>
+      case Some((planA, planB, _)) =>
         logInfo(s"Planning spatial join for $relationship relationship")
         RangeJoinExec(planLater(planA), planLater(planB), a, b, intersects, extraCondition) :: Nil
       case None =>
@@ -138,7 +182,7 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
     val relationship = if (intersects) "ST_Distance <=" else "ST_Distance <";
 
     matchExpressionsToPlans(a, b, left, right) match {
-      case Some((planA, planB)) =>
+      case Some((planA, planB, _)) =>
         if (radius.references.isEmpty || matches(radius, planA)) {
           logInfo("Planning spatial distance join")
           DistanceJoinExec(planLater(planA), planLater(planB), a, b, radius, intersects, extraCondition) :: Nil
@@ -158,4 +202,45 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
         Nil
     }
   }
+
+  private def planBroadcastJoin(left: LogicalPlan,
+                                right: LogicalPlan,
+                                children: Seq[Expression],
+                                intersects: Boolean,
+                                indexType: IndexType,
+                                broadcastLeft: Boolean,
+                                extraCondition: Option[Expression],
+                                radius: Option[Expression]): Seq[SparkPlan] = {
+    val a = children.head
+    val b = children.tail.head
+
+    val relationship = radius match {
+      case Some(_) if intersects => "ST_Distance <="
+      case Some(_) if !intersects => "ST_Distance <"
+      case None if intersects => "ST_Intersects"
+      case None if !intersects => "ST_Contains"
+    }
+
+    matchExpressionsToPlans(a, b, left, right) match {
+      case Some((_, _, swapped)) =>
+        logInfo(s"Planning spatial join for $relationship relationship")
+        val broadcastSide = if (broadcastLeft) LeftSide else RightSide
+        val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide, swapped) match {
+          case (LeftSide, false) => // Broadcast the left side, windows on the left
+            (SpatialIndexExec(planLater(left), a, indexType, radius), planLater(right), b, LeftSide)
+          case (LeftSide, true) => // Broadcast the left side, objects on the left
+            (SpatialIndexExec(planLater(left), b, indexType), planLater(right), a, RightSide)
+          case (RightSide, false) => // Broadcast the right side, windows on the left
+            (planLater(left), SpatialIndexExec(planLater(right), b, indexType), a, LeftSide)
+          case (RightSide, true) => // Broadcast the right side, objects on the left
+            (planLater(left), SpatialIndexExec(planLater(right), a, indexType, radius), b, RightSide)
+        }
+        BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide, windowSide, intersects, extraCondition, radius) :: Nil
+      case None =>
+        logInfo(
+          s"Spatial join for $relationship with arguments not aligned " +
+            "with join relations is not supported")
+        Nil
+    }
+  }
 }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
new file mode 100644
index 0000000..c76ccfc
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.join
+
+import scala.collection.JavaConverters._
+
+import org.apache.sedona.core.enums.IndexType
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, UnsafeRow}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.exchange.Exchange
+
+
+
+case class SpatialIndexExec(child: SparkPlan,
+                            shape: Expression,
+                            indexType: IndexType,
+                            radius: Option[Expression] = None)
+  extends Exchange
+    with TraitJoinQueryBase
+    with Logging {
+
+  override def output: Seq[Attribute] = child.output
+  
+  override protected def doExecute(): RDD[InternalRow] = {
+    throw new UnsupportedOperationException(
+      "SpatialIndex does not support the execute() code path.")
+  }
+
+  override protected[sql] def doExecuteBroadcast[T](): Broadcast[T] = {
+    val boundShape = BindReferences.bindReference(shape, child.output)
+
+    val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
+
+    val spatialRDD = radius match {
+      case Some(radiusExpression) => toCircleRDD(resultRaw, boundShape, BindReferences.bindReference(radiusExpression, child.output))
+      case None => toSpatialRDD(resultRaw, boundShape)
+    }
+
+    spatialRDD.buildIndex(indexType, false)
+    sparkContext.broadcast(spatialRDD.indexedRawRDD.take(1).asScala.head).asInstanceOf[Broadcast[T]]
+  }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
similarity index 52%
copy from sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
copy to sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 28087b4..3b7a0e1 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -20,37 +20,39 @@ package org.apache.spark.sql.sedona_sql.strategy.join
 
 import org.apache.sedona.core.geometryObjects.Circle
 import org.apache.sedona.core.spatialRDD.SpatialRDD
+import org.apache.sedona.core.utils.SedonaConf
 import org.apache.sedona.sql.utils.GeometrySerializer
-import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
 import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
 import org.locationtech.jts.geom.Geometry
 
-// ST_Distance(left, right) <= radius
-// radius can be literal or a computation over 'left'
-case class DistanceJoinExec(left: SparkPlan,
-                            right: SparkPlan,
-                            leftShape: Expression,
-                            rightShape: Expression,
-                            radius: Expression,
-                            intersects: Boolean,
-                            extraCondition: Option[Expression] = None)
-  extends BinaryExecNode
-    with TraitJoinQueryExec
-    with Logging {
+trait TraitJoinQueryBase {
+  self: SparkPlan =>
 
-  private val boundRadius = BindReferences.bindReference(radius, left.output)
+  def toSpatialRddPair(buildRdd: RDD[UnsafeRow],
+                       buildExpr: Expression,
+                       streamedRdd: RDD[UnsafeRow],
+                       streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
+    (toSpatialRDD(buildRdd, buildExpr), toSpatialRDD(streamedRdd, streamedExpr))
 
-  override def toSpatialRddPair(
-                                 buildRdd: RDD[UnsafeRow],
-                                 buildExpr: Expression,
-                                 streamedRdd: RDD[UnsafeRow],
-                                 streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
-    (toCircleRDD(buildRdd, buildExpr), toSpatialRdd(streamedRdd, streamedExpr))
+  def toSpatialRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
+    val spatialRdd = new SpatialRDD[Geometry]
+    spatialRdd.setRawSpatialRDD(
+      rdd
+        .map { x => {
+          val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
+          //logInfo(shape.toString)
+          shape.setUserData(x.copy)
+          shape
+        }
+        }
+        .toJavaRDD())
+    spatialRdd
+  }
 
-  private def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
+  def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = {
     val spatialRdd = new SpatialRDD[Geometry]
     spatialRdd.setRawSpatialRDD(
       rdd
@@ -65,4 +67,9 @@ case class DistanceJoinExec(left: SparkPlan,
     spatialRdd
   }
 
+  def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry], followerShapes: SpatialRDD[Geometry],
+                            numPartitions: Integer, sedonaConf: SedonaConf): Unit = {
+    dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
+    followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
+  }
 }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index d00194b..9850b78 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -21,25 +21,21 @@ package org.apache.spark.sql.sedona_sql.strategy.join
 import org.apache.sedona.core.enums.JoinSparitionDominantSide
 import org.apache.sedona.core.spatialOperator.JoinQuery
 import org.apache.sedona.core.spatialOperator.JoinQuery.JoinParams
-import org.apache.sedona.core.spatialRDD.SpatialRDD
 import org.apache.sedona.core.utils.SedonaConf
-import org.apache.sedona.sql.utils.GeometrySerializer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, Predicate, UnsafeRow}
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.execution.SparkPlan
-import org.locationtech.jts.geom.Geometry
 
-trait TraitJoinQueryExec {
+trait TraitJoinQueryExec extends TraitJoinQueryBase {
   self: SparkPlan =>
 
   // Using lazy val to avoid serialization
   @transient private lazy val boundCondition: (InternalRow => Boolean) = {
     if (extraCondition.isDefined) {
-Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPARK3 anchor
-//newPredicate(extraCondition.get, left.output ++ right.output).eval _ // SPARK2 anchor
+      Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPARK3 anchor
+//      newPredicate(extraCondition.get, left.output ++ right.output).eval _ // SPARK2 anchor
     } else { (r: InternalRow) =>
       true
     }
@@ -134,8 +130,8 @@ Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPAR
     matches.rdd.mapPartitions { iter =>
       val filtered =
         if (extraCondition.isDefined) {
-val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
-//val boundCondition = newPredicate(extraCondition.get, left.output ++ right.output) // SPARK2 anchor
+          val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
+//          val boundCondition = newPredicate(extraCondition.get, left.output ++ right.output) // SPARK2 anchor
           iter.filter {
             case (l, r) =>
               val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
@@ -157,35 +153,6 @@ val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.o
     }
   }
 
-  def toSpatialRddPair(buildRdd: RDD[UnsafeRow],
-                       buildExpr: Expression,
-                       streamedRdd: RDD[UnsafeRow],
-                       streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
-    (toSpatialRdd(buildRdd, buildExpr), toSpatialRdd(streamedRdd, streamedExpr))
-
-  protected def toSpatialRdd(rdd: RDD[UnsafeRow],
-                             shapeExpression: Expression): SpatialRDD[Geometry] = {
-
-    val spatialRdd = new SpatialRDD[Geometry]
-    spatialRdd.setRawSpatialRDD(
-      rdd
-        .map { x => {
-          val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
-          //logInfo(shape.toString)
-          shape.setUserData(x.copy)
-          shape
-        }
-        }
-        .toJavaRDD())
-    spatialRdd
-  }
-
-  def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry], followerShapes: SpatialRDD[Geometry],
-                            numPartitions: Integer, sedonaConf: SedonaConf): Unit = {
-    dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
-    followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
-  }
-
   def joinPartitionNumOptimizer(dominantSidePartNum: Int, followerSidePartNum: Int, dominantSideCount: Long): Int = {
     log.info("[SedonaSQL] Dominant side count: " + dominantSideCount)
     var numPartition = -1
diff --git a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/enums.scala
similarity index 50%
copy from sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
copy to sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/enums.scala
index 12870c8..7ce02a5 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/enums.scala
@@ -16,27 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.sedona.sql.utils
+package org.apache.spark.sql.sedona_sql.strategy.join
 
-import org.apache.sedona.sql.UDF.UdfRegistrator
-import org.apache.sedona.sql.UDT.UdtRegistrator
-import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
-import org.apache.spark.sql.{SQLContext, SparkSession}
+sealed trait JoinSide
 
-object SedonaSQLRegistrator {
-  def registerAll(sqlContext: SQLContext): Unit = {
-    sqlContext.experimental.extraStrategies = JoinQueryDetector :: Nil
-    UdtRegistrator.registerAll()
-    UdfRegistrator.registerAll(sqlContext)
-  }
-
-  def registerAll(sparkSession: SparkSession): Unit = {
-    sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
-    UdtRegistrator.registerAll()
-    UdfRegistrator.registerAll(sparkSession)
-  }
-
-  def dropAll(sparkSession: SparkSession): Unit = {
-    UdfRegistrator.dropAll(sparkSession)
-  }
-}
+case object LeftSide extends JoinSide
+case object RightSide extends JoinSide
diff --git a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
new file mode 100644
index 0000000..1689332
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sedona.sql
+
+import org.apache.spark.sql.sedona_sql.strategy.join.BroadcastIndexJoinExec
+import org.apache.spark.sql.functions._
+
+class BroadcastIndexJoinSuite extends TestBaseScala {
+
+  describe("Sedona-SQL Broadcast Index Join Test") {
+
+    // Using UDFs rather than lit prevents optimizations that would circumvent the checks we want to test
+    val one = udf(() => 1)
+    val two = udf(() => 2)
+
+    it("Passed Correct partitioning for broadcast join for ST_Polygon and ST_Point") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      var broadcastJoinDf = pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == pointDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == polygonDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == polygonDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed Broadcasts the left side if both sides have a broadcast hint") {
+      val polygonDf = buildPolygonDf.repartition(3)
+      val pointDf = buildPointDf.repartition(5)
+
+      var broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.rdd.getNumPartitions == polygonDf.rdd.getNumPartitions)
+      assert(broadcastJoinDf.count() == 1000)
+    }
+
+    it("Passed Can access attributes of both sides of broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", one())
+      
+      var broadcastJoinDf = polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+
+      broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+
+      broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+
+      broadcastJoinDf = pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+      assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+      assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+    }
+
+    it("Passed Handles extra conditions on a broadcast join") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+      val pointDf = buildPointDf.withColumn("object_extra", two())
+
+      var broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("ST_Contains(polygonshape, pointshape) AND window_extra <= object_extra")
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("ST_Contains(polygonshape, pointshape) AND window_extra > object_extra")
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+
+      broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra <= object_extra AND ST_Contains(polygonshape, pointshape)")
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra > object_extra AND ST_Contains(polygonshape, pointshape)")
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+    }
+
+    it("Passed Handles multiple extra conditions on a broadcast join with the ST predicate last") {
+      val polygonDf = buildPolygonDf.withColumn("window_extra", one()).withColumn("window_extra2", one())
+      val pointDf = buildPointDf.withColumn("object_extra", two()).withColumn("object_extra2", two())
+
+      var broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra <= object_extra AND window_extra2 <= object_extra2 AND ST_Contains(polygonshape, pointshape)")
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 1000)
+
+      broadcastJoinDf = pointDf
+        .alias("pointDf")
+        .join(
+          broadcast(polygonDf.alias("polygonDf")),
+          expr("window_extra > object_extra AND window_extra2 > object_extra2 AND ST_Contains(polygonshape, pointshape)")
+        )
+
+      assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(broadcastJoinDf.count() == 0)
+    }
+
+    it("Passed ST_Distance <= radius in a broadcast join") {
+      var pointDf1 = buildPointDf
+      var pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+    }
+
+    it("Passed ST_Distance < radius in a broadcast join") {
+      var pointDf1 = buildPointDf
+      var pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+    }
+
+    it("Passed ST_Distance radius is bound to first expression") {
+      var pointDf1 = buildPointDf.withColumn("radius", two())
+      var pointDf2 = buildPointDf
+
+      var distanceJoinDf = pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = pointDf2.alias("pointDf2").join(broadcast(pointDf1).alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+
+      distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+      assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+      assert(distanceJoinDf.count() == 2998)
+    }
+  }
+}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 4409106..1c33fe9 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -22,7 +22,7 @@ import org.apache.log4j.{Level, Logger}
 import org.apache.sedona.core.serde.SedonaKryoRegistrator
 import org.apache.sedona.sql.utils.SedonaSQLRegistrator
 import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.scalatest.{BeforeAndAfterAll, FunSpec}
 
 trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
@@ -39,7 +39,6 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
     .getOrCreate()
 
   val resourceFolder = System.getProperty("user.dir") + "/../core/src/test/resources/"
-
   val mixedWkbGeometryInputLocation = resourceFolder + "county_small_wkb.tsv"
   val mixedWktGeometryInputLocation = resourceFolder + "county_small.tsv"
   val shapefileInputLocation = resourceFolder + "shapefiles/dbf"
@@ -70,4 +69,11 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
     //SedonaSQLRegistrator.dropAll(spark)
     //spark.stop
   }
+
+  def loadCsv(path: String): DataFrame = {
+    sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path)
+  }
+
+  lazy val buildPointDf = loadCsv(csvPointInputLocation).selectExpr("ST_Point(cast(_c0 as Decimal(24,20)),cast(_c1 as Decimal(24,20))) as pointshape")
+  lazy val buildPolygonDf = loadCsv(csvPolygonInputLocation).selectExpr("ST_PolygonFromEnvelope(cast(_c0 as Decimal(24,20)),cast(_c1 as Decimal(24,20)), cast(_c2 as Decimal(24,20)), cast(_c3 as Decimal(24,20))) as polygonshape")
 }
diff --git a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index dffadb0..dc6bb0a 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -21,7 +21,6 @@ package org.apache.sedona.sql
 
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
 import org.apache.spark.sql.types._
 
 class predicateJoinTestScala extends TestBaseScala {
@@ -127,8 +126,6 @@ class predicateJoinTestScala extends TestBaseScala {
     }
 
     it("Passed ST_Distance <= radius in a join") {
-      sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
-
       var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
       pointCsvDF1.createOrReplaceTempView("pointtable")
       var pointDf1 = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable")
@@ -145,8 +142,6 @@ class predicateJoinTestScala extends TestBaseScala {
     }
 
     it("Passed ST_Distance < radius in a join") {
-      sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
-
       var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
       pointCsvDF1.createOrReplaceTempView("pointtable")
       var pointDf1 = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable")