You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/10/21 02:58:17 UTC
[incubator-sedona] branch master updated: [SEDONA-178] Correctness issue in distance joins (#701)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 8d7e3fbe [SEDONA-178] Correctness issue in distance joins (#701)
8d7e3fbe is described below
commit 8d7e3fbedad3b643e409714c055b654f742cd812
Author: Martin Andersson <u....@gmail.com>
AuthorDate: Fri Oct 21 04:58:11 2022 +0200
[SEDONA-178] Correctness issue in distance joins (#701)
---
docs/tutorial/core-python.md | 2 +
docs/tutorial/rdd.md | 2 +
pom.xml | 4 ++
.../strategy/join/BroadcastIndexJoinExec.scala | 12 ++--
.../strategy/join/DistanceJoinExec.scala | 45 +++++++-------
.../strategy/join/JoinQueryDetector.scala | 70 +++++++++++-----------
.../strategy/join/SpatialIndexExec.scala | 6 +-
.../strategy/join/TraitJoinQueryBase.scala | 12 ++--
.../sedona/sql/BroadcastIndexJoinSuite.scala | 23 ++++++-
.../org/apache/sedona/sql/SpatialJoinSuite.scala | 5 +-
.../org/apache/sedona/sql/functionTestScala.scala | 13 ++--
.../apache/sedona/sql/predicateJoinTestScala.scala | 20 ++++++-
12 files changed, 130 insertions(+), 84 deletions(-)
diff --git a/docs/tutorial/core-python.md b/docs/tutorial/core-python.md
index 3d4d1c3b..158071a8 100644
--- a/docs/tutorial/core-python.md
+++ b/docs/tutorial/core-python.md
@@ -450,6 +450,8 @@ Each object on the left is covered/intersected by the object on the right.
## Write a Distance Join Query
+!!!warning RDD distance joins are only reliable for points. For other geometry types, please use Spatial SQL.
+
A distance join query takes two spatial RDD assuming that we have two SpatialRDD's:
<li> object_rdd </li>
<li> spatial_rdd </li>
diff --git a/docs/tutorial/rdd.md b/docs/tutorial/rdd.md
index 8ea06ee6..ca0f177a 100644
--- a/docs/tutorial/rdd.md
+++ b/docs/tutorial/rdd.md
@@ -466,6 +466,8 @@ Each object on the left is covered/intersected by the object on the right.
## Write a Distance Join Query
+!!!warning RDD distance joins are only reliable for points. For other geometry types, please use Spatial SQL.
+
A distance join query takes as input two Spatial RDD A and B and a distance. For each geometry in A, finds the geometries (from B) are within the given distance to it. A and B can be any geometry type and are not necessary to have the same geometry type. The unit of the distance is explained [here](#transform-the-coordinate-reference-system).
Assume you now have two SpatialRDDs (typed or generic). You can use the following code to issue an Distance Join Query on them.
diff --git a/pom.xml b/pom.xml
index 43e8d10b..3e373901 100644
--- a/pom.xml
+++ b/pom.xml
@@ -68,6 +68,10 @@
<sedona.jackson.version>2.13.3</sedona.jackson.version>
<hadoop.version>3.2.4</hadoop.version>
<maven.deploy.skip>false</maven.deploy.skip>
+
+ <!-- Actual scala version will be set by a profile.
+ Setting a default value helps IDE:s that can't make sense of profiles. -->
+ <scala.compat.version>2.12</scala.compat.version>
</properties>
<dependencies>
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index f60b8947..7925c552 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -42,7 +42,7 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
windowJoinSide: JoinSide,
spatialPredicate: SpatialPredicate,
extraCondition: Option[Expression] = None,
- radius: Option[Expression] = None)
+ distance: Option[Expression] = None)
extends SedonaBinaryExecNode
with TraitJoinQueryBase
with Logging {
@@ -71,7 +71,7 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
(streamShape, broadcast.shape)
}
- private val spatialExpression = (radius, spatialPredicate) match {
+ private val spatialExpression = (distance, spatialPredicate) match {
case (Some(r), SpatialPredicate.INTERSECTS) => s"ST_Distance($windowExpression, $objectExpression) <= $r"
case (Some(r), _) => s"ST_Distance($windowExpression, $objectExpression) < $r"
case (None, _) => s"ST_$spatialPredicate($windowExpression, $objectExpression)"
@@ -104,10 +104,10 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
val boundStreamShape = BindReferences.bindReference(streamShape, streamed.output)
val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]
- // If there's a radius and the objects are being broadcast, we need to build the CircleRDD on the window stream side
- val streamShapes = radius match {
- case Some(radiusExpression) if indexBuildSide != windowJoinSide =>
- toCircleRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(radiusExpression, streamed.output))
+ // If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
+ val streamShapes = distance match {
+ case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
+ toExpandedEnvelopeRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(distanceExpression, streamed.output))
case _ =>
toSpatialRDD(streamResultsRaw, boundStreamShape)
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index ebaf0005..4df455ef 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -18,54 +18,53 @@
*/
package org.apache.spark.sql.sedona_sql.strategy.join
-import org.apache.sedona.common.geometryObjects.Circle
import org.apache.sedona.core.spatialOperator.SpatialPredicate
import org.apache.sedona.core.spatialRDD.SpatialRDD
-import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
-import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
import org.locationtech.jts.geom.Geometry
-// ST_Distance(left, right) <= radius
-// radius can be literal or a computation over 'left'
+/**
+ * Distance joins requires matching geometries to be in the same partition, despite not necessarily overlapping.
+ * To create an overlap and guarantee matching geometries end up in the same partition, the left geometry is expanded
+ * before partitioning. It's the logical equivalent of:
+ *
+ * select * from a join b on ST_Distance(a.geom, b.geom) <= 1
+ *
+ * becomes
+ *
+ * select * from a join b on ST_Intersects(ST_Envelope(ST_Buffer(a.geom, 1)), b.geom) and ST_Distance(a.geom, b.geom) <= 1
+ *
+ * @param left
+ * @param right
+ * @param leftShape
+ * @param rightShape
+ * @param distance - ST_Distance(left, right) <= distance. Distance can be literal or a computation over 'left'.
+ * @param spatialPredicate
+ * @param extraCondition
+ */
case class DistanceJoinExec(left: SparkPlan,
right: SparkPlan,
leftShape: Expression,
rightShape: Expression,
- radius: Expression,
+ distance: Expression,
spatialPredicate: SpatialPredicate,
extraCondition: Option[Expression] = None)
extends SedonaBinaryExecNode
with TraitJoinQueryExec
with Logging {
- private val boundRadius = BindReferences.bindReference(radius, left.output)
+ private val boundRadius = BindReferences.bindReference(distance, left.output)
override def toSpatialRddPair(
buildRdd: RDD[UnsafeRow],
buildExpr: Expression,
streamedRdd: RDD[UnsafeRow],
streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
- (toCircleRDD(buildRdd, buildExpr), toSpatialRDD(streamedRdd, streamedExpr))
-
- private def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
- val spatialRdd = new SpatialRDD[Geometry]
- spatialRdd.setRawSpatialRDD(
- rdd
- .map { x => {
- val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
- val circle = new Circle(shape, boundRadius.eval(x).asInstanceOf[Double])
- circle.setUserData(x.copy)
- circle.asInstanceOf[Geometry]
- }
- }
- .toJavaRDD())
- spatialRdd
- }
+ (toExpandedEnvelopeRDD(buildRdd, buildExpr, boundRadius), toSpatialRDD(streamedRdd, streamedExpr))
protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = {
copy(left = newLeft, right = newRight)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 6a020e33..1e419b17 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -30,13 +30,13 @@ import org.apache.spark.sql.sedona_sql.expressions._
case class JoinQueryDetection(
- left: LogicalPlan,
- right: LogicalPlan,
- leftShape: Expression,
- rightShape: Expression,
- spatialPredicate: SpatialPredicate,
- extraCondition: Option[Expression] = None,
- radius: Option[Expression] = None
+ left: LogicalPlan,
+ right: LogicalPlan,
+ leftShape: Expression,
+ rightShape: Expression,
+ spatialPredicate: SpatialPredicate,
+ extraCondition: Option[Expression] = None,
+ distance: Option[Expression] = None
)
/**
@@ -92,18 +92,20 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
getJoinDetection(left, right, predicate, Some(extraCondition))
case Some(And(extraCondition, predicate: ST_Predicate)) =>
getJoinDetection(left, right, predicate, Some(extraCondition))
- case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, None, Some(radius)))
- case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, Some(extraCondition), Some(radius)))
- case Some(And(extraCondition, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, Some(extraCondition), Some(radius)))
- case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, None, Some(radius)))
- case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, Some(extraCondition), Some(radius)))
- case Some(And(extraCondition, LessThan(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, Some(extraCondition), Some(radius)))
+
+ // For distance joins we execute the actual predicate (condition) and not only extraConditions.
+ case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance), _)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+ case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)), distance))) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
case _ =>
None
}
@@ -112,8 +114,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
queryDetection match {
- case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, radius)) =>
- planBroadcastJoin(left, right, Seq(leftShape, rightShape), spatialPredicate, sedonaConf.getIndexType, broadcastLeft, extraCondition, radius)
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, distance)) =>
+ planBroadcastJoin(left, right, Seq(leftShape, rightShape), spatialPredicate, sedonaConf.getIndexType, broadcastLeft, extraCondition, distance)
case _ =>
Nil
}
@@ -121,8 +123,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
queryDetection match {
case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, None)) =>
planSpatialJoin(left, right, Seq(leftShape, rightShape), spatialPredicate, extraCondition)
- case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, Some(radius))) =>
- planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, spatialPredicate, extraCondition)
+ case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, Some(distance))) =>
+ planDistanceJoin(left, right, Seq(leftShape, rightShape), distance, spatialPredicate, extraCondition)
case None =>
Nil
}
@@ -176,7 +178,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
private def planDistanceJoin(left: LogicalPlan,
right: LogicalPlan,
children: Seq[Expression],
- radius: Expression,
+ distance: Expression,
spatialPredicate: SpatialPredicate,
extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
val a = children.head
@@ -184,15 +186,15 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
matchExpressionsToPlans(a, b, left, right) match {
case Some((planA, planB, _)) =>
- if (radius.references.isEmpty || matches(radius, planA)) {
+ if (distance.references.isEmpty || matches(distance, planA)) {
logInfo("Planning spatial distance join")
- DistanceJoinExec(planLater(planA), planLater(planB), a, b, radius, spatialPredicate, extraCondition) :: Nil
- } else if (matches(radius, planB)) {
+ DistanceJoinExec(planLater(planA), planLater(planB), a, b, distance, spatialPredicate, extraCondition) :: Nil
+ } else if (matches(distance, planB)) {
logInfo("Planning spatial distance join")
- DistanceJoinExec(planLater(planB), planLater(planA), b, a, radius, spatialPredicate, extraCondition) :: Nil
+ DistanceJoinExec(planLater(planB), planLater(planA), b, a, distance, spatialPredicate, extraCondition) :: Nil
} else {
logInfo(
- "Spatial distance join for ST_Distance with non-scalar radius " +
+ "Spatial distance join for ST_Distance with non-scalar distance " +
"that is not a computation over just one side of the join is not supported")
Nil
}
@@ -211,11 +213,11 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
indexType: IndexType,
broadcastLeft: Boolean,
extraCondition: Option[Expression],
- radius: Option[Expression]): Seq[SparkPlan] = {
+ distance: Option[Expression]): Seq[SparkPlan] = {
val a = children.head
val b = children.tail.head
- val relationship = (radius, spatialPredicate) match {
+ val relationship = (distance, spatialPredicate) match {
case (Some(_), SpatialPredicate.INTERSECTS) => "ST_Distance <="
case (Some(_), _) => "ST_Distance <"
case (None, _) => s"ST_$spatialPredicate"
@@ -227,15 +229,15 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
val broadcastSide = if (broadcastLeft) LeftSide else RightSide
val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide, swapped) match {
case (LeftSide, false) => // Broadcast the left side, windows on the left
- (SpatialIndexExec(planLater(left), a, indexType, radius), planLater(right), b, LeftSide)
+ (SpatialIndexExec(planLater(left), a, indexType, distance), planLater(right), b, LeftSide)
case (LeftSide, true) => // Broadcast the left side, objects on the left
(SpatialIndexExec(planLater(left), b, indexType), planLater(right), a, RightSide)
case (RightSide, false) => // Broadcast the right side, windows on the left
(planLater(left), SpatialIndexExec(planLater(right), b, indexType), a, LeftSide)
case (RightSide, true) => // Broadcast the right side, objects on the left
- (planLater(left), SpatialIndexExec(planLater(right), a, indexType, radius), b, RightSide)
+ (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distance), b, RightSide)
}
- BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide, windowSide, spatialPredicate, extraCondition, radius) :: Nil
+ BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide, windowSide, spatialPredicate, extraCondition, distance) :: Nil
case None =>
logInfo(
s"Spatial join for $relationship with arguments not aligned " +
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index 43eb7191..f9f9a4ed 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
case class SpatialIndexExec(child: SparkPlan,
shape: Expression,
indexType: IndexType,
- radius: Option[Expression] = None)
+ distance: Option[Expression] = None)
extends SedonaUnaryExecNode
with TraitJoinQueryBase
with Logging {
@@ -50,8 +50,8 @@ case class SpatialIndexExec(child: SparkPlan,
val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
- val spatialRDD = radius match {
- case Some(radiusExpression) => toCircleRDD(resultRaw, boundShape, BindReferences.bindReference(radiusExpression, child.output))
+ val spatialRDD = distance match {
+ case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, boundShape, BindReferences.bindReference(distanceExpression, child.output))
case None => toSpatialRDD(resultRaw, boundShape)
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 46a8ac4a..22d44a4a 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -18,7 +18,6 @@
*/
package org.apache.spark.sql.sedona_sql.strategy.join
-import org.apache.sedona.common.geometryObjects.Circle
import org.apache.sedona.core.spatialRDD.SpatialRDD
import org.apache.sedona.core.utils.SedonaConf
import org.apache.sedona.sql.utils.GeometrySerializer
@@ -52,15 +51,18 @@ trait TraitJoinQueryBase {
spatialRdd
}
- def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = {
+ def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = {
val spatialRdd = new SpatialRDD[Geometry]
spatialRdd.setRawSpatialRDD(
rdd
.map { x => {
val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
- val circle = new Circle(shape, boundRadius.eval(x).asInstanceOf[Double])
- circle.setUserData(x.copy)
- circle.asInstanceOf[Geometry]
+ val envelope = shape.getEnvelopeInternal.copy()
+ envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double])
+
+ val expandedEnvelope = shape.getFactory.toGeometry(envelope)
+ expandedEnvelope.setUserData(x.copy)
+ expandedEnvelope
}
}
.toJavaRDD())
diff --git a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index 8395063e..445af754 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -156,7 +156,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(broadcastJoinDf.count() == 0)
}
- it("Passed ST_Distance <= radius in a broadcast join") {
+ it("Passed ST_Distance <= distance in a broadcast join") {
var pointDf1 = buildPointDf
var pointDf2 = buildPointDf
@@ -169,7 +169,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(distanceJoinDf.count() == 2998)
}
- it("Passed ST_Distance < radius in a broadcast join") {
+ it("Passed ST_Distance < distance in a broadcast join") {
var pointDf1 = buildPointDf
var pointDf2 = buildPointDf
@@ -182,7 +182,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(distanceJoinDf.count() == 2998)
}
- it("Passed ST_Distance radius is bound to first expression") {
+ it("Passed ST_Distance distance is bound to first expression") {
var pointDf1 = buildPointDf.withColumn("radius", two())
var pointDf2 = buildPointDf
@@ -229,5 +229,22 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
assert(broadcastJoinDf.count() == 1000)
sparkSession.conf.set("spark.sql.adaptive.enabled", false)
}
+
+ it("Passed broadcast distance join with LineString") {
+ assert(sparkSession.sql(
+ """
+ |select /*+ BROADCAST(a) */ *
+ |from (select ST_LineFromText('LineString(1 1, 1 3, 3 3)') as geom) a
+ |join (select ST_Point(2.0,2.0) as geom) b
+ |on ST_Distance(a.geom, b.geom) < 0.1
+ |""".stripMargin).isEmpty)
+ assert(sparkSession.sql(
+ """
+ |select /*+ BROADCAST(a) */ *
+ |from (select ST_LineFromText('LineString(1 1, 1 4)') as geom) a
+ |join (select ST_Point(1.0,5.0) as geom) b
+ |on ST_Distance(a.geom, b.geom) < 1.5
+ |""".stripMargin).count() == 1)
+ }
}
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 42da9063..62de4f5b 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -128,11 +128,10 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
case "ST_Touches" => (l: Geometry, r: Geometry) => l.touches(r)
case "ST_Within" => (l: Geometry, r: Geometry) => l.within(r)
case "ST_Distance" =>
- // XXX: ST_Distance has a weird behavior, it is wildly different from `l.distance(r)`.
if (joinCondition.contains("<=")) {
- (l: Geometry, r: Geometry) => new Circle(l, 1.0).intersects(r)
+ (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
} else {
- (l: Geometry, r: Geometry) => new Circle(l, 1.0).covers(r)
+ (l: Geometry, r: Geometry) => l.distance(r) < 1.0
}
}
left.flatMap { case (id, geom) =>
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index ccbb30f7..89bfb1be 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -134,14 +134,15 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
var polygonDf = sparkSession.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable")
polygonDf.createOrReplaceTempView("polygondf")
val polygon = "POLYGON ((110.54671 55.818002, 110.54671 55.143743, 110.940494 55.143743, 110.940494 55.818002, 110.54671 55.818002))"
- val forceXYExpect = "POLYGON ((471596.69167460164 6185916.951191288, 471107.5623640998 6110880.974228167, 496207.109151055 6110788.804712435, 496271.31937046186 6185825.60569904, 471596.69167460164 6185916.951191288))"
+ // Floats don't have an exact decimal representation. String format varies across jvm:s. Do an approximate match.
+ val forceXYExpect = "POLYGON \\(\\(471596.69167460\\d* 6185916.95119\\d*, 471107.562364\\d* 6110880.97422\\d*, 496207.10915\\d* 6110788.80471\\d*, 496271.3193704\\d* 6185825.6056\\d*, 471596.6916746\\d* 6185916.95119\\d*\\)\\)"
sparkSession.createDataset(Seq(polygon))
.withColumn("geom", expr("ST_GeomFromWKT(value)"))
.createOrReplaceTempView("df")
val forceXYResult = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'EPSG:4326', 'EPSG:32649', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
- assert(forceXYResult == forceXYExpect)
+ forceXYResult should fullyMatch regex(forceXYExpect)
}
it("Passed ST_transform WKT version"){
@@ -150,7 +151,7 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
var polygonDf = sparkSession.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable")
polygonDf.createOrReplaceTempView("polygondf")
val polygon = "POLYGON ((110.54671 55.818002, 110.54671 55.143743, 110.940494 55.143743, 110.940494 55.818002, 110.54671 55.818002))"
- val forceXYExpect = "POLYGON ((471596.69167460164 6185916.951191288, 471107.5623640998 6110880.974228167, 496207.109151055 6110788.804712435, 496271.31937046186 6185825.60569904, 471596.69167460164 6185916.951191288))"
+ val forceXYExpect = "POLYGON \\(\\(471596.6916746\\d* 6185916.95119\\d*, 471107.562364\\d* 6110880.97422\\d*, 496207.10915\\d* 6110788.80471\\d*, 496271.319370\\d* 6185825.6056\\d*, 471596.6916746\\d* 6185916.95119\\d*\\)\\)"
val EPSG_TGT_CRS = CRS.decode("EPSG:32649")
val EPSG_TGT_WKT = EPSG_TGT_CRS.toWKT()
@@ -162,13 +163,13 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
.createOrReplaceTempView("df")
val forceXYResult_TGT_WKT = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'EPSG:4326', '$EPSG_TGT_WKT', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
- assert(forceXYResult_TGT_WKT == forceXYExpect)
+ forceXYResult_TGT_WKT should fullyMatch regex(forceXYExpect)
val forceXYResult_SRC_WKT = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'$EPSG_SRC_WKT', 'EPSG:32649', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
- assert(forceXYResult_SRC_WKT == forceXYExpect)
+ forceXYResult_SRC_WKT should fullyMatch regex(forceXYExpect)
val forceXYResult_SRC_TGT_WKT = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'$EPSG_SRC_WKT', '$EPSG_TGT_WKT', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
- assert(forceXYResult_SRC_TGT_WKT == forceXYExpect)
+ forceXYResult_SRC_TGT_WKT should fullyMatch regex(forceXYExpect)
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index 45ec24aa..48684bc5 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -21,6 +21,7 @@ package org.apache.sedona.sql
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.ExplainMode
import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.io.WKTWriter
@@ -218,7 +219,7 @@ class predicateJoinTestScala extends TestBaseScala {
assert(distanceJoinDf.count() == 2998)
}
- it("Passed ST_Distance < radius in a join") {
+ it("Passed ST_Distance < distance in a join") {
var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
pointCsvDF1.createOrReplaceTempView("pointtable")
var pointDf1 = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable")
@@ -234,6 +235,23 @@ class predicateJoinTestScala extends TestBaseScala {
assert(distanceJoinDf.count() == 2998)
}
+ it("Passed ST_Distance < distance with LineString in a join") {
+ assert(sparkSession.sql(
+ """
+ |select *
+ |from (select ST_LineFromText('LineString(1 1, 1 3, 3 3)') as geom) a
+ |join (select ST_Point(2.0,2.0) as geom) b
+ |on ST_Distance(a.geom, b.geom) < 0.1
+ |""".stripMargin).isEmpty)
+ assert(sparkSession.sql(
+ """
+ |select *
+ |from (select ST_LineFromText('LineString(1 1, 1 4)') as geom) a
+ |join (select ST_Point(1.0,5.0) as geom) b
+ |on ST_Distance(a.geom, b.geom) < 1.5
+ |""".stripMargin).count() == 1)
+ }
+
it("Passed ST_Contains in a range and join") {
var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPolygonInputLocation)
polygonCsvDf.createOrReplaceTempView("polygontable")