You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/11/10 09:43:19 UTC
[incubator-sedona] branch master updated: [SEDONA-186] Fix ordering of columns in results of queries like SELECT * A JOIN B (#706)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new dcba933f [SEDONA-186] Fix ordering of columns in results of queries like SELECT * A JOIN B (#706)
dcba933f is described below
commit dcba933f9343094969bf95fae8844599ca41e3fd
Author: Kristin Cowalcijk <mo...@yeah.net>
AuthorDate: Thu Nov 10 17:43:12 2022 +0800
[SEDONA-186] Fix ordering of columns in results of queries like SELECT * A JOIN B (#706)
---
.../strategy/join/DistanceJoinExec.scala | 14 +++---
.../strategy/join/JoinQueryDetector.scala | 10 ++--
.../sedona_sql/strategy/join/RangeJoinExec.scala | 10 ++--
.../strategy/join/TraitJoinQueryExec.scala | 54 +++++++++-------------
.../org/apache/sedona/sql/SpatialJoinSuite.scala | 41 +++++++++++++++-
5 files changed, 82 insertions(+), 47 deletions(-)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 4df455ef..ee440e46 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -38,18 +38,20 @@ import org.locationtech.jts.geom.Geometry
*
* select * from a join b on ST_Intersects(ST_Envelope(ST_Buffer(a.geom, 1)), b.geom) and ST_Distance(a.geom, b.geom) <= 1
*
- * @param left
- * @param right
- * @param leftShape
- * @param rightShape
+ * @param left left side of the join
+ * @param right right side of the join
+ * @param leftShape expression for the first argument of spatialPredicate
+ * @param rightShape expression for the second argument of spatialPredicate
+ * @param swappedLeftAndRight boolean indicating whether left and right plans were swapped
* @param distance - ST_Distance(left, right) <= distance. Distance can be literal or a computation over 'left'.
- * @param spatialPredicate
- * @param extraCondition
+ * @param spatialPredicate spatial predicate as join condition
+ * @param extraCondition extra join condition other than spatialPredicate
*/
case class DistanceJoinExec(left: SparkPlan,
right: SparkPlan,
leftShape: Expression,
rightShape: Expression,
+ swappedLeftAndRight: Boolean,
distance: Expression,
spatialPredicate: SpatialPredicate,
extraCondition: Option[Expression] = None)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 1e419b17..f5e086c0 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -164,9 +164,9 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
val relationship = s"ST_$spatialPredicate"
matchExpressionsToPlans(a, b, left, right) match {
- case Some((planA, planB, _)) =>
+ case Some((planA, planB, swappedLeftAndRight)) =>
logInfo(s"Planning spatial join for $relationship relationship")
- RangeJoinExec(planLater(planA), planLater(planB), a, b, spatialPredicate, extraCondition) :: Nil
+ RangeJoinExec(planLater(planA), planLater(planB), a, b, swappedLeftAndRight, spatialPredicate, extraCondition) :: Nil
case None =>
logInfo(
s"Spatial join for $relationship with arguments not aligned " +
@@ -185,13 +185,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
val b = children.tail.head
matchExpressionsToPlans(a, b, left, right) match {
- case Some((planA, planB, _)) =>
+ case Some((planA, planB, swappedLeftAndRight)) =>
if (distance.references.isEmpty || matches(distance, planA)) {
logInfo("Planning spatial distance join")
- DistanceJoinExec(planLater(planA), planLater(planB), a, b, distance, spatialPredicate, extraCondition) :: Nil
+ DistanceJoinExec(planLater(planA), planLater(planB), a, b, swappedLeftAndRight, distance, spatialPredicate, extraCondition) :: Nil
} else if (matches(distance, planB)) {
logInfo("Planning spatial distance join")
- DistanceJoinExec(planLater(planB), planLater(planA), b, a, distance, spatialPredicate, extraCondition) :: Nil
+ DistanceJoinExec(planLater(planB), planLater(planA), b, a, swappedLeftAndRight, distance, spatialPredicate, extraCondition) :: Nil
} else {
logInfo(
"Spatial distance join for ST_Distance with non-scalar distance " +
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala
index feffe99d..ada871a6 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala
@@ -31,15 +31,17 @@ import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
*
* @param left left side of the join
* @param right right side of the join
- * @param leftShape expression for the first argument of ST_Contains or ST_Intersects
- * @param rightShape expression for the second argument of ST_Contains or ST_Intersects
- * @param intersects boolean indicating whether spatial relationship is 'intersects' (true)
- * or 'contains' (false)
+ * @param leftShape expression for the first argument of spatialPredicate
+ * @param rightShape expression for the second argument of spatialPredicate
+ * @param swappedLeftAndRight boolean indicating whether left and right plans were swapped
+ * @param spatialPredicate spatial predicate as join condition
+ * @param extraCondition extra join condition other than spatialPredicate
*/
case class RangeJoinExec(left: SparkPlan,
right: SparkPlan,
leftShape: Expression,
rightShape: Expression,
+ swappedLeftAndRight: Boolean,
spatialPredicate: SpatialPredicate,
extraCondition: Option[Expression] = None)
extends SedonaBinaryExecNode
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index ca5418b3..eb912798 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -33,23 +33,17 @@ import org.locationtech.jts.geom.Geometry
trait TraitJoinQueryExec extends TraitJoinQueryBase {
self: SparkPlan =>
- // Using lazy val to avoid serialization
- @transient private lazy val boundCondition: (InternalRow => Boolean) = {
- if (extraCondition.isDefined) {
- Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPARK3 anchor
-// newPredicate(extraCondition.get, left.output ++ right.output).eval _ // SPARK2 anchor
- } else { (r: InternalRow) =>
- true
- }
- }
val left: SparkPlan
val right: SparkPlan
val leftShape: Expression
val rightShape: Expression
+ val swappedLeftAndRight: Boolean
val spatialPredicate: SpatialPredicate
val extraCondition: Option[Expression]
- override def output: Seq[Attribute] = left.output ++ right.output
+ override def output: Seq[Attribute] = {
+ if (!swappedLeftAndRight) left.output ++ right.output else right.output ++ left.output
+ }
override protected def doExecute(): RDD[InternalRow] = {
val boundLeftShape = BindReferences.bindReference(leftShape, left.output)
@@ -58,7 +52,7 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
val leftResultsRaw = left.execute().asInstanceOf[RDD[UnsafeRow]]
val rightResultsRaw = right.execute().asInstanceOf[RDD[UnsafeRow]]
- var sedonaConf = new SedonaConf(sparkContext.conf)
+ val sedonaConf = new SedonaConf(sparkContext.conf)
val (leftShapes, rightShapes) =
toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw, boundRightShape)
@@ -135,27 +129,25 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
logDebug(s"Join result has ${matchesRDD.count()} rows")
matchesRDD.mapPartitions { iter =>
- val filtered =
- if (extraCondition.isDefined) {
- val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
-// val boundCondition = newPredicate(extraCondition.get, left.output ++ right.output) // SPARK2 anchor
- iter.filter {
- case (l, r) =>
- val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
- val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
- var joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
- boundCondition.eval(joiner.join(leftRow, rightRow))
- }
- } else {
- iter
- }
+ val joinRow = if (!swappedLeftAndRight) {
+ val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
+ (l: UnsafeRow, r: UnsafeRow) => joiner.join(l, r)
+ } else {
+ val joiner = GenerateUnsafeRowJoiner.create(right.schema, left.schema)
+ (l: UnsafeRow, r: UnsafeRow) => joiner.join(r, l)
+ }
+
+ val joined = iter.map { case (l, r) =>
+ val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
+ val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
+ joinRow(leftRow, rightRow)
+ }
- filtered.map {
- case (l, r) =>
- val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
- val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
- var joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
- joiner.join(leftRow, rightRow)
+ extraCondition match {
+ case Some(condition) =>
+ val boundCondition = Predicate.create(condition, output)
+ joined.filter(row => boundCondition.eval(row))
+ case None => joined
}
}
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 62de4f5b..9ec443ff 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -19,7 +19,6 @@
package org.apache.sedona.sql
-import org.apache.sedona.common.geometryObjects.Circle
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
@@ -98,15 +97,55 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
}
}
+ describe("Sedona-SQL Spatial Join Test with SELECT *") {
+ val joinConditions = Table("join condition",
+ "ST_Contains(df1.geom, df2.geom)",
+ "ST_Contains(df2.geom, df1.geom)",
+ "ST_Distance(df1.geom, df2.geom) < 1.0",
+ "ST_Distance(df2.geom, df1.geom) < 1.0"
+ )
+
+ forAll (joinConditions) { joinCondition =>
+ it(s"should SELECT * in join query with $joinCondition produce correct result") {
+ prepareTempViewsForTestData()
+ val resultAll = sparkSession.sql(s"SELECT * FROM df1 JOIN df2 ON $joinCondition").collect()
+ val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+ val expected = buildExpectedResult(joinCondition)
+ assert(result.nonEmpty)
+ assert(result === expected)
+ }
+
+ it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the left side") {
+ prepareTempViewsForTestData()
+ val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
+ val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+ val expected = buildExpectedResult(joinCondition)
+ assert(result.nonEmpty)
+ assert(result === expected)
+ }
+
+ it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the right side") {
+ prepareTempViewsForTestData()
+ val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
+ val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+ val expected = buildExpectedResult(joinCondition)
+ assert(result.nonEmpty)
+ assert(result === expected)
+ }
+ }
+ }
+
private def prepareTempViewsForTestData(): (DataFrame, DataFrame) = {
val df1 = sparkSession.read.format("csv").option("header", "false").option("delimiter", testDataDelimiter)
.load(spatialJoinLeftInputLocation)
.withColumn("id", col("_c0").cast(IntegerType))
.withColumn("geom", ST_GeomFromText(new Column("_c2")))
+ .select("id", "geom")
val df2 = sparkSession.read.format("csv").option("header", "false").option("delimiter", testDataDelimiter)
.load(spatialJoinRightInputLocation)
.withColumn("id", col("_c0").cast(IntegerType))
.withColumn("geom", ST_GeomFromText(new Column("_c2")))
+ .select("id", "geom")
df1.createOrReplaceTempView("df1")
df2.createOrReplaceTempView("df2")
(df1, df2)