You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/03/12 20:42:58 UTC

[sedona] branch master updated: [SEDONA-261] Allow distance expression in distance join to reference attributes from the right-side relation when running broadcast join. (#796)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b107f91 [SEDONA-261] Allow distance expression in distance join to reference attributes from the right-side relation when running broadcast join. (#796)
7b107f91 is described below

commit 7b107f91bf3bd2dba155d8c84234a90f9034da53
Author: Kristin Cowalcijk <bo...@wherobots.com>
AuthorDate: Mon Mar 13 04:42:51 2023 +0800

    [SEDONA-261] Allow distance expression in distance join to reference attributes from the right-side relation when running broadcast join. (#796)
---
 .../strategy/join/BroadcastIndexJoinExec.scala     |  3 +--
 .../strategy/join/JoinQueryDetector.scala          | 31 ++++++++++++++++++----
 .../org/apache/sedona/sql/SpatialJoinSuite.scala   | 31 +++++++++++++++-------
 3 files changed, 49 insertions(+), 16 deletions(-)

diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index 445362c6..693f4d1c 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -257,9 +257,8 @@ case class BroadcastIndexJoinExec(
   }
 
   private def createStreamShapes(streamResultsRaw: RDD[UnsafeRow], boundStreamShape: Expression) = {
-    // If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
     distance match {
-      case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
+      case Some(distanceExpression) =>
         streamResultsRaw.map(row => {
           val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
           if (geom == null) {
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 393cbae6..a17eabfd 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -175,6 +175,16 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
       None
     }
 
+  private def matchDistanceExpressionToJoinSide(distance: Expression, left: LogicalPlan, right: LogicalPlan): Option[JoinSide] = {
+    if (distance.references.isEmpty || matches(distance, left)) {
+      Some(LeftSide)
+    } else if (matches(distance, right)) {
+      Some(RightSide)
+    } else {
+      None
+    }
+  }
+
   private def planSpatialJoin(
     left: LogicalPlan,
     right: LogicalPlan,
@@ -277,20 +287,31 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
       case (None, _) => s"ST_$spatialPredicate"
     }
 
+    val (distanceOnIndexSide, distanceOnStreamSide) = distance.map { distanceExpr =>
+      matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
+        case Some(side) =>
+          if (broadcastSide.get == side) (Some(distanceExpr), None)
+          else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
+          else (None, Some(distanceExpr))
+        case _ => throw new IllegalArgumentException("Distance expression must be bound to one side of the join")
+      }
+    }.getOrElse((None, None))
+
     matchExpressionsToPlans(a, b, left, right) match {
       case Some((_, _, swapped)) =>
         logInfo(s"Planning spatial join for $relationship relationship")
         val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide.get, swapped) match {
           case (LeftSide, false) => // Broadcast the left side, windows on the left
-            (SpatialIndexExec(planLater(left), a, indexType, distance), planLater(right), b, LeftSide)
+            (SpatialIndexExec(planLater(left), a, indexType, distanceOnIndexSide), planLater(right), b, LeftSide)
           case (LeftSide, true) => // Broadcast the left side, objects on the left
-            (SpatialIndexExec(planLater(left), b, indexType), planLater(right), a, RightSide)
+            (SpatialIndexExec(planLater(left), b, indexType, distanceOnIndexSide), planLater(right), a, RightSide)
           case (RightSide, false) => // Broadcast the right side, windows on the left
-            (planLater(left), SpatialIndexExec(planLater(right), b, indexType), a, LeftSide)
+            (planLater(left), SpatialIndexExec(planLater(right), b, indexType, distanceOnIndexSide), a, LeftSide)
           case (RightSide, true) => // Broadcast the right side, objects on the left
-            (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distance), b, RightSide)
+            (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distanceOnIndexSide), b, RightSide)
         }
-        BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide.get, windowSide, joinType, spatialPredicate, extraCondition, distance) :: Nil
+        BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide.get, windowSide, joinType,
+          spatialPredicate, extraCondition, distanceOnStreamSide) :: Nil
       case None =>
         logInfo(
           s"Spatial join for $relationship with arguments not aligned " +
diff --git a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 9ec443ff..a155ee21 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -21,7 +21,7 @@ package org.apache.sedona.sql
 
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions.{col, expr}
 import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_GeomFromText
 import org.apache.spark.sql.types.IntegerType
 import org.locationtech.jts.geom.Geometry
@@ -58,7 +58,12 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
       "ST_Distance(df1.geom, df2.geom) < 1.0",
       "ST_Distance(df1.geom, df2.geom) <= 1.0",
       "ST_Distance(df2.geom, df1.geom) < 1.0",
-      "ST_Distance(df2.geom, df1.geom) <= 1.0"
+      "ST_Distance(df2.geom, df1.geom) <= 1.0",
+
+      "ST_Distance(df1.geom, df2.geom) < df1.dist",
+      "ST_Distance(df1.geom, df2.geom) < df2.dist",
+      "ST_Distance(df2.geom, df1.geom) < df1.dist",
+      "ST_Distance(df2.geom, df1.geom) < df2.dist"
     )
 
     var spatialJoinPartitionSide = "left"
@@ -109,7 +114,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
       it(s"should SELECT * in join query with $joinCondition produce correct result") {
         prepareTempViewsForTestData()
         val resultAll = sparkSession.sql(s"SELECT * FROM df1 JOIN df2 ON $joinCondition").collect()
-        val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+        val result = resultAll.map(row => (row.getInt(0), row.getInt(3))).sorted
         val expected = buildExpectedResult(joinCondition)
         assert(result.nonEmpty)
         assert(result === expected)
@@ -118,7 +123,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
       it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the left side") {
         prepareTempViewsForTestData()
         val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
-        val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+        val result = resultAll.map(row => (row.getInt(0), row.getInt(3))).sorted
         val expected = buildExpectedResult(joinCondition)
         assert(result.nonEmpty)
         assert(result === expected)
@@ -127,7 +132,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
       it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the right side") {
         prepareTempViewsForTestData()
         val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
-        val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+        val result = resultAll.map(row => (row.getInt(0), row.getInt(3))).sorted
         val expected = buildExpectedResult(joinCondition)
         assert(result.nonEmpty)
         assert(result === expected)
@@ -141,11 +146,13 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
       .withColumn("id", col("_c0").cast(IntegerType))
       .withColumn("geom", ST_GeomFromText(new Column("_c2")))
       .select("id", "geom")
+      .withColumn("dist", expr("ST_Area(geom)"))
     val df2 = sparkSession.read.format("csv").option("header", "false").option("delimiter", testDataDelimiter)
       .load(spatialJoinRightInputLocation)
       .withColumn("id", col("_c0").cast(IntegerType))
       .withColumn("geom", ST_GeomFromText(new Column("_c2")))
       .select("id", "geom")
+      .withColumn("dist", expr("ST_Area(geom)"))
     df1.createOrReplaceTempView("df1")
     df2.createOrReplaceTempView("df2")
     (df1, df2)
@@ -167,10 +174,16 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
       case "ST_Touches" => (l: Geometry, r: Geometry) => l.touches(r)
       case "ST_Within" => (l: Geometry, r: Geometry) => l.within(r)
       case "ST_Distance" =>
-        if (joinCondition.contains("<=")) {
-          (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
-        } else {
-          (l: Geometry, r: Geometry) => l.distance(r) < 1.0
+        if (joinCondition contains "df1.dist")
+          (l: Geometry, r: Geometry) => l.distance(r) < (if (!swapped) l.getArea else r.getArea)
+        else if (joinCondition contains "df2.dist")
+          (l: Geometry, r: Geometry) => l.distance(r) < (if (!swapped) r.getArea else l.getArea)
+        else {
+          if (joinCondition.contains("<=")) {
+            (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
+          } else {
+            (l: Geometry, r: Geometry) => l.distance(r) < 1.0
+          }
         }
     }
     left.flatMap { case (id, geom) =>