You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/05/29 08:23:59 UTC
[sedona] 01/01: Add distanceJoin support for ST_DistanceSpheroid
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch geography-join
in repository https://gitbox.apache.org/repos/asf/sedona.git
commit dadc81861cb0a912aacf69881efbe3d579fd3673
Author: Jia Yu <ji...@apache.org>
AuthorDate: Mon May 29 01:23:36 2023 -0700
Add distanceJoin support for ST_DistanceSpheroid
---
docs/api/sql/Function.md | 6 +-
docs/api/sql/Optimizer.md | 37 ++++++++--
.../strategy/join/DistanceJoinExec.scala | 5 +-
.../strategy/join/JoinQueryDetector.scala | 84 +++++++++++++---------
.../strategy/join/SpatialIndexExec.scala | 3 +-
.../strategy/join/TraitJoinQueryBase.scala | 21 +++++-
.../sedona/sql/BroadcastIndexJoinSuite.scala | 19 +++++
.../org/apache/sedona/sql/TestBaseScala.scala | 21 ++++++
.../apache/sedona/sql/predicateJoinTestScala.scala | 18 ++++-
9 files changed, 163 insertions(+), 51 deletions(-)
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index b8643268..e92ce4f3 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -51,7 +51,7 @@ FROM polygondf
## ST_AreaSpheroid
-Introduction: Return the geodesic area of A using WGS84 spheroid. Unit is meter. Works better for large geometries (country level) compared to `ST_Area` + `ST_Transform`. It is equivalent to PostGIS `ST_Area(geography, use_spheroid=true)` function and produces nearly identical results.
+Introduction: Return the geodesic area of A using WGS84 spheroid. Unit is square meter. Works better for large geometries (country level) compared to `ST_Area` + `ST_Transform`. It is equivalent to PostGIS `ST_Area(geography, use_spheroid=true)` function and produces nearly identical results.
Geometry must be in EPSG:4326 (WGS84) projection and must be in ==lat/lon== order. You can use ==ST_FlipCoordinates== to swap lat and lon.
@@ -416,7 +416,7 @@ FROM polygondf
## ST_DistanceSphere
-Introduction: Return the haversine / great-circle distance of A using a given earth radius (default radius: 6378137.0). Unit is meter. Works better for large geometries (country level) compared to `ST_Distance` + `ST_Transform`. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=false)` and `ST_DistanceSphere` function and produces nearly identical results. It provides faster but less accurate result compared to `ST_DistanceSpheroid`.
+Introduction: Return the haversine / great-circle distance of A using a given earth radius (default radius: 6378137.0). Unit is meter. Compared to `ST_Distance` + `ST_Transform`, it works better for datasets that cover large regions such as continents or the entire planet. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=false)` and `ST_DistanceSphere` function and produces nearly identical results. It provides faster but less accurate result compared to `ST_DistanceSpheroid`.
Geometry must be in EPSG:4326 (WGS84) projection and must be in ==lat/lon== order. You can use ==ST_FlipCoordinates== to swap lat and lon. For non-point data, we first take the centroids of both geometries and then compute the distance.
@@ -441,7 +441,7 @@ Output: `544405.4459192449`
## ST_DistanceSpheroid
-Introduction: Return the geodesic distance of A using WGS84 spheroid. Unit is meter. Works better for large geometries (country level) compared to `ST_Distance` + `ST_Transform`. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=true)` and `ST_DistanceSpheroid` function and produces nearly identical results. It provides slower but more accurate result compared to `ST_DistanceSphere`.
+Introduction: Return the geodesic distance of A using WGS84 spheroid. Unit is meter. Compared to `ST_Distance` + `ST_Transform`, it works better for datasets that cover large regions such as continents or the entire planet. It is equivalent to PostGIS `ST_Distance(geography, use_spheroid=true)` and `ST_DistanceSpheroid` function and produces nearly identical results. It provides slower but more accurate result compared to `ST_DistanceSphere`.
Geometry must be in EPSG:4326 (WGS84) projection and must be in ==lat/lon== order. You can use ==ST_FlipCoordinates== to swap lat and lon. For non-point data, we first take the centroids of both geometries and then compute the distance.
diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md
index 025b964b..ca6cc09d 100644
--- a/docs/api/sql/Optimizer.md
+++ b/docs/api/sql/Optimizer.md
@@ -28,7 +28,9 @@ SELECT *
FROM pointdf, polygondf
WHERE ST_Within(pointdf.pointshape, polygondf.polygonshape)
```
+
Spark SQL Physical plan:
+
```
== Physical Plan ==
RangeJoin polygonshape#20: geometry, pointshape#43: geometry, false
@@ -44,9 +46,9 @@ RangeJoin polygonshape#20: geometry, pointshape#43: geometry, false
## Distance join
-Introduction: Find geometries from A and geometries from B such that the internal Euclidean distance of each geometry pair is less or equal than a certain distance
+Introduction: Find geometries from A and geometries from B such that the distance of each geometry pair is less or equal than a certain distance. It supports the planar Euclidean distance calculator `ST_Distance` and the meter-based geodesic distance calculator `ST_DistanceSpheroid`.
-Spark SQL Example:
+Spark SQL Example for planar Euclidean distance:
*Only consider ==fully within a certain distance==*
```sql
@@ -73,7 +75,26 @@ DistanceJoin pointshape1#12: geometry, pointshape2#33: geometry, 2.0, true
```
!!!warning
- Sedona doesn't control the distance's unit (degree or meter). It is same with the geometry. If your coordinates are in the longitude and latitude system, the unit of `distance` should be degree instead of meter or mile. To change the geometry's unit, please either transform the coordinate reference system to a meter-based system. See [ST_Transform](Function.md#st_transform). If you don't want to transform your data and are ok with sacrificing the query accuracy, you can use an approxima [...]
+ If you use `ST_Distance` as the predicate, Sedona doesn't control the distance's unit (degree or meter). It is same with the geometry. If your coordinates are in the longitude and latitude system, the unit of `distance` should be degree instead of meter or mile. To change the geometry's unit, please either transform the coordinate reference system to a meter-based system. See [ST_Transform](Function.md#st_transform). If you don't want to transform your data, please consider using `ST_Di [...]
+
+Spark SQL Example for meter-based geodesic distance:
+
+*Less than a certain distance==*
+```sql
+SELECT *
+FROM pointdf1, pointdf2
+WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) < 2
+```
+
+*Less than or equal to a certain distance==*
+```sql
+SELECT *
+FROM pointdf1, pointdf2
+WHERE ST_DistanceSpheroid(pointdf1.pointshape1,pointdf2.pointshape2) <= 2
+```
+
+!!!warning
+ If you use `ST_DistanceSpheroid ` as the predicate, the unit of the distance is meter. Currently, distance join with geodesic distance calculators work best for point data. For non-point data, it only considers their centroids.
## Broadcast index join
@@ -105,7 +126,7 @@ BroadcastIndexJoin pointshape#52: geometry, BuildRight, BuildRight, false ST_Con
+- FileScan csv
```
-This also works for distance joins:
+This also works for distance joins with `ST_Distance` or `ST_DistanceSpheroid`:
```scala
pointDf1.alias("pointDf1").join(broadcast(pointDf2).alias("pointDf2"), expr("ST_Distance(pointDf1.pointshape, pointDf2.pointshape) <= 2"))
@@ -202,14 +223,16 @@ GROUP BY (lcs_geom, rcs_geom)
This also works for distance join. You first need to use `ST_Buffer(geometry, distance)` to wrap one of your original geometry column. If your original geometry column contains points, this `ST_Buffer` will make them become circles with a radius of `distance`.
-For example. run this query first on the left table before Step 1.
+Since the coordinates are in the longitude and latitude system, so the unit of `distance` should be degree instead of meter or mile. You can get an approximation by performing `METER_DISTANCE/111000.0`, then filter out false-positives.
+
+In a nutshell, run this query first on the left table before Step 1. Please replace `METER_DISTANCE` with a meter distance. In Step 1, generate S2 IDs based on the `buffered_geom` column. Then run Step 2, 3, 4 on the original `geom` column.
```sql
-SELECT id, ST_Buffer(geom, DISTANCE), name
+SELECT id, geom , ST_Buffer(geom, METER_DISTANCE/111000.0) as buffered_geom, name
FROM lefts
```
-Since the coordinates are in the longitude and latitude system, so the unit of `distance` should be degree instead of meter or mile. You will have to estimate the corresponding degrees based on your meter values. Please use [this calculator](https://lucidar.me/en/online-unit-converter-length-to-angle/convert-degrees-to-meters/#online-converter).
+
## Regular spatial predicate pushdown
Introduction: Given a join query and a predicate in the same WHERE clause, first executes the Predicate as a filter, then executes the join query.
diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index b393f2d0..615f88a2 100644
--- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -54,6 +54,7 @@ case class DistanceJoinExec(left: SparkPlan,
distance: Expression,
distanceBoundToLeft: Boolean,
spatialPredicate: SpatialPredicate,
+ isGeography: Boolean,
extraCondition: Option[Expression] = None)
extends SedonaBinaryExecNode
with TraitJoinQueryExec
@@ -70,9 +71,9 @@ case class DistanceJoinExec(left: SparkPlan,
rightRdd: RDD[UnsafeRow],
rightShapeExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) = {
if (distanceBoundToLeft) {
- (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius), toSpatialRDD(rightRdd, rightShapeExpr))
+ (toExpandedEnvelopeRDD(leftRdd, leftShapeExpr, boundRadius, isGeography), toSpatialRDD(rightRdd, rightShapeExpr))
} else {
- (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd, rightShapeExpr, boundRadius))
+ (toSpatialRDD(leftRdd, leftShapeExpr), toExpandedEnvelopeRDD(rightRdd, rightShapeExpr, boundRadius, isGeography))
}
}
diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 464db3e7..4536ec50 100644
--- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -21,13 +21,13 @@ package org.apache.spark.sql.sedona_sql.strategy.join
import org.apache.sedona.core.enums.{IndexType, SpatialJoinOptimizationMode}
import org.apache.sedona.core.spatialOperator.SpatialPredicate
import org.apache.sedona.core.utils.SedonaConf
-import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, LessThan, LessThanOrEqual}
-import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, NaturalJoin, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sedona_sql.expressions._
import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
+import org.apache.spark.sql.{SparkSession, Strategy}
case class JoinQueryDetection(
@@ -36,6 +36,7 @@ case class JoinQueryDetection(
leftShape: Expression,
rightShape: Expression,
spatialPredicate: SpatialPredicate,
+ isGeography: Boolean,
extraCondition: Option[Expression] = None,
distance: Option[Expression] = None
)
@@ -57,23 +58,23 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
extraCondition: Option[Expression] = None): Option[JoinQueryDetection] = {
predicate match {
case ST_Contains(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CONTAINS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CONTAINS, false, extraCondition))
case ST_Intersects(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, extraCondition))
case ST_Within(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.WITHIN, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.WITHIN, false, extraCondition))
case ST_Covers(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, false, extraCondition))
case ST_CoveredBy(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERED_BY, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERED_BY, false, extraCondition))
case ST_Overlaps(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.OVERLAPS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.OVERLAPS, false, extraCondition))
case ST_Touches(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.TOUCHES, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.TOUCHES, false, extraCondition))
case ST_Equals(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.EQUALS, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.EQUALS, false, extraCondition))
case ST_Crosses(Seq(leftShape, rightShape)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CROSSES, extraCondition))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.CROSSES, false, extraCondition))
case _ => None
}
}
@@ -109,20 +110,32 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
getJoinDetection(left, right, predicate, Some(extraCondition))
case Some(And(extraCondition, predicate: ST_Predicate)) =>
getJoinDetection(left, right, predicate, Some(extraCondition))
-
// For distance joins we execute the actual predicate (condition) and not only extraConditions.
case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+ // ST_DistanceSpheroid
+ case Some(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(_, LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case Some(And(_, LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
case _ =>
None
}
@@ -131,20 +144,20 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
queryDetection match {
- case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, distance)) =>
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, isGeography, extraCondition, distance)) =>
planBroadcastJoin(
left, right, Seq(leftShape, rightShape), joinType,
spatialPredicate, sedonaConf.getIndexType,
- broadcastLeft, broadcastRight, extraCondition, distance)
+ broadcastLeft, broadcastRight, isGeography, extraCondition, distance)
case _ =>
Nil
}
} else {
queryDetection match {
- case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, None)) =>
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, isGeography, extraCondition, None)) =>
planSpatialJoin(left, right, Seq(leftShape, rightShape), joinType, spatialPredicate, extraCondition)
- case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, Some(distance))) =>
- planDistanceJoin(left, right, Seq(leftShape, rightShape), joinType, distance, spatialPredicate, extraCondition)
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, isGeography, extraCondition, Some(distance))) =>
+ planDistanceJoin(left, right, Seq(leftShape, rightShape), joinType, distance, spatialPredicate, isGeography, extraCondition)
case None =>
Nil
}
@@ -236,6 +249,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
joinType: JoinType,
distance: Expression,
spatialPredicate: SpatialPredicate,
+ isGeography: Boolean,
extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
if (joinType != Inner) {
@@ -252,11 +266,11 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
case Some(LeftSide) =>
logInfo("Planning spatial distance join, distance bound to left relation")
DistanceJoinExec(planLater(left), planLater(right), leftShape, rightShape, distance, distanceBoundToLeft = true,
- spatialPredicate, extraCondition) :: Nil
+ spatialPredicate, isGeography, extraCondition) :: Nil
case Some(RightSide) =>
logInfo("Planning spatial distance join, distance bound to right relation")
DistanceJoinExec(planLater(left), planLater(right), leftShape, rightShape, distance, distanceBoundToLeft = false,
- spatialPredicate, extraCondition) :: Nil
+ spatialPredicate, isGeography, extraCondition) :: Nil
case _ =>
logInfo(
"Spatial distance join for ST_Distance with non-scalar distance " +
@@ -280,6 +294,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
indexType: IndexType,
broadcastLeft: Boolean,
broadcastRight: Boolean,
+ isGeography: Boolean,
extraCondition: Option[Expression],
distance: Option[Expression]): Seq[SparkPlan] = {
@@ -300,12 +315,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
val a = children.head
val b = children.tail.head
- val relationship = (distance, spatialPredicate) match {
- case (Some(_), SpatialPredicate.INTERSECTS) => "ST_Distance <="
- case (Some(_), _) => "ST_Distance <"
- case (None, _) => s"ST_$spatialPredicate"
+ val relationship = (distance, spatialPredicate, isGeography) match {
+ case (Some(_), SpatialPredicate.INTERSECTS, false) => "ST_Distance <="
+ case (Some(_), _, false) => "ST_Distance <"
+ case (Some(_), SpatialPredicate.INTERSECTS, true) => "ST_Distance (Geography) <="
+ case (Some(_), _, true) => "ST_Distance (Geography) <"
+ case (None, _, false) => s"ST_$spatialPredicate"
}
-
val (distanceOnIndexSide, distanceOnStreamSide) = distance.map { distanceExpr =>
matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
case Some(side) =>
@@ -321,13 +337,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
logInfo(s"Planning spatial join for $relationship relationship")
val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide.get, swapped) match {
case (LeftSide, false) => // Broadcast the left side, windows on the left
- (SpatialIndexExec(planLater(left), a, indexType, distanceOnIndexSide), planLater(right), b, LeftSide)
+ (SpatialIndexExec(planLater(left), a, indexType, isGeography, distanceOnIndexSide), planLater(right), b, LeftSide)
case (LeftSide, true) => // Broadcast the left side, objects on the left
- (SpatialIndexExec(planLater(left), b, indexType, distanceOnIndexSide), planLater(right), a, RightSide)
+ (SpatialIndexExec(planLater(left), b, indexType, isGeography, distanceOnIndexSide), planLater(right), a, RightSide)
case (RightSide, false) => // Broadcast the right side, windows on the left
- (planLater(left), SpatialIndexExec(planLater(right), b, indexType, distanceOnIndexSide), a, LeftSide)
+ (planLater(left), SpatialIndexExec(planLater(right), b, indexType, isGeography, distanceOnIndexSide), a, LeftSide)
case (RightSide, true) => // Broadcast the right side, objects on the left
- (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distanceOnIndexSide), b, RightSide)
+ (planLater(left), SpatialIndexExec(planLater(right), a, indexType, isGeography, distanceOnIndexSide), b, RightSide)
}
BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide.get, windowSide, joinType,
spatialPredicate, extraCondition, distanceOnStreamSide) :: Nil
diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index f9f9a4ed..2c24a34e 100644
--- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
case class SpatialIndexExec(child: SparkPlan,
shape: Expression,
indexType: IndexType,
+ isGeography: Boolean,
distance: Option[Expression] = None)
extends SedonaUnaryExecNode
with TraitJoinQueryBase
@@ -51,7 +52,7 @@ case class SpatialIndexExec(child: SparkPlan,
val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
val spatialRDD = distance match {
- case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, boundShape, BindReferences.bindReference(distanceExpression, child.output))
+ case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, boundShape, BindReferences.bindReference(distanceExpression, child.output), isGeography)
case None => toSpatialRDD(resultRaw, boundShape)
}
diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index df8f4cd3..0f636ab3 100644
--- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -48,14 +48,14 @@ trait TraitJoinQueryBase {
spatialRdd
}
- def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = {
+ def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression, isGeography: Boolean): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(
rdd
.map { x =>
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[Array[Byte]])
val envelope = shape.getEnvelopeInternal.copy()
- envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double])
+ envelope.expandBy(distanceToDegree(boundRadius.eval(x).asInstanceOf[Double], isGeography))
val expandedEnvelope = shape.getFactory.toGeometry(envelope)
expandedEnvelope.setUserData(x.copy)
@@ -72,4 +72,21 @@ trait TraitJoinQueryBase {
followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
}
}
+
+ /**
+ * Convert distance to degree based on the given isGeography flag.
+ * Note that this is an approximation since the degree of longitude is not constant.
+ * We assume that the degree of longitude is 111000 meters without considering the latitude.
+ * For latitude, the degree is always 111000 meters.
+ * @param distance
+ * @param isGeography
+ * @return
+ */
+ private def distanceToDegree(distance: Double, isGeography: Boolean): Double = {
+ if (isGeography) {
+ distance / 111000.0
+ } else {
+ distance
+ }
+ }
}
diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index 78f3aaf5..0dd17104 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -302,6 +302,25 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(rows2(0) == Row(2.0, 2.0, "left_2", 2.0, 2.0, "right_2"))
}
+
+ it("Passed ST_DistanceSpheroid in a broadcast join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ broadcast(pointDf2).alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 2.0"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 89)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 2.0"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 89)
+
+ distanceJoinDf = broadcast(pointDf1).alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) < 2.0"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: BroadcastIndexJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 89)
+ }
}
describe("Sedona-SQL Broadcast Index Join Test for left semi joins") {
diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index f12b5874..c9116533 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -20,6 +20,7 @@ package org.apache.sedona.sql
import com.google.common.math.DoubleMath
import org.apache.log4j.{Level, Logger}
+import org.apache.sedona.common.sphere.{Haversine, Spheroid}
import org.apache.sedona.core.serde.SedonaKryoRegistrator
import org.apache.sedona.sql.utils.SedonaSQLRegistrator
import org.apache.spark.serializer.KryoSerializer
@@ -97,4 +98,24 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
}
}
+ protected def bruteForceDistanceJoinCountSpheroid(distance: Double): Int = {
+ buildPointDf.collect().map(row => {
+ val point1 = row.getAs[org.locationtech.jts.geom.Point](0)
+ buildPointDf.collect().map(row => {
+ val point2 = row.getAs[org.locationtech.jts.geom.Point](0)
+ if (Spheroid.distance(point1, point2) <= distance) 1 else 0
+ }).sum
+ }).sum
+ }
+
+ protected def bruteForceDistanceJoinCountSphere(distance: Double): Int = {
+ buildPointDf.collect().map(row => {
+ val point1 = row.getAs[org.locationtech.jts.geom.Point](0)
+ buildPointDf.collect().map(row => {
+ val point2 = row.getAs[org.locationtech.jts.geom.Point](0)
+ if (Haversine.distance(point1, point2) <= distance) 1 else 0
+ }).sum
+ }).sum
+ }
+
}
diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index 596e561b..037b75c5 100644
--- a/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ b/sql/common/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -21,10 +21,10 @@ package org.apache.sedona.sql
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.Row
-import org.apache.spark.sql.execution.ExplainMode
+import org.apache.spark.sql.functions.expr
+import org.apache.spark.sql.sedona_sql.strategy.join.DistanceJoinExec
import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Geometry
-import org.locationtech.jts.io.WKTWriter
class predicateJoinTestScala extends TestBaseScala {
@@ -361,5 +361,19 @@ class predicateJoinTestScala extends TestBaseScala {
assert(equalJoinDf.count() == 0, s"Expected 0 but got ${equalJoinDf.count()}")
}
+
+ it("Passed ST_DistanceSpheroid in a spatial join") {
+ val pointDf1 = buildPointDf
+ val pointDf2 = buildPointDf
+ var distanceJoinDf = pointDf1.alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) <= 2.0"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: DistanceJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 89)
+
+ distanceJoinDf = pointDf1.alias("pointDf1").join(
+ pointDf2.alias("pointDf2"), expr("ST_DistanceSpheroid(pointDf1.pointshape, pointDf2.pointshape) < 2.0"))
+ assert(distanceJoinDf.queryExecution.sparkPlan.collect { case p: DistanceJoinExec => p }.size === 1)
+ assert(distanceJoinDf.count() == 89)
+ }
}
}