You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/10/21 02:58:17 UTC

[incubator-sedona] branch master updated: [SEDONA-178] Correctness issue in distance joins (#701)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d7e3fbe [SEDONA-178] Correctness issue in distance joins (#701)
8d7e3fbe is described below

commit 8d7e3fbedad3b643e409714c055b654f742cd812
Author: Martin Andersson <u....@gmail.com>
AuthorDate: Fri Oct 21 04:58:11 2022 +0200

    [SEDONA-178] Correctness issue in distance joins (#701)
---
 docs/tutorial/core-python.md                       |  2 +
 docs/tutorial/rdd.md                               |  2 +
 pom.xml                                            |  4 ++
 .../strategy/join/BroadcastIndexJoinExec.scala     | 12 ++--
 .../strategy/join/DistanceJoinExec.scala           | 45 +++++++-------
 .../strategy/join/JoinQueryDetector.scala          | 70 +++++++++++-----------
 .../strategy/join/SpatialIndexExec.scala           |  6 +-
 .../strategy/join/TraitJoinQueryBase.scala         | 12 ++--
 .../sedona/sql/BroadcastIndexJoinSuite.scala       | 23 ++++++-
 .../org/apache/sedona/sql/SpatialJoinSuite.scala   |  5 +-
 .../org/apache/sedona/sql/functionTestScala.scala  | 13 ++--
 .../apache/sedona/sql/predicateJoinTestScala.scala | 20 ++++++-
 12 files changed, 130 insertions(+), 84 deletions(-)

diff --git a/docs/tutorial/core-python.md b/docs/tutorial/core-python.md
index 3d4d1c3b..158071a8 100644
--- a/docs/tutorial/core-python.md
+++ b/docs/tutorial/core-python.md
@@ -450,6 +450,8 @@ Each object on the left is covered/intersected by the object on the right.
 
 ## Write a Distance Join Query
 
+!!!warning RDD distance joins are only reliable for points. For other geometry types, please use Spatial SQL.
+
 A distance join query takes two spatial RDD assuming that we have two SpatialRDD's:
 <li> object_rdd </li>
 <li> spatial_rdd </li>
diff --git a/docs/tutorial/rdd.md b/docs/tutorial/rdd.md
index 8ea06ee6..ca0f177a 100644
--- a/docs/tutorial/rdd.md
+++ b/docs/tutorial/rdd.md
@@ -466,6 +466,8 @@ Each object on the left is covered/intersected by the object on the right.
 
 ## Write a Distance Join Query
 
+!!!warning RDD distance joins are only reliable for points. For other geometry types, please use Spatial SQL.
+
 A distance join query takes as input two Spatial RDD A and B and a distance. For each geometry in A, finds the geometries (from B) are within the given distance to it. A and B can be any geometry type and are not necessary to have the same geometry type. The unit of the distance is explained [here](#transform-the-coordinate-reference-system).
 
 Assume you now have two SpatialRDDs (typed or generic). You can use the following code to issue an Distance Join Query on them.
diff --git a/pom.xml b/pom.xml
index 43e8d10b..3e373901 100644
--- a/pom.xml
+++ b/pom.xml
@@ -68,6 +68,10 @@
         <sedona.jackson.version>2.13.3</sedona.jackson.version>
         <hadoop.version>3.2.4</hadoop.version>
         <maven.deploy.skip>false</maven.deploy.skip>
+
+        <!-- Actual scala version will be set by a profile.
+        Setting a default value helps IDE:s that can't make sense of profiles. -->
+        <scala.compat.version>2.12</scala.compat.version>
     </properties>
 
     <dependencies>
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index f60b8947..7925c552 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -42,7 +42,7 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
                                   windowJoinSide: JoinSide,
                                   spatialPredicate: SpatialPredicate,
                                   extraCondition: Option[Expression] = None,
-                                  radius: Option[Expression] = None)
+                                  distance: Option[Expression] = None)
   extends SedonaBinaryExecNode
     with TraitJoinQueryBase
     with Logging {
@@ -71,7 +71,7 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
     (streamShape, broadcast.shape)
   }
 
-  private val spatialExpression = (radius, spatialPredicate) match {
+  private val spatialExpression = (distance, spatialPredicate) match {
     case (Some(r), SpatialPredicate.INTERSECTS) => s"ST_Distance($windowExpression, $objectExpression) <= $r"
     case (Some(r), _) => s"ST_Distance($windowExpression, $objectExpression) < $r"
     case (None, _) => s"ST_$spatialPredicate($windowExpression, $objectExpression)"
@@ -104,10 +104,10 @@ case class BroadcastIndexJoinExec(left: SparkPlan,
     val boundStreamShape = BindReferences.bindReference(streamShape, streamed.output)
     val streamResultsRaw = streamed.execute().asInstanceOf[RDD[UnsafeRow]]
 
-    // If there's a radius and the objects are being broadcast, we need to build the CircleRDD on the window stream side
-    val streamShapes = radius match {
-      case Some(radiusExpression) if indexBuildSide != windowJoinSide =>
-        toCircleRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(radiusExpression, streamed.output))
+    // If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
+    val streamShapes = distance match {
+      case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
+        toExpandedEnvelopeRDD(streamResultsRaw, boundStreamShape, BindReferences.bindReference(distanceExpression, streamed.output))
       case _ =>
         toSpatialRDD(streamResultsRaw, boundStreamShape)
     }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index ebaf0005..4df455ef 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -18,54 +18,53 @@
  */
 package org.apache.spark.sql.sedona_sql.strategy.join
 
-import org.apache.sedona.common.geometryObjects.Circle
 import org.apache.sedona.core.spatialOperator.SpatialPredicate
 import org.apache.sedona.core.spatialRDD.SpatialRDD
-import org.apache.sedona.sql.utils.GeometrySerializer
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
-import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
 import org.locationtech.jts.geom.Geometry
 
-// ST_Distance(left, right) <= radius
-// radius can be literal or a computation over 'left'
+/**
+ * Distance joins requires matching geometries to be in the same partition, despite not necessarily overlapping.
+ * To create an overlap and guarantee matching geometries end up in the same partition, the left geometry is expanded
+ * before partitioning. It's the logical equivalent of:
+ *
+ * select * from a join b on ST_Distance(a.geom, b.geom) <= 1
+ *
+ * becomes
+ *
+ * select * from a join b on ST_Intersects(ST_Envelope(ST_Buffer(a.geom, 1)), b.geom) and ST_Distance(a.geom, b.geom) <= 1
+ *
+ * @param left
+ * @param right
+ * @param leftShape
+ * @param rightShape
+ * @param distance - ST_Distance(left, right) <= distance. Distance can be literal or a computation over 'left'.
+ * @param spatialPredicate
+ * @param extraCondition
+ */
 case class DistanceJoinExec(left: SparkPlan,
                             right: SparkPlan,
                             leftShape: Expression,
                             rightShape: Expression,
-                            radius: Expression,
+                            distance: Expression,
                             spatialPredicate: SpatialPredicate,
                             extraCondition: Option[Expression] = None)
   extends SedonaBinaryExecNode
     with TraitJoinQueryExec
     with Logging {
 
-  private val boundRadius = BindReferences.bindReference(radius, left.output)
+  private val boundRadius = BindReferences.bindReference(distance, left.output)
 
   override def toSpatialRddPair(
                                  buildRdd: RDD[UnsafeRow],
                                  buildExpr: Expression,
                                  streamedRdd: RDD[UnsafeRow],
                                  streamedExpr: Expression): (SpatialRDD[Geometry], SpatialRDD[Geometry]) =
-    (toCircleRDD(buildRdd, buildExpr), toSpatialRDD(streamedRdd, streamedExpr))
-
-  private def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression): SpatialRDD[Geometry] = {
-    val spatialRdd = new SpatialRDD[Geometry]
-    spatialRdd.setRawSpatialRDD(
-      rdd
-        .map { x => {
-          val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
-          val circle = new Circle(shape, boundRadius.eval(x).asInstanceOf[Double])
-          circle.setUserData(x.copy)
-          circle.asInstanceOf[Geometry]
-        }
-        }
-        .toJavaRDD())
-    spatialRdd
-  }
+    (toExpandedEnvelopeRDD(buildRdd, buildExpr, boundRadius), toSpatialRDD(streamedRdd, streamedExpr))
 
   protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = {
     copy(left = newLeft, right = newRight)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 6a020e33..1e419b17 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -30,13 +30,13 @@ import org.apache.spark.sql.sedona_sql.expressions._
 
 
 case class JoinQueryDetection(
-  left: LogicalPlan,
-  right: LogicalPlan,
-  leftShape: Expression,
-  rightShape: Expression,
-  spatialPredicate: SpatialPredicate,
-  extraCondition: Option[Expression] = None,
-  radius: Option[Expression] = None
+                               left: LogicalPlan,
+                               right: LogicalPlan,
+                               leftShape: Expression,
+                               rightShape: Expression,
+                               spatialPredicate: SpatialPredicate,
+                               extraCondition: Option[Expression] = None,
+                               distance: Option[Expression] = None
 )
 
 /**
@@ -92,18 +92,20 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
           getJoinDetection(left, right, predicate, Some(extraCondition))
         case Some(And(extraCondition, predicate: ST_Predicate)) =>
           getJoinDetection(left, right, predicate, Some(extraCondition))
-        case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, None, Some(radius)))
-        case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, Some(extraCondition), Some(radius)))
-        case Some(And(extraCondition, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, Some(extraCondition), Some(radius)))
-        case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, None, Some(radius)))
-        case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), radius), extraCondition)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, Some(extraCondition), Some(radius)))
-        case Some(And(extraCondition, LessThan(ST_Distance(Seq(leftShape, rightShape)), radius))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.COVERS, Some(extraCondition), Some(radius)))
+
+        // For distance joins we execute the actual predicate (condition) and not only extraConditions.
+        case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+        case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance), _)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+        case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+        case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+        case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance), _)) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
+        case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)), distance))) =>
+          Some(JoinQueryDetection(left, right, leftShape, rightShape, SpatialPredicate.INTERSECTS, condition, Some(distance)))
         case _ =>
           None
       }
@@ -112,8 +114,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
 
       if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
         queryDetection match {
-          case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, radius)) =>
-            planBroadcastJoin(left, right, Seq(leftShape, rightShape), spatialPredicate, sedonaConf.getIndexType, broadcastLeft, extraCondition, radius)
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, distance)) =>
+            planBroadcastJoin(left, right, Seq(leftShape, rightShape), spatialPredicate, sedonaConf.getIndexType, broadcastLeft, extraCondition, distance)
           case _ =>
             Nil
         }
@@ -121,8 +123,8 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
         queryDetection match {
           case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, None)) =>
             planSpatialJoin(left, right, Seq(leftShape, rightShape), spatialPredicate, extraCondition)
-          case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, Some(radius))) =>
-            planDistanceJoin(left, right, Seq(leftShape, rightShape), radius, spatialPredicate, extraCondition)
+          case Some(JoinQueryDetection(left, right, leftShape, rightShape, spatialPredicate, extraCondition, Some(distance))) =>
+            planDistanceJoin(left, right, Seq(leftShape, rightShape), distance, spatialPredicate, extraCondition)
           case None => 
             Nil
         }
@@ -176,7 +178,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
   private def planDistanceJoin(left: LogicalPlan,
                                right: LogicalPlan,
                                children: Seq[Expression],
-                               radius: Expression,
+                               distance: Expression,
                                spatialPredicate: SpatialPredicate,
                                extraCondition: Option[Expression] = None): Seq[SparkPlan] = {
     val a = children.head
@@ -184,15 +186,15 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
 
     matchExpressionsToPlans(a, b, left, right) match {
       case Some((planA, planB, _)) =>
-        if (radius.references.isEmpty || matches(radius, planA)) {
+        if (distance.references.isEmpty || matches(distance, planA)) {
           logInfo("Planning spatial distance join")
-          DistanceJoinExec(planLater(planA), planLater(planB), a, b, radius, spatialPredicate, extraCondition) :: Nil
-        } else if (matches(radius, planB)) {
+          DistanceJoinExec(planLater(planA), planLater(planB), a, b, distance, spatialPredicate, extraCondition) :: Nil
+        } else if (matches(distance, planB)) {
           logInfo("Planning spatial distance join")
-          DistanceJoinExec(planLater(planB), planLater(planA), b, a, radius, spatialPredicate, extraCondition) :: Nil
+          DistanceJoinExec(planLater(planB), planLater(planA), b, a, distance, spatialPredicate, extraCondition) :: Nil
         } else {
           logInfo(
-            "Spatial distance join for ST_Distance with non-scalar radius " +
+            "Spatial distance join for ST_Distance with non-scalar distance " +
               "that is not a computation over just one side of the join is not supported")
           Nil
         }
@@ -211,11 +213,11 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
                                 indexType: IndexType,
                                 broadcastLeft: Boolean,
                                 extraCondition: Option[Expression],
-                                radius: Option[Expression]): Seq[SparkPlan] = {
+                                distance: Option[Expression]): Seq[SparkPlan] = {
     val a = children.head
     val b = children.tail.head
 
-    val relationship = (radius, spatialPredicate) match {
+    val relationship = (distance, spatialPredicate) match {
       case (Some(_), SpatialPredicate.INTERSECTS) => "ST_Distance <="
       case (Some(_), _) => "ST_Distance <"
       case (None, _) => s"ST_$spatialPredicate"
@@ -227,15 +229,15 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
         val broadcastSide = if (broadcastLeft) LeftSide else RightSide
         val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide, swapped) match {
           case (LeftSide, false) => // Broadcast the left side, windows on the left
-            (SpatialIndexExec(planLater(left), a, indexType, radius), planLater(right), b, LeftSide)
+            (SpatialIndexExec(planLater(left), a, indexType, distance), planLater(right), b, LeftSide)
           case (LeftSide, true) => // Broadcast the left side, objects on the left
             (SpatialIndexExec(planLater(left), b, indexType), planLater(right), a, RightSide)
           case (RightSide, false) => // Broadcast the right side, windows on the left
             (planLater(left), SpatialIndexExec(planLater(right), b, indexType), a, LeftSide)
           case (RightSide, true) => // Broadcast the right side, objects on the left
-            (planLater(left), SpatialIndexExec(planLater(right), a, indexType, radius), b, RightSide)
+            (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distance), b, RightSide)
         }
-        BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide, windowSide, spatialPredicate, extraCondition, radius) :: Nil
+        BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide, windowSide, spatialPredicate, extraCondition, distance) :: Nil
       case None =>
         logInfo(
           s"Spatial join for $relationship with arguments not aligned " +
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
index 43eb7191..f9f9a4ed 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/SpatialIndexExec.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.sedona_sql.execution.SedonaUnaryExecNode
 case class SpatialIndexExec(child: SparkPlan,
                             shape: Expression,
                             indexType: IndexType,
-                            radius: Option[Expression] = None)
+                            distance: Option[Expression] = None)
   extends SedonaUnaryExecNode
     with TraitJoinQueryBase
     with Logging {
@@ -50,8 +50,8 @@ case class SpatialIndexExec(child: SparkPlan,
 
     val resultRaw = child.execute().asInstanceOf[RDD[UnsafeRow]].coalesce(1)
 
-    val spatialRDD = radius match {
-      case Some(radiusExpression) => toCircleRDD(resultRaw, boundShape, BindReferences.bindReference(radiusExpression, child.output))
+    val spatialRDD = distance match {
+      case Some(distanceExpression) => toExpandedEnvelopeRDD(resultRaw, boundShape, BindReferences.bindReference(distanceExpression, child.output))
       case None => toSpatialRDD(resultRaw, boundShape)
     }
 
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 46a8ac4a..22d44a4a 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -18,7 +18,6 @@
  */
 package org.apache.spark.sql.sedona_sql.strategy.join
 
-import org.apache.sedona.common.geometryObjects.Circle
 import org.apache.sedona.core.spatialRDD.SpatialRDD
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.sedona.sql.utils.GeometrySerializer
@@ -52,15 +51,18 @@ trait TraitJoinQueryBase {
     spatialRdd
   }
 
-  def toCircleRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = {
+  def toExpandedEnvelopeRDD(rdd: RDD[UnsafeRow], shapeExpression: Expression, boundRadius: Expression): SpatialRDD[Geometry] = {
     val spatialRdd = new SpatialRDD[Geometry]
     spatialRdd.setRawSpatialRDD(
       rdd
         .map { x => {
           val shape = GeometrySerializer.deserialize(shapeExpression.eval(x).asInstanceOf[ArrayData])
-          val circle = new Circle(shape, boundRadius.eval(x).asInstanceOf[Double])
-          circle.setUserData(x.copy)
-          circle.asInstanceOf[Geometry]
+          val envelope = shape.getEnvelopeInternal.copy()
+          envelope.expandBy(boundRadius.eval(x).asInstanceOf[Double])
+
+          val expandedEnvelope = shape.getFactory.toGeometry(envelope)
+          expandedEnvelope.setUserData(x.copy)
+          expandedEnvelope
         }
         }
         .toJavaRDD())
diff --git a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
index 8395063e..445af754 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/BroadcastIndexJoinSuite.scala
@@ -156,7 +156,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       assert(broadcastJoinDf.count() == 0)
     }
 
-    it("Passed ST_Distance <= radius in a broadcast join") {
+    it("Passed ST_Distance <= distance in a broadcast join") {
       var pointDf1 = buildPointDf
       var pointDf2 = buildPointDf
 
@@ -169,7 +169,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       assert(distanceJoinDf.count() == 2998)
     }
 
-    it("Passed ST_Distance < radius in a broadcast join") {
+    it("Passed ST_Distance < distance in a broadcast join") {
       var pointDf1 = buildPointDf
       var pointDf2 = buildPointDf
 
@@ -182,7 +182,7 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       assert(distanceJoinDf.count() == 2998)
     }
 
-    it("Passed ST_Distance radius is bound to first expression") {
+    it("Passed ST_Distance distance is bound to first expression") {
       var pointDf1 = buildPointDf.withColumn("radius", two())
       var pointDf2 = buildPointDf
 
@@ -229,5 +229,22 @@ class BroadcastIndexJoinSuite extends TestBaseScala {
       assert(broadcastJoinDf.count() == 1000)
       sparkSession.conf.set("spark.sql.adaptive.enabled", false)
     }
+
+    it("Passed broadcast distance join with LineString") {
+      assert(sparkSession.sql(
+        """
+          |select /*+ BROADCAST(a) */ *
+          |from (select ST_LineFromText('LineString(1 1, 1 3, 3 3)') as geom) a
+          |join (select ST_Point(2.0,2.0) as geom) b
+          |on ST_Distance(a.geom, b.geom) < 0.1
+          |""".stripMargin).isEmpty)
+      assert(sparkSession.sql(
+        """
+          |select /*+ BROADCAST(a) */ *
+          |from (select ST_LineFromText('LineString(1 1, 1 4)') as geom) a
+          |join (select ST_Point(1.0,5.0) as geom) b
+          |on ST_Distance(a.geom, b.geom) < 1.5
+          |""".stripMargin).count() == 1)
+    }
   }
 }
diff --git a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 42da9063..62de4f5b 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -128,11 +128,10 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
       case "ST_Touches" => (l: Geometry, r: Geometry) => l.touches(r)
       case "ST_Within" => (l: Geometry, r: Geometry) => l.within(r)
       case "ST_Distance" =>
-        // XXX: ST_Distance has a weird behavior, it is wildly different from `l.distance(r)`.
         if (joinCondition.contains("<=")) {
-          (l: Geometry, r: Geometry) => new Circle(l, 1.0).intersects(r)
+          (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
         } else {
-          (l: Geometry, r: Geometry) => new Circle(l, 1.0).covers(r)
+          (l: Geometry, r: Geometry) => l.distance(r) < 1.0
         }
     }
     left.flatMap { case (id, geom) =>
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index ccbb30f7..89bfb1be 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -134,14 +134,15 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
       var polygonDf = sparkSession.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable")
       polygonDf.createOrReplaceTempView("polygondf")
       val polygon = "POLYGON ((110.54671 55.818002, 110.54671 55.143743, 110.940494 55.143743, 110.940494 55.818002, 110.54671 55.818002))"
-      val forceXYExpect = "POLYGON ((471596.69167460164 6185916.951191288, 471107.5623640998 6110880.974228167, 496207.109151055 6110788.804712435, 496271.31937046186 6185825.60569904, 471596.69167460164 6185916.951191288))"
+      // Floats don't have an exact decimal representation. String format varies across jvm:s. Do an approximate match.
+      val forceXYExpect = "POLYGON \\(\\(471596.69167460\\d* 6185916.95119\\d*, 471107.562364\\d* 6110880.97422\\d*, 496207.10915\\d* 6110788.80471\\d*, 496271.3193704\\d* 6185825.6056\\d*, 471596.6916746\\d* 6185916.95119\\d*\\)\\)"
 
       sparkSession.createDataset(Seq(polygon))
         .withColumn("geom", expr("ST_GeomFromWKT(value)"))
         .createOrReplaceTempView("df")
 
       val forceXYResult = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'EPSG:4326', 'EPSG:32649', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
-      assert(forceXYResult == forceXYExpect)
+      forceXYResult should fullyMatch regex(forceXYExpect)
     }
 
     it("Passed ST_transform WKT version"){
@@ -150,7 +151,7 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
       var polygonDf = sparkSession.sql("select ST_GeomFromWKT(polygontable._c0) as countyshape from polygontable")
       polygonDf.createOrReplaceTempView("polygondf")
       val polygon = "POLYGON ((110.54671 55.818002, 110.54671 55.143743, 110.940494 55.143743, 110.940494 55.818002, 110.54671 55.818002))"
-      val forceXYExpect = "POLYGON ((471596.69167460164 6185916.951191288, 471107.5623640998 6110880.974228167, 496207.109151055 6110788.804712435, 496271.31937046186 6185825.60569904, 471596.69167460164 6185916.951191288))"
+      val forceXYExpect = "POLYGON \\(\\(471596.6916746\\d* 6185916.95119\\d*, 471107.562364\\d* 6110880.97422\\d*, 496207.10915\\d* 6110788.80471\\d*, 496271.319370\\d* 6185825.6056\\d*, 471596.6916746\\d* 6185916.95119\\d*\\)\\)"
 
       val EPSG_TGT_CRS = CRS.decode("EPSG:32649")
       val EPSG_TGT_WKT = EPSG_TGT_CRS.toWKT()
@@ -162,13 +163,13 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
         .createOrReplaceTempView("df")
 
       val forceXYResult_TGT_WKT = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'EPSG:4326', '$EPSG_TGT_WKT', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
-      assert(forceXYResult_TGT_WKT == forceXYExpect)
+      forceXYResult_TGT_WKT should fullyMatch regex(forceXYExpect)
 
       val forceXYResult_SRC_WKT = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'$EPSG_SRC_WKT', 'EPSG:32649', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
-      assert(forceXYResult_SRC_WKT == forceXYExpect)
+      forceXYResult_SRC_WKT should fullyMatch regex(forceXYExpect)
 
       val forceXYResult_SRC_TGT_WKT = sparkSession.sql(s"""select ST_Transform(ST_FlipCoordinates(ST_geomFromWKT('$polygon')),'$EPSG_SRC_WKT', '$EPSG_TGT_WKT', false)""").rdd.map(row => row.getAs[Geometry](0).toString).collect()(0)
-      assert(forceXYResult_SRC_TGT_WKT == forceXYExpect)
+      forceXYResult_SRC_TGT_WKT should fullyMatch regex(forceXYExpect)
 
     }
 
diff --git a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index 45ec24aa..48684bc5 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -21,6 +21,7 @@ package org.apache.sedona.sql
 
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.ExplainMode
 import org.apache.spark.sql.types._
 import org.locationtech.jts.geom.Geometry
 import org.locationtech.jts.io.WKTWriter
@@ -218,7 +219,7 @@ class predicateJoinTestScala extends TestBaseScala {
       assert(distanceJoinDf.count() == 2998)
     }
 
-    it("Passed ST_Distance < radius in a join") {
+    it("Passed ST_Distance < distance in a join") {
       var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
       pointCsvDF1.createOrReplaceTempView("pointtable")
       var pointDf1 = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape1 from pointtable")
@@ -234,6 +235,23 @@ class predicateJoinTestScala extends TestBaseScala {
       assert(distanceJoinDf.count() == 2998)
     }
 
+    it("Passed ST_Distance < distance with LineString in a join") {
+      assert(sparkSession.sql(
+        """
+          |select *
+          |from (select ST_LineFromText('LineString(1 1, 1 3, 3 3)') as geom) a
+          |join (select ST_Point(2.0,2.0) as geom) b
+          |on ST_Distance(a.geom, b.geom) < 0.1
+          |""".stripMargin).isEmpty)
+      assert(sparkSession.sql(
+        """
+          |select *
+          |from (select ST_LineFromText('LineString(1 1, 1 4)') as geom) a
+          |join (select ST_Point(1.0,5.0) as geom) b
+          |on ST_Distance(a.geom, b.geom) < 1.5
+          |""".stripMargin).count() == 1)
+    }
+
     it("Passed ST_Contains in a range and join") {
       var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPolygonInputLocation)
       polygonCsvDf.createOrReplaceTempView("polygontable")