You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/03/12 20:42:58 UTC
[sedona] branch master updated: [SEDONA-261] Allow distance expression in distance join to reference attributes from the right-side relation when running broadcast join. (#796)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 7b107f91 [SEDONA-261] Allow distance expression in distance join to reference attributes from the right-side relation when running broadcast join. (#796)
7b107f91 is described below
commit 7b107f91bf3bd2dba155d8c84234a90f9034da53
Author: Kristin Cowalcijk <bo...@wherobots.com>
AuthorDate: Mon Mar 13 04:42:51 2023 +0800
[SEDONA-261] Allow distance expression in distance join to reference attributes from the right-side relation when running broadcast join. (#796)
---
.../strategy/join/BroadcastIndexJoinExec.scala | 3 +--
.../strategy/join/JoinQueryDetector.scala | 31 ++++++++++++++++++----
.../org/apache/sedona/sql/SpatialJoinSuite.scala | 31 +++++++++++++++-------
3 files changed, 49 insertions(+), 16 deletions(-)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
index 445362c6..693f4d1c 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastIndexJoinExec.scala
@@ -257,9 +257,8 @@ case class BroadcastIndexJoinExec(
}
private def createStreamShapes(streamResultsRaw: RDD[UnsafeRow], boundStreamShape: Expression) = {
- // If there's a distance and the objects are being broadcast, we need to build the expanded envelope on the window stream side
distance match {
- case Some(distanceExpression) if indexBuildSide != windowJoinSide =>
+ case Some(distanceExpression) =>
streamResultsRaw.map(row => {
val geom = boundStreamShape.eval(row).asInstanceOf[Array[Byte]]
if (geom == null) {
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 393cbae6..a17eabfd 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -175,6 +175,16 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
None
}
+ private def matchDistanceExpressionToJoinSide(distance: Expression, left: LogicalPlan, right: LogicalPlan): Option[JoinSide] = {
+ if (distance.references.isEmpty || matches(distance, left)) {
+ Some(LeftSide)
+ } else if (matches(distance, right)) {
+ Some(RightSide)
+ } else {
+ None
+ }
+ }
+
private def planSpatialJoin(
left: LogicalPlan,
right: LogicalPlan,
@@ -277,20 +287,31 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
case (None, _) => s"ST_$spatialPredicate"
}
+ val (distanceOnIndexSide, distanceOnStreamSide) = distance.map { distanceExpr =>
+ matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
+ case Some(side) =>
+ if (broadcastSide.get == side) (Some(distanceExpr), None)
+ else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
+ else (None, Some(distanceExpr))
+ case _ => throw new IllegalArgumentException("Distance expression must be bound to one side of the join")
+ }
+ }.getOrElse((None, None))
+
matchExpressionsToPlans(a, b, left, right) match {
case Some((_, _, swapped)) =>
logInfo(s"Planning spatial join for $relationship relationship")
val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide.get, swapped) match {
case (LeftSide, false) => // Broadcast the left side, windows on the left
- (SpatialIndexExec(planLater(left), a, indexType, distance), planLater(right), b, LeftSide)
+ (SpatialIndexExec(planLater(left), a, indexType, distanceOnIndexSide), planLater(right), b, LeftSide)
case (LeftSide, true) => // Broadcast the left side, objects on the left
- (SpatialIndexExec(planLater(left), b, indexType), planLater(right), a, RightSide)
+ (SpatialIndexExec(planLater(left), b, indexType, distanceOnIndexSide), planLater(right), a, RightSide)
case (RightSide, false) => // Broadcast the right side, windows on the left
- (planLater(left), SpatialIndexExec(planLater(right), b, indexType), a, LeftSide)
+ (planLater(left), SpatialIndexExec(planLater(right), b, indexType, distanceOnIndexSide), a, LeftSide)
case (RightSide, true) => // Broadcast the right side, objects on the left
- (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distance), b, RightSide)
+ (planLater(left), SpatialIndexExec(planLater(right), a, indexType, distanceOnIndexSide), b, RightSide)
}
- BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide.get, windowSide, joinType, spatialPredicate, extraCondition, distance) :: Nil
+ BroadcastIndexJoinExec(leftPlan, rightPlan, streamShape, broadcastSide.get, windowSide, joinType,
+ spatialPredicate, extraCondition, distanceOnStreamSide) :: Nil
case None =>
logInfo(
s"Spatial join for $relationship with arguments not aligned " +
diff --git a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 9ec443ff..a155ee21 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -21,7 +21,7 @@ package org.apache.sedona.sql
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_GeomFromText
import org.apache.spark.sql.types.IntegerType
import org.locationtech.jts.geom.Geometry
@@ -58,7 +58,12 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
"ST_Distance(df1.geom, df2.geom) < 1.0",
"ST_Distance(df1.geom, df2.geom) <= 1.0",
"ST_Distance(df2.geom, df1.geom) < 1.0",
- "ST_Distance(df2.geom, df1.geom) <= 1.0"
+ "ST_Distance(df2.geom, df1.geom) <= 1.0",
+
+ "ST_Distance(df1.geom, df2.geom) < df1.dist",
+ "ST_Distance(df1.geom, df2.geom) < df2.dist",
+ "ST_Distance(df2.geom, df1.geom) < df1.dist",
+ "ST_Distance(df2.geom, df1.geom) < df2.dist"
)
var spatialJoinPartitionSide = "left"
@@ -109,7 +114,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
it(s"should SELECT * in join query with $joinCondition produce correct result") {
prepareTempViewsForTestData()
val resultAll = sparkSession.sql(s"SELECT * FROM df1 JOIN df2 ON $joinCondition").collect()
- val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+ val result = resultAll.map(row => (row.getInt(0), row.getInt(3))).sorted
val expected = buildExpectedResult(joinCondition)
assert(result.nonEmpty)
assert(result === expected)
@@ -118,7 +123,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the left side") {
prepareTempViewsForTestData()
val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
- val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+ val result = resultAll.map(row => (row.getInt(0), row.getInt(3))).sorted
val expected = buildExpectedResult(joinCondition)
assert(result.nonEmpty)
assert(result === expected)
@@ -127,7 +132,7 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the right side") {
prepareTempViewsForTestData()
val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
- val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+ val result = resultAll.map(row => (row.getInt(0), row.getInt(3))).sorted
val expected = buildExpectedResult(joinCondition)
assert(result.nonEmpty)
assert(result === expected)
@@ -141,11 +146,13 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
.withColumn("id", col("_c0").cast(IntegerType))
.withColumn("geom", ST_GeomFromText(new Column("_c2")))
.select("id", "geom")
+ .withColumn("dist", expr("ST_Area(geom)"))
val df2 = sparkSession.read.format("csv").option("header", "false").option("delimiter", testDataDelimiter)
.load(spatialJoinRightInputLocation)
.withColumn("id", col("_c0").cast(IntegerType))
.withColumn("geom", ST_GeomFromText(new Column("_c2")))
.select("id", "geom")
+ .withColumn("dist", expr("ST_Area(geom)"))
df1.createOrReplaceTempView("df1")
df2.createOrReplaceTempView("df2")
(df1, df2)
@@ -167,10 +174,16 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
case "ST_Touches" => (l: Geometry, r: Geometry) => l.touches(r)
case "ST_Within" => (l: Geometry, r: Geometry) => l.within(r)
case "ST_Distance" =>
- if (joinCondition.contains("<=")) {
- (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
- } else {
- (l: Geometry, r: Geometry) => l.distance(r) < 1.0
+ if (joinCondition contains "df1.dist")
+ (l: Geometry, r: Geometry) => l.distance(r) < (if (!swapped) l.getArea else r.getArea)
+ else if (joinCondition contains "df2.dist")
+ (l: Geometry, r: Geometry) => l.distance(r) < (if (!swapped) r.getArea else l.getArea)
+ else {
+ if (joinCondition.contains("<=")) {
+ (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
+ } else {
+ (l: Geometry, r: Geometry) => l.distance(r) < 1.0
+ }
}
}
left.flatMap { case (id, geom) =>