You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2021/03/26 00:56:32 UTC
[incubator-sedona] branch master updated: [SEDONA-26] Add broadcast
join support (#515)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 369b51a [SEDONA-26] Add broadcast join support (#515)
369b51a is described below
commit 369b51a68a7983b0c6fa79c313a29bdc1ff6be11
Author: Adam Binford <ad...@gmail.com>
AuthorDate: Thu Mar 25 20:56:26 2021 -0400
[SEDONA-26] Add broadcast join support (#515)
* Initial broadcast join working in Spark 3
* Add distance join, extra condition, and spark 2 support for broadcast join
* Add some documentation for broadcast joins
* Fix small issue with broadcast distance joins with expression for radius
* Simplify tests
* Use custom enum for Spark 3.1 compatibility and try to add a python test
* Simplify join detection with ST_Predicate parent class and detect both sides of an And expression
Co-authored-by: Adam Binford <ad...@maxar.com>
---
.gitignore | 3 +
docs/api/sql/GeoSparkSQL-Optimizer.md | 37 ++++
docs/api/sql/GeoSparkSQL-Overview.md | 5 +
python-adapter/.gitignore | 4 +
python/tests/sql/test_predicate_join.py | 47 +++++
spark-version-converter.py | 3 +-
.../sedona/sql/utils/SedonaSQLRegistrator.scala | 8 +-
.../sql/sedona_sql/expressions/Predicates.scala | 16 +-
.../strategy/join/BroadcastIndexJoinExec.scala | 130 +++++++++++++
.../strategy/join/DistanceJoinExec.scala | 2 +-
.../strategy/join/JoinQueryDetector.scala | 195 +++++++++++++------
.../strategy/join/SpatialIndexExec.scala | 62 +++++++
...anceJoinExec.scala => TraitJoinQueryBase.scala} | 53 +++---
.../strategy/join/TraitJoinQueryExec.scala | 45 +----
.../sql/sedona_sql/strategy/join/enums.scala} | 26 +--
.../sedona/sql/BroadcastIndexJoinSuite.scala | 206 +++++++++++++++++++++
.../org/apache/sedona/sql/TestBaseScala.scala | 10 +-
.../apache/sedona/sql/predicateJoinTestScala.scala | 5 -
18 files changed, 697 insertions(+), 160 deletions(-)
diff --git a/.gitignore b/.gitignore
index bbeea00..4ced234 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,3 +10,6 @@
/conf/
/log/
/site/
+/.bloop/
+/.metals/
+/.vscode/
\ No newline at end of file
diff --git a/docs/api/sql/GeoSparkSQL-Optimizer.md b/docs/api/sql/GeoSparkSQL-Optimizer.md
index 222a8de..92a66f2 100644
--- a/docs/api/sql/GeoSparkSQL-Optimizer.md
+++ b/docs/api/sql/GeoSparkSQL-Optimizer.md
@@ -72,6 +72,43 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33: geometry, 2.0, true
!!!warning
Sedona doesn't control the distance's unit (degree or meter). It is same with the geometry. To change the geometry's unit, please transform the coordinate reference system. See [ST_Transform](GeoSparkSQL-Function.md#st_transform).
+## Broadcast join
+Introduction: Perform a range join or distance join but broadcast one of the sides of the join. This maintains the partitioning of the non-broadcast side and doesn't require a shuffle.
+
+```Scala
+pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+```
+
+Spark SQL Physical plan:
+```
+== Physical Plan ==
+BroadcastIndexJoin pointshape#52: geometry, BuildRight, BuildRight, false ST_Contains(polygonshape#30, pointshape#52)
+:- Project [st_point(cast(_c0#48 as decimal(24,20)), cast(_c1#49 as decimal(24,20))) AS pointshape#52]
+: +- FileScan csv
++- SpatialIndex polygonshape#30: geometry, QUADTREE, [id=#62]
+ +- Project [st_polygonfromenvelope(cast(_c0#22 as decimal(24,20)), cast(_c1#23 as decimal(24,20)), cast(_c2#24 as decimal(24,20)), cast(_c3#25 as decimal(24,20))) AS polygonshape#30]
+ +- FileScan csv
+```
+
+This also works for distance joins:
+
+```Scala
+pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+```
+
+Spark SQL Physical plan:
+```
+== Physical Plan ==
+BroadcastIndexJoin pointshape#52: geometry, BuildRight, BuildLeft, true, 2.0 ST_Distance(pointshape#52, pointshape#415) <= 2.0
+:- Project [st_point(cast(_c0#48 as decimal(24,20)), cast(_c1#49 as decimal(24,20))) AS pointshape#52]
+: +- FileScan csv
++- SpatialIndex pointshape#415: geometry, QUADTREE, [id=#1068]
+ +- Project [st_point(cast(_c0#48 as decimal(24,20)), cast(_c1#49 as decimal(24,20))) AS pointshape#415]
+ +- FileScan csv
+```
+
+Note: Ff the distance is an expression, it is only evaluated on the first argument to ST_Distance (`pointDf1` above).
+
## Predicate pushdown
Introduction: Given a join query and a predicate in the same WHERE clause, first executes the Predicate as a filter, then executes the join query*
diff --git a/docs/api/sql/GeoSparkSQL-Overview.md b/docs/api/sql/GeoSparkSQL-Overview.md
index cf596be..15edc9f 100644
--- a/docs/api/sql/GeoSparkSQL-Overview.md
+++ b/docs/api/sql/GeoSparkSQL-Overview.md
@@ -6,6 +6,11 @@ SedonaSQL supports SQL/MM Part3 Spatial SQL Standard. It includes four kinds of
var myDataFrame = sparkSession.sql("YOUR_SQL")
```
+Alternatively, `expr` and `selectExpr` can be used:
+```Scala
+myDataFrame.withColumn("geometry", expr("ST_*")).selectExpr("ST_*")
+```
+
* Constructor: Construct a Geometry given an input string or coordinates
* Example: ST_GeomFromWKT (string). Create a Geometry from a WKT String.
* Documentation: [Here](../GeoSparkSQL-Constructor)
diff --git a/python-adapter/.gitignore b/python-adapter/.gitignore
index b83d222..804f11a 100644
--- a/python-adapter/.gitignore
+++ b/python-adapter/.gitignore
@@ -1 +1,5 @@
/target/
+bin
+/.settings
+/.classpath
+/.project
\ No newline at end of file
diff --git a/python/tests/sql/test_predicate_join.py b/python/tests/sql/test_predicate_join.py
index dfce133..8aee350 100644
--- a/python/tests/sql/test_predicate_join.py
+++ b/python/tests/sql/test_predicate_join.py
@@ -16,6 +16,7 @@
# under the License.
from pyspark import Row
+from pyspark.sql.functions import broadcast, expr
from pyspark.sql.types import StructType, StringType, IntegerType, StructField, DoubleType
from tests import csv_polygon_input_location, csv_point_input_location, overlap_polygon_input_location, \
@@ -449,3 +450,49 @@ class TestPredicateJoin(TestBase):
equal_join_df.explain()
equal_join_df.show(3)
assert equal_join_df.count() == 0, f"Expected 0 but got {equal_join_df.count()}"
+
+ def test_st_contains_in_broadcast_join(self):
+ polygon_csv_df = self.spark.read.format("csv").\
+ option("delimiter", ",").\
+ option("header", "false").load(
+ csv_polygon_input_location
+ )
+ polygon_csv_df.createOrReplaceTempView("polygontable")
+ polygon_csv_df.show()
+
+ polygon_df = self.spark.sql(
+ "select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable")
+ polygon_df = polygon_df.repartition(7)
+ polygon_df.createOrReplaceTempView("polygondf")
+ polygon_df.show()
+
+ point_csv_df = self.spark.read.format("csv").\
+ option("delimiter", ",").\
+ option("header", "false").load(
+ csv_point_input_location
+ )
+ point_csv_df.createOrReplaceTempView("pointtable")
+ point_csv_df.show()
+
+ point_df = self.spark.sql(
+ "select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable")
+ point_df = point_df.repartition(9)
+ point_df.createOrReplaceTempView("pointdf")
+ point_df.show()
+
+ range_join_df = self.spark.sql(
+ "select /*+ BROADCAST(polygondf) */ * from polygondf, pointdf where ST_Contains(polygondf.polygonshape,pointdf.pointshape) ")
+
+ range_join_df.explain()
+ range_join_df.show(3)
+ assert range_join_df.rdd.getNumPartitions() == 9
+ assert range_join_df.count() == 1000
+
+ range_join_df = point_df.alias("pointdf").join(broadcast(polygon_df).alias("polygondf"), on=expr("ST_Contains(polygondf.polygonshape, pointdf.pointshape)"))
+
+ range_join_df.explain()
+ range_join_df.show(3)
+ assert range_join_df.rdd.getNumPartitions() == 9
+ assert range_join_df.count() == 1000
+
+
diff --git a/spark-version-converter.py b/spark-version-converter.py
index 8e5bca1..c04ffe2 100644
--- a/spark-version-converter.py
+++ b/spark-version-converter.py
@@ -23,7 +23,8 @@ spark2_anchor = 'SPARK2 anchor'
spark3_anchor = 'SPARK3 anchor'
files = ['sql/src/main/scala/org/apache/sedona/sql/UDF/UdfRegistrator.scala',
'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala',
- 'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala']
+ 'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala',
+ 'sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala']
def switch_version(line):
if line[:2] == '//':
diff --git a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala b/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
index 12870c8..3e60627 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
@@ -20,18 +20,16 @@ package org.apache.sedona.sql.utils
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDT.UdtRegistrator
-import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
import org.apache.spark.sql.{SQLContext, SparkSession}
+import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
object SedonaSQLRegistrator {
def registerAll(sqlContext: SQLContext): Unit = {
- sqlContext.experimental.extraStrategies = JoinQueryDetector :: Nil
- UdtRegistrator.registerAll()
- UdfRegistrator.registerAll(sqlContext)
+ registerAll(sqlContext.sparkSession)
}
def registerAll(sparkSession: SparkSession): Unit = {
- sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
+ sparkSession.experimental.extraStrategies = new JoinQueryDetector(sparkSession) :: Nil
UdtRegistrator.registerAll()
UdfRegistrator.registerAll(sparkSession)
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
index 4e71b86..959b8de 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
@@ -25,13 +25,15 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types.BooleanType
+abstract class ST_Predicate extends Expression
+
/**
* Test if leftGeometry full contains rightGeometry
*
* @param inputExpressions
*/
case class ST_Contains(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends ST_Predicate with CodegenFallback {
// This is a binary expression
assert(inputExpressions.length == 2)
@@ -62,7 +64,7 @@ case class ST_Contains(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Intersects(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends ST_Predicate with CodegenFallback {
override def nullable: Boolean = false
// This is a binary expression
@@ -92,7 +94,7 @@ case class ST_Intersects(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Within(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends ST_Predicate with CodegenFallback {
override def nullable: Boolean = false
// This is a binary expression
@@ -123,7 +125,7 @@ case class ST_Within(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Crosses(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends ST_Predicate with CodegenFallback {
override def nullable: Boolean = false
override def toString: String = s" **${ST_Crosses.getClass.getName}** "
@@ -153,7 +155,7 @@ case class ST_Crosses(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Overlaps(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends ST_Predicate with CodegenFallback {
override def nullable: Boolean = false
// This is a binary expression
@@ -183,7 +185,7 @@ case class ST_Overlaps(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Touches(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends ST_Predicate with CodegenFallback {
override def nullable: Boolean = false
// This is a binary expression
@@ -213,7 +215,7 @@ case class ST_Touches(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Equals(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends ST_Predicate with CodegenFallback {
override def nullable: Boolean = false
// This is a binary expression
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
new file mode 100644
index 0000000..0d9ec4f
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.join
+
+import collection.JavaConverters._
+
+import org.apache.sedona.core.spatialRDD.SpatialRDD
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, Predicate, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.index.SpatialIndex
+
+case class BroadcastIndexJoinExec(left: SparkPlan,
+ right: SparkPlan,
+ streamShape: Expression,
+ indexBuildSide: JoinSide,
+ windowJoinSide: JoinSide,
+ intersects: Boolean,
+ extraCondition: Option[Expression] = None,
+ radius: Option[Expression] = None)
+ extends BinaryExecNode
+ with TraitJoinQueryBase
+ with Logging {
+
+ override def output: Seq[Attribute] = left.output ++ right.output
+
+ // Using lazy val to avoid serialization
+ @transient private lazy val boundCondition: (InternalRow => Boolean) = extraCondition match {
+ case Some(condition) =>
+ Predicate.create(condition, output).eval _ // SPARK3 anchor
+// newPredicate(condition, output).eval _ // SPARK2 anchor
+ case None =>
+ (r: InternalRow) => true
+ }
+
+ private val (streamed, broadcast) = indexBuildSide match {
+ case LeftSide => (right, left.asInstanceOf[SpatialIndexExec])
+ case RightSide => (left, right.asInstanceOf[SpatialIndexExec])
+ }
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ private val (windowExpression, objectExpression) = if (indexBuildSide == windowJoinSide) {
+ (broadcast.shape, streamShape)
+ } else {
+ (streamShape, broadcast.shape)
+ }
+
+ private val spatialExpression = radius match {
+ case Some(r) if intersects => s"ST_Distance($windowExpression, $objectExpression) <= $r"
+ case Some(r) if !intersects => s"ST_Distance($windowExpression, $objectExpression) < $r"
+ case None if intersects => s"ST_Intersects($windowExpression, $objectExpression)"
+ case None if !intersects => s"ST_Contains($windowExpression, $objectExpression)"
+ }
+
+ override def simpleString(maxFields: Int): String = super.simpleString(maxFields) + s" $spatialExpression" // SPARK3 anchor
+// override def simpleString: String = super.simpleString + s" $spatialExpression" // SPARK2 anchor
+
+ private def windowBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd: SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
+ spatialRdd.getRawSpatialRDD.rdd.flatMap { row =>
+ val candidates = index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
+ candidates
+ .filter(candidate => if (intersects) candidate.intersects(row) else candidate.covers(row))
+ .map(candidate => (candidate, row))
+ }
+ }
+
+ private def objectBroadcastJoin(index: Broadcast[SpatialIndex], spatialRdd: SpatialRDD[Geometry]): RDD[(Geometry, Geometry)] = {
+ spatialRdd.getRawSpatialRDD.rdd.flatMap { row =>
+ val candidates = index.value.query(row.getEnvelopeInternal).iterator.asScala.asInstanceOf[Iterator[Geometry]]
+ candidates
+ .filter(candidate => if (intersects) row.intersects(candidate) else row.covers(candidate))
+ .map(candidate => (row, candidate))
+ }
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val boundStreamShape = BindReferences.bindReference(streamShape, streamed.output)
+ val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]
+
+ // If there's a radius and the objects are being broadcast, we need to build the CircleRDD on the window stream side
+ val streamShapes = radius match {
+ case Some(radiusExpression) if indexBuildSide != windowJoinSide =>
+ toCircleRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(radiusExpression, streamed.output))
+ case _ =>
+ toSpatialRDD(streamResultsRaw, boundStreamShape)
+ }
+
+ val broadcastIndex = broadcast.executeBroadcast[SpatialIndex]()
+
+ val pairs = (indexBuildSide, windowJoinSide) match {
+ case (LeftSide, LeftSide) => windowBroadcastJoin(broadcastIndex, streamShapes)
+ case (LeftSide, RightSide) => objectBroadcastJoin(broadcastIndex, streamShapes).map { case (left, right) => (right, left) }
+ case (RightSide, LeftSide) => objectBroadcastJoin(broadcastIndex, streamShapes)
+ case (RightSide, RightSide) => windowBroadcastJoin(broadcastIndex, streamShapes).map { case (left, right) => (right, left) }
+ }
+
+ pairs.mapPartitions { iter =>
+ val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
+ iter.map {
+ case (l, r) =>
+ val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
+ val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
+ joiner.join(leftRow, rightRow)
+ }.filter(boundCondition(_))
+ }
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 28087b4..a921979 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -48,7 +48,7 @@ case class DistanceJoinExec(left: SparkPlan,
buildExpr: Expression,
streamedRdd: RDD[UnsafeRow],
streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
- (toCircleRDD(buildRdd, buildExpr), toSpatialRdd(streamedRdd, streamedExpr))
+ (toCircleRDD(buildRdd, buildExpr), toSpatialRDD(streamedRdd, streamedExpr))
private def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index df624f9..18a8141 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -18,68 +18,112 @@
*/
package org.apache.spark.sql.sedona_sql.strategy.join
-import org.apache.spark.sql.Strategy
-import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan, LessThanOrEqual}
+import org.apache.sedona.core.enums.IndexType
+import org.apache.sedona.core.utils.SedonaConf
+import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan, LessThanOrEqual}
import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sedona_sql.expressions._
+
+
+case class JoinQueryDetection(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ leftShape: Expression,
+ rightShape: Expression,
+ intersects: Boolean,
+ extraCondition: Option[Expression] = None,
+ radius: Option[Expression] = None
+)
+
/**
* Plans `RangeJoinExec` for inner joins on spatial relationships ST_Contains(a, b)
* and ST_Intersects(a, b).
*
* Plans `DistanceJoinExec` for inner joins on spatial relationship ST_Distance(a, b) < r.
+ *
+ * Plans `BroadcastIndexJoinExec for inner joins on spatial relationships with a broadcast hint.
*/
-object JoinQueryDetector extends Strategy {
+class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
+
+ private def getJoinDetection(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ predicate: ST_Predicate,
+ extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = {
+ predicate match {
+ case ST_Contains(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, false, extraCondition))
+ case ST_Intersects(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, true, extraCondition))
+ case ST_Within(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(right, left, rightShape, leftShape, false, extraCondition))
+ case ST_Overlaps(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(right, left, rightShape, leftShape, false, extraCondition))
+ case ST_Touches(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, true, extraCondition))
+ case ST_Equals(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, false, extraCondition))
+ case ST_Crosses(Seq(leftShape, rightShape)) =>
+ Some(JoinQueryDetection(right, left, rightShape, leftShape, false, extraCondition))
+ case _ => None
+ }
+ }
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-
- // ST_Contains(a, b) - a contains b
-case Join(left, right, Inner, Some(ST_Contains(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Contains(Seq(leftShape, rightShape)))) => // SPARK2 anchor
- planSpatialJoin(left, right, Seq(leftShape, rightShape), false)
-
- // ST_Intersects(a, b) - a intersects b
-case Join(left, right, Inner, Some(ST_Intersects(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Intersects(Seq(leftShape, rightShape)))) => // SPARK2 anchor
- planSpatialJoin(left, right, Seq(leftShape, rightShape), true)
-
- // ST_WITHIN(a, b) - a is within b
-case Join(left, right, Inner, Some(ST_Within(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Within(Seq(leftShape, rightShape)))) => // SPARK2 anchor
- planSpatialJoin(right, left, Seq(rightShape, leftShape), false)
-
- // ST_Overlaps(a, b) - a overlaps b
-case Join(left, right, Inner, Some(ST_Overlaps(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Overlaps(Seq(leftShape, rightShape)))) => // SPARK2 anchor
- planSpatialJoin(right, left, Seq(rightShape, leftShape), false)
-
- // ST_Touches(a, b) - a touches b
-case Join(left, right, Inner, Some(ST_Touches(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Touches(Seq(leftShape, rightShape)))) => // SPARK2 anchor
- planSpatialJoin(left, right, Seq(leftShape, rightShape), true)
-
- // ST_Distance(a, b) <= radius consider boundary intersection
-case Join(left, right, Inner, Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius)), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius))) => // SPARK2 anchor
- planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, true)
-
- // ST_Distance(a, b) < radius don't consider boundary intersection
-case Join(left, right, Inner, Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius)), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius))) => // SPARK2 anchor
- planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, false)
-
- // ST_Equals(a, b) - a is equal to b
-case Join(left, right, Inner, Some(ST_Equals(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Equals(Seq(leftShape, rightShape)))) => // SPARK2 anchor
- planSpatialJoin(left, right, Seq(leftShape, rightShape), false)
-
- // ST_Crosses(a, b) - a crosses b
-case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) => // SPARK3 anchor
-//case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape)))) => // SPARK2 anchor
- planSpatialJoin(right, left, Seq(rightShape, leftShape), false)
-
+ case Join(left, right, Inner, condition, JoinHint(leftHint, rightHint)) => { // SPARK3 anchor
+// case Join(left, right, Inner, condition) => { // SPARK2 anchor
+ val broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST)) // SPARK3 anchor
+ val broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST)) // SPARK3 anchor
+// val broadcastLeft = left.isInstanceOf[ResolvedHint] && left.asInstanceOf[ResolvedHint].hints.broadcast // SPARK2 anchor
+// val broadcastRight = right.isInstanceOf[ResolvedHint] && right.asInstanceOf[ResolvedHint].hints.broadcast // SPARK2 anchor
+
+ val queryDetection: Option[JoinQueryDetection] = condition match {
+ case Some(predicate: ST_Predicate) =>
+ getJoinDetection(left, right, predicate)
+ case Some(And(predicate: ST_Predicate, extraCondition)) =>
+ getJoinDetection(left, right, predicate, Some(extraCondition))
+ case Some(And(extraCondition, predicate: ST_Predicate)) =>
+ getJoinDetection(left, right, predicate, Some(extraCondition))
+ case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, true, None, Some(radius)))
+ case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, true, Some(extraCondition), Some(radius)))
+ case Some(And(extraCondition, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, true, Some(extraCondition), Some(radius)))
+ case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, false, None, Some(radius)))
+ case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, false, Some(extraCondition), Some(radius)))
+ case Some(And(extraCondition, LessThan(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, false, Some(extraCondition), Some(radius)))
+ case _ =>
+ None
+ }
+
+ val sedonaConf = new SedonaConf(sparkSession.sparkContext.conf)
+
+ if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
+ queryDetection match {
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, intersects, extraCondition, radius)) =>
+ planBroadcastJoin(left, right, Seq(leftShape, rightShape), intersects, sedonaConf.getIndexType, broadcastLeft, extraCondition, radius)
+ case _ =>
+ Nil
+ }
+ } else {
+ queryDetection match {
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, intersects, extraCondition, None)) =>
+ planSpatialJoin(left, right, Seq(leftShape, rightShape), intersects, extraCondition)
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, intersects, extraCondition, Some(radius))) =>
+ planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, intersects, extraCondition)
+ case None =>
+ Nil
+ }
+ }
+ }
case _ =>
Nil
}
@@ -95,11 +139,11 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
private def matchExpressionsToPlans(exprA: Expression,
exprB: Expression,
planA: LogicalPlan,
- planB: LogicalPlan): Option[(LogicalPlan, LogicalPlan)] =
+ planB: LogicalPlan): Option[(LogicalPlan, LogicalPlan, Boolean)] =
if (matches(exprA, planA) && matches(exprB, planB)) {
- Some((planA, planB))
+ Some((planA, planB, false))
} else if (matches(exprA, planB) && matches(exprB, planA)) {
- Some((planB, planA))
+ Some((planB, planA, true))
} else {
None
}
@@ -115,7 +159,7 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
val relationship = if (intersects) "ST_Intersects" else "ST_Contains";
matchExpressionsToPlans(a, b, left, right) match {
- case Some((planA, planB)) =>
+ case Some((planA, planB, _)) =>
logInfo(s"Planning spatial join for $relationship relationship")
RangeJoinExec(planLater(planA), planLater(planB), a, b, intersects, extraCondition) :: Nil
case None =>
@@ -138,7 +182,7 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
val relationship = if (intersects) "ST_Distance <=" else "ST_Distance <";
matchExpressionsToPlans(a, b, left, right) match {
- case Some((planA, planB)) =>
+ case Some((planA, planB, _)) =>
if (radius.references.isEmpty || matches(radius, planA)) {
logInfo("Planning spatial distance join")
DistanceJoinExec(planLater(planA), planLater(planB), a, b, radius, intersects, extraCondition) :: Nil
@@ -158,4 +202,45 @@ case Join(left, right, Inner, Some(ST_Crosses(Seq(leftShape, rightShape))), _) =
Nil
}
}
+
+ private def planBroadcastJoin(left: LogicalPlan,
+ right: LogicalPlan,
+ children: Seq[Expression],
+ intersects: Boolean,
+ indexType: IndexType,
+ broadcastLeft: Boolean,
+ extraCondition: Option[Expression],
+ radius: Option[Expression]): Seq[SparkPlan] = {
+ val a = children.head
+ val b = children.tail.head
+
+ val relationship = radius match {
+ case Some(_) if intersects => "ST_Distance <="
+ case Some(_) if !intersects => "ST_Distance <"
+ case None if intersects => "ST_Intersects"
+ case None if !intersects => "ST_Contains"
+ }
+
+ matchExpressionsToPlans(a, b, left, right) match {
+ case Some((_, _, swapped)) =>
+ logInfo(s"Planning spatial join for $relationship relationship")
+ val broadcastSide = if (broadcastLeft) LeftSide else RightSide
+ val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide, swapped) match {
+ case (LeftSide, false) => // Broadcast the left side, windows on the left
+ (SpatialIndexExec(planLater(left), a, indexType, radius), planLater(right), b, LeftSide)
+ case (LeftSide, true) => // Broadcast the left side, objects on the left
+ (SpatialIndexExec(planLater(left), b, indexType), planLater(right), a, RightSide)
+ case (RightSide, false) => // Broadcast the right side, windows on the left
+ (planLater(left), SpatialIndexExec(planLater(right), b, indexType), a, LeftSide)
+ case (RightSide, true) => // Broadcast the right side, objects on the left
+ (planLater(left), SpatialIndexExec(planLater(right), a, indexType, radius), b, RightSide)
+ }
+ BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide, windowSide, intersects, extraCondition, radius) :: Nil
+ case None =>
+ logInfo(
+ s"Spatial join for $relationship with arguments not aligned " +
+ "with join relations is not supported")
+ Nil
+ }
+ }
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
new file mode 100644
index 0000000..c76ccfc
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.join
+
+import scala.collection.JavaConverters._
+
+import org.apache.sedona.core.enums.IndexType
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, UnsafeRow}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.exchange.Exchange
+
+
+
+case class SpatialIndexExec(child: SparkPlan,
+ shape: Expression,
+ indexType: IndexType,
+ radius: Option[Expression] = None)
+ extends Exchange
+ with TraitJoinQueryBase
+ with Logging {
+
+ override def output: Seq[Attribute] = child.output
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ "SpatialIndex does not support the execute() code path.")
+ }
+
+ override protected[sql] def doExecuteBroadcast[T](): Broadcast[T] = {
+ val boundShape = BindReferences.bindReference(shape, child.output)
+
+ val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
+
+ val spatialRDD = radius match {
+ case Some(radiusExpression) => toCircleRDD(resultRaw, boundShape, BindReferences.bindReference(radiusExpression, child.output))
+ case None => toSpatialRDD(resultRaw, boundShape)
+ }
+
+ spatialRDD.buildIndex(indexType, false)
+ sparkContext.broadcast(spatialRDD.indexedRawRDD.take(1).asScala.head).asInstanceOf[Broadcast[T]]
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
similarity index 52%
copy from sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
copy to sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 28087b4..3b7a0e1 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -20,37 +20,39 @@ package org.apache.spark.sql.sedona_sql.strategy.join
import org.apache.sedona.core.geometryObjects.Circle
import org.apache.sedona.core.spatialRDD.SpatialRDD
+import org.apache.sedona.core.utils.SedonaConf
import org.apache.sedona.sql.utils.GeometrySerializer
-import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
import org.locationtech.jts.geom.Geometry
-// ST_Distance(left, right) <= radius
-// radius can be literal or a computation over 'left'
-case class DistanceJoinExec(left: SparkPlan,
- right: SparkPlan,
- leftShape: Expression,
- rightShape: Expression,
- radius: Expression,
- intersects: Boolean,
- extraCondition: Option[Expression] = None)
- extends BinaryExecNode
- with TraitJoinQueryExec
- with Logging {
+trait TraitJoinQueryBase {
+ self: SparkPlan =>
- private val boundRadius = BindReferences.bindReference(radius, left.output)
+ def toSpatialRddPair(buildRdd: RDD[UnsafeRow],
+ buildExpr: Expression,
+ streamedRdd: RDD[UnsafeRow],
+ streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
+ (toSpatialRDD(buildRdd, buildExpr), toSpatialRDD(streamedRdd, streamedExpr))
- override def toSpatialRddPair(
- buildRdd: RDD[UnsafeRow],
- buildExpr: Expression,
- streamedRdd: RDD[UnsafeRow],
- streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
- (toCircleRDD(buildRdd, buildExpr), toSpatialRdd(streamedRdd, streamedExpr))
+ def toSpatialRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
+ val spatialRdd = new SpatialRDD[Geometry]
+ spatialRdd.setRawSpatialRDD(
+ rdd
+ .map { x => {
+ val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
+ //logInfo(shape.toString)
+ shape.setUserData(x.copy)
+ shape
+ }
+ }
+ .toJavaRDD())
+ spatialRdd
+ }
- private def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
+ def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(
rdd
@@ -65,4 +67,9 @@ case class DistanceJoinExec(left: SparkPlan,
spatialRdd
}
+ def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry], followerShapes: SpatialRDD[Geometry],
+ numPartitions: Integer, sedonaConf: SedonaConf): Unit = {
+ dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
+ followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
+ }
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index d00194b..9850b78 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -21,25 +21,21 @@ package org.apache.spark.sql.sedona_sql.strategy.join
import org.apache.sedona.core.enums.JoinSparitionDominantSide
import org.apache.sedona.core.spatialOperator.JoinQuery
import org.apache.sedona.core.spatialOperator.JoinQuery.JoinParams
-import org.apache.sedona.core.spatialRDD.SpatialRDD
import org.apache.sedona.core.utils.SedonaConf
-import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, Predicate, UnsafeRow}
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.SparkPlan
-import org.locationtech.jts.geom.Geometry
-trait TraitJoinQueryExec {
+trait TraitJoinQueryExec extends TraitJoinQueryBase {
self: SparkPlan =>
// Using lazy val to avoid serialization
@transient private lazy val boundCondition: (InternalRow => Boolean) = {
if (extraCondition.isDefined) {
-Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPARK3 anchor
-//newPredicate(extraCondition.get, left.output ++ right.output).eval _ // SPARK2 anchor
+ Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPARK3 anchor
+// newPredicate(extraCondition.get, left.output ++ right.output).eval _ // SPARK2 anchor
} else { (r: InternalRow) =>
true
}
@@ -134,8 +130,8 @@ Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPAR
matches.rdd.mapPartitions { iter =>
val filtered =
if (extraCondition.isDefined) {
-val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
-//val boundCondition = newPredicate(extraCondition.get, left.output ++ right.output) // SPARK2 anchor
+ val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
+// val boundCondition = newPredicate(extraCondition.get, left.output ++ right.output) // SPARK2 anchor
iter.filter {
case (l, r) =>
val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
@@ -157,35 +153,6 @@ val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.o
}
}
- def toSpatialRddPair(buildRdd: RDD[UnsafeRow],
- buildExpr: Expression,
- streamedRdd: RDD[UnsafeRow],
- streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
- (toSpatialRdd(buildRdd, buildExpr), toSpatialRdd(streamedRdd, streamedExpr))
-
- protected def toSpatialRdd(rdd: RDD[UnsafeRow],
- shapeExpression: Expression): SpatialRDD[Geometry] = {
-
- val spatialRdd = new SpatialRDD[Geometry]
- spatialRdd.setRawSpatialRDD(
- rdd
- .map { x => {
- val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
- //logInfo(shape.toString)
- shape.setUserData(x.copy)
- shape
- }
- }
- .toJavaRDD())
- spatialRdd
- }
-
- def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry], followerShapes: SpatialRDD[Geometry],
- numPartitions: Integer, sedonaConf: SedonaConf): Unit = {
- dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
- followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
- }
-
def joinPartitionNumOptimizer(dominantSidePartNum: Int, followerSidePartNum: Int, dominantSideCount: Long): Int = {
log.info("[SedonaSQL] Dominant side count: " + dominantSideCount)
var numPartition = -1
diff --git a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/enums.scala
similarity index 50%
copy from sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
copy to sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/enums.scala
index 12870c8..7ce02a5 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/enums.scala
@@ -16,27 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.sedona.sql.utils
+package org.apache.spark.sql.sedona_sql.strategy.join
-import org.apache.sedona.sql.UDF.UdfRegistrator
-import org.apache.sedona.sql.UDT.UdtRegistrator
-import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
-import org.apache.spark.sql.{SQLContext, SparkSession}
+sealed trait JoinSide
-object SedonaSQLRegistrator {
- def registerAll(sqlContext: SQLContext): Unit = {
- sqlContext.experimental.extraStrategies = JoinQueryDetector :: Nil
- UdtRegistrator.registerAll()
- UdfRegistrator.registerAll(sqlContext)
- }
-
- def registerAll(sparkSession: SparkSession): Unit = {
- sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
- UdtRegistrator.registerAll()
- UdfRegistrator.registerAll(sparkSession)
- }
-
- def dropAll(sparkSession: SparkSession): Unit = {
- UdfRegistrator.dropAll(sparkSession)
- }
-}
+case object LeftSide extends JoinSide
+case object RightSide extends JoinSide
diff --git a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
new file mode 100644
index 0000000..1689332
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sedona.sql
+
+import org.apache.spark.sql.sedona_sql.strategy.join.BroadcastIndexJoinExec
+import org.apache.spark.sql.functions._
+
+class BroadcastIndexJoinSuite extends TestBaseScala {
+
+ describe("Sedona-SQL Broadcast Index Join Test") {
+
+ // Using UDFs rather than lit prevents optimizations that would circumvent the checks we want to test
+ val one = udf(() => 1)
+ val two = udf(() => 2)
+
+ it("Passed Correct partitioning for broadcast join for ST_Polygon and ST_Point") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ var broadcastJoinDf = pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions == pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions == pointDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions == polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions == polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed Broadcasts the left side if both sides have a broadcast hint") {
+ val polygonDf = buildPolygonDf.repartition(3)
+ val pointDf = buildPointDf.repartition(5)
+
+ var broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.rdd.getNumPartitions == polygonDf.rdd.getNumPartitions)
+ assert(broadcastJoinDf.count() == 1000)
+ }
+
+ it("Passed Can access attributes of both sides of broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", one())
+
+ var broadcastJoinDf = polygonDf.alias("polygonDf").join(broadcast(pointDf).alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+
+ broadcastJoinDf = broadcast(polygonDf).alias("polygonDf").join(pointDf.alias("pointDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+
+ broadcastJoinDf = broadcast(pointDf).alias("pointDf").join(polygonDf.alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+
+ broadcastJoinDf = pointDf.alias("pointDf").join(broadcast(polygonDf).alias("polygonDf"), expr("ST_Contains(polygonDf.polygonshape, pointDf.pointshape)"))
+ assert(broadcastJoinDf.select(sum("object_extra")).collect().head(0) == 1000)
+ assert(broadcastJoinDf.select(sum("window_extra")).collect().head(0) == 1000)
+ }
+
+ it("Passed Handles extra conditions on a broadcast join") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one())
+ val pointDf = buildPointDf.withColumn("object_extra", two())
+
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra <= object_extra")
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("ST_Contains(polygonshape, pointshape) AND window_extra > object_extra")
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra <= object_extra AND ST_Contains(polygonshape, pointshape)")
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra > object_extra AND ST_Contains(polygonshape, pointshape)")
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+ }
+
+ it("Passed Handles multiple extra conditions on a broadcast join with the ST predicate last") {
+ val polygonDf = buildPolygonDf.withColumn("window_extra", one()).withColumn("window_extra2", one())
+ val pointDf = buildPointDf.withColumn("object_extra", two()).withColumn("object_extra2", two())
+
+ var broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra <= object_extra AND window_extra2 <= object_extra2 AND ST_Contains(polygonshape, pointshape)")
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 1000)
+
+ broadcastJoinDf = pointDf
+ .alias("pointDf")
+ .join(
+ broadcast(polygonDf.alias("polygonDf")),
+ expr("window_extra > object_extra AND window_extra2 > object_extra2 AND ST_Contains(polygonshape, pointshape)")
+ )
+
+ assert(broadcastJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(broadcastJoinDf.count() == 0)
+ }
+
+ it("Passed ST_Distance <= radius in a broadcast join") {
+ var pointDf1 = buildPointDf
+ var pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+ }
+
+ it("Passed ST_Distance < radius in a broadcast join") {
+ var pointDf1 = buildPointDf
+ var pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < 2"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+ }
+
+ it("Passed ST_Distance radius is bound to first expression") {
+ var pointDf1 = buildPointDf.withColumn("radius", two())
+ var pointDf2 = buildPointDf
+
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(pointDf2.alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = pointDf2.alias("pointDf2").join(broadcast(pointDf1).alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+
+ distanceJoinDf = broadcast(pointDf2).alias("pointDf2").join(pointDf1.alias("pointDf1"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) < radius"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect{ case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 2998)
+ }
+ }
+}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 4409106..1c33fe9 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -22,7 +22,7 @@ import org.apache.log4j.{Level, Logger}
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.utils.SedonaSQLRegistrator
import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.scalatest.{BeforeAndAfterAll, FunSpec}
trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
@@ -39,7 +39,6 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
.getOrCreate()
val resourceFolder = System.getProperty("user.dir") + "/../core/src/test/resources/"
-
val mixedWkbGeometryInputLocation = resourceFolder + "county_small_wkb.tsv"
val mixedWktGeometryInputLocation = resourceFolder + "county_small.tsv"
val shapefileInputLocation = resourceFolder + "shapefiles/dbf"
@@ -70,4 +69,11 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
//SedonaSQLRegistrator.dropAll(spark)
//spark.stop
}
+
+ def loadCsv(path: String): DataFrame = {
+ sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(path)
+ }
+
+ lazy val buildPointDf = loadCsv(csvPointInputLocation).selectExpr("ST_Point(cast(_c0 as Decimal(24,20)),cast(_c1 as Decimal(24,20))) as pointshape")
+ lazy val buildPolygonDf = loadCsv(csvPolygonInputLocation).selectExpr("ST_PolygonFromEnvelope(cast(_c0 as Decimal(24,20)),cast(_c1 as Decimal(24,20)), cast(_c2 as Decimal(24,20)), cast(_c3 as Decimal(24,20))) as polygonshape")
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index dffadb0..dc6bb0a 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -21,7 +21,6 @@ package org.apache.sedona.sql
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.Row
-import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
import org.apache.spark.sql.types._
class predicateJoinTestScala extends TestBaseScala {
@@ -127,8 +126,6 @@ class predicateJoinTestScala extends TestBaseScala {
}
it("Passed ST_Distance <= radius in a join") {
- sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
-
var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
pointCsvDF1.createOrReplaceTempView("pointtable")
var pointDf1 = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable")
@@ -145,8 +142,6 @@ class predicateJoinTestScala extends TestBaseScala {
}
it("Passed ST_Distance < radius in a join") {
- sparkSession.experimental.extraStrategies = JoinQueryDetector :: Nil
-
var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
pointCsvDF1.createOrReplaceTempView("pointtable")
var pointDf1 = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable")