You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/11/10 09:43:19 UTC

[incubator-sedona] branch master updated: [SEDONA-186] Fix ordering of columns in results of queries like SELECT * A JOIN B (#706)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new dcba933f [SEDONA-186] Fix ordering of columns in results of queries like SELECT * A JOIN B (#706)
dcba933f is described below

commit dcba933f9343094969bf95fae8844599ca41e3fd
Author: Kristin Cowalcijk <mo...@yeah.net>
AuthorDate: Thu Nov 10 17:43:12 2022 +0800

    [SEDONA-186] Fix ordering of columns in results of queries like SELECT * A JOIN B (#706)
---
 .../strategy/join/DistanceJoinExec.scala           | 14 +++---
 .../strategy/join/JoinQueryDetector.scala          | 10 ++--
 .../sedona_sql/strategy/join/RangeJoinExec.scala   | 10 ++--
 .../strategy/join/TraitJoinQueryExec.scala         | 54 +++++++++-------------
 .../org/apache/sedona/sql/SpatialJoinSuite.scala   | 41 +++++++++++++++-
 5 files changed, 82 insertions(+), 47 deletions(-)

diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
index 4df455ef..ee440e46 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/DistanceJoinExec.scala
@@ -38,18 +38,20 @@ import org.locationtech.jts.geom.Geometry
  *
  * select * from a join b on ST_Intersects(ST_Envelope(ST_Buffer(a.geom, 1)), b.geom) and ST_Distance(a.geom, b.geom) <= 1
  *
- * @param left
- * @param right
- * @param leftShape
- * @param rightShape
+ * @param left left side of the join
+ * @param right right side of the join
+ * @param leftShape expression for the first argument of spatialPredicate
+ * @param rightShape expression for the second argument of spatialPredicate
+ * @param swappedLeftAndRight boolean indicating whether left and right plans were swapped
  * @param distance - ST_Distance(left, right) <= distance. Distance can be literal or a computation over 'left'.
- * @param spatialPredicate
- * @param extraCondition
+ * @param spatialPredicate spatial predicate as join condition
+ * @param extraCondition extra join condition other than spatialPredicate
  */
 case class DistanceJoinExec(left: SparkPlan,
                             right: SparkPlan,
                             leftShape: Expression,
                             rightShape: Expression,
+                            swappedLeftAndRight: Boolean,
                             distance: Expression,
                             spatialPredicate: SpatialPredicate,
                             extraCondition: Option[Expression] = None)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index 1e419b17..f5e086c0 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -164,9 +164,9 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
     val relationship = s"ST_$spatialPredicate"
 
     matchExpressionsToPlans(a, b, left, right) match {
-      case Some((planA, planB, _)) =>
+      case Some((planA, planB, swappedLeftAndRight)) =>
         logInfo(s"Planning spatial join for $relationship relationship")
-        RangeJoinExec(planLater(planA), planLater(planB), a, b, spatialPredicate, extraCondition) :: Nil
+        RangeJoinExec(planLater(planA), planLater(planB), a, b, swappedLeftAndRight, spatialPredicate, extraCondition) :: Nil
       case None =>
         logInfo(
           s"Spatial join for $relationship with arguments not aligned " +
@@ -185,13 +185,13 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
     val b = children.tail.head
 
     matchExpressionsToPlans(a, b, left, right) match {
-      case Some((planA, planB, _)) =>
+      case Some((planA, planB, swappedLeftAndRight)) =>
         if (distance.references.isEmpty || matches(distance, planA)) {
           logInfo("Planning spatial distance join")
-          DistanceJoinExec(planLater(planA), planLater(planB), a, b, distance, spatialPredicate, extraCondition) :: Nil
+          DistanceJoinExec(planLater(planA), planLater(planB), a, b, swappedLeftAndRight, distance, spatialPredicate, extraCondition) :: Nil
         } else if (matches(distance, planB)) {
           logInfo("Planning spatial distance join")
-          DistanceJoinExec(planLater(planB), planLater(planA), b, a, distance, spatialPredicate, extraCondition) :: Nil
+          DistanceJoinExec(planLater(planB), planLater(planA), b, a, swappedLeftAndRight, distance, spatialPredicate, extraCondition) :: Nil
         } else {
           logInfo(
             "Spatial distance join for ST_Distance with non-scalar distance " +
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala
index feffe99d..ada871a6 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/RangeJoinExec.scala
@@ -31,15 +31,17 @@ import org.apache.spark.sql.sedona_sql.execution.SedonaBinaryExecNode
   *
   * @param left       left side of the join
   * @param right      right side of the join
-  * @param leftShape  expression for the first argument of ST_Contains or ST_Intersects
-  * @param rightShape expression for the second argument of ST_Contains or ST_Intersects
-  * @param intersects boolean indicating whether spatial relationship is 'intersects' (true)
-  *                   or 'contains' (false)
+  * @param leftShape  expression for the first argument of spatialPredicate
+  * @param rightShape expression for the second argument of spatialPredicate
+  * @param swappedLeftAndRight boolean indicating whether left and right plans were swapped
+  * @param spatialPredicate spatial predicate as join condition
+  * @param extraCondition extra join condition other than spatialPredicate
   */
 case class RangeJoinExec(left: SparkPlan,
                          right: SparkPlan,
                          leftShape: Expression,
                          rightShape: Expression,
+                         swappedLeftAndRight: Boolean,
                          spatialPredicate: SpatialPredicate,
                          extraCondition: Option[Expression] = None)
   extends SedonaBinaryExecNode
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index ca5418b3..eb912798 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -33,23 +33,17 @@ import org.locationtech.jts.geom.Geometry
 trait TraitJoinQueryExec extends TraitJoinQueryBase {
   self: SparkPlan =>
 
-  // Using lazy val to avoid serialization
-  @transient private lazy val boundCondition: (InternalRow => Boolean) = {
-    if (extraCondition.isDefined) {
-      Predicate.create(extraCondition.get, left.output ++ right.output).eval _ // SPARK3 anchor
-//      newPredicate(extraCondition.get, left.output ++ right.output).eval _ // SPARK2 anchor
-    } else { (r: InternalRow) =>
-      true
-    }
-  }
   val left: SparkPlan
   val right: SparkPlan
   val leftShape: Expression
   val rightShape: Expression
+  val swappedLeftAndRight: Boolean
   val spatialPredicate: SpatialPredicate
   val extraCondition: Option[Expression]
 
-  override def output: Seq[Attribute] = left.output ++ right.output
+  override def output: Seq[Attribute] = {
+    if (!swappedLeftAndRight) left.output ++ right.output else right.output ++ left.output
+  }
 
   override protected def doExecute(): RDD[InternalRow] = {
     val boundLeftShape = BindReferences.bindReference(leftShape, left.output)
@@ -58,7 +52,7 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
     val leftResultsRaw = left.execute().asInstanceOf[RDD[UnsafeRow]]
     val rightResultsRaw = right.execute().asInstanceOf[RDD[UnsafeRow]]
 
-    var sedonaConf = new SedonaConf(sparkContext.conf)
+    val sedonaConf = new SedonaConf(sparkContext.conf)
     val (leftShapes, rightShapes) =
       toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw, boundRightShape)
 
@@ -135,27 +129,25 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
     logDebug(s"Join result has ${matchesRDD.count()} rows")
 
     matchesRDD.mapPartitions { iter =>
-      val filtered =
-        if (extraCondition.isDefined) {
-          val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
-//          val boundCondition = newPredicate(extraCondition.get, left.output ++ right.output) // SPARK2 anchor
-          iter.filter {
-            case (l, r) =>
-              val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
-              val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
-              var joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
-              boundCondition.eval(joiner.join(leftRow, rightRow))
-          }
-        } else {
-          iter
-        }
+      val joinRow = if (!swappedLeftAndRight) {
+        val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
+        (l: UnsafeRow, r: UnsafeRow) => joiner.join(l, r)
+      } else {
+        val joiner = GenerateUnsafeRowJoiner.create(right.schema, left.schema)
+        (l: UnsafeRow, r: UnsafeRow) => joiner.join(r, l)
+      }
+
+      val joined = iter.map { case (l, r) =>
+        val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
+        val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
+        joinRow(leftRow, rightRow)
+      }
 
-      filtered.map {
-        case (l, r) =>
-          val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
-          val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
-          var joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
-          joiner.join(leftRow, rightRow)
+      extraCondition match {
+        case Some(condition) =>
+          val boundCondition = Predicate.create(condition, output)
+          joined.filter(row => boundCondition.eval(row))
+        case None => joined
       }
     }
   }
diff --git a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 62de4f5b..9ec443ff 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -19,7 +19,6 @@
 
 package org.apache.sedona.sql
 
-import org.apache.sedona.common.geometryObjects.Circle
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions.col
@@ -98,15 +97,55 @@ class SpatialJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
     }
   }
 
+  describe("Sedona-SQL Spatial Join Test with SELECT *") {
+    val joinConditions = Table("join condition",
+      "ST_Contains(df1.geom, df2.geom)",
+      "ST_Contains(df2.geom, df1.geom)",
+      "ST_Distance(df1.geom, df2.geom) < 1.0",
+      "ST_Distance(df2.geom, df1.geom) < 1.0"
+    )
+
+    forAll (joinConditions) { joinCondition =>
+      it(s"should SELECT * in join query with $joinCondition produce correct result") {
+        prepareTempViewsForTestData()
+        val resultAll = sparkSession.sql(s"SELECT * FROM df1 JOIN df2 ON $joinCondition").collect()
+        val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+        val expected = buildExpectedResult(joinCondition)
+        assert(result.nonEmpty)
+        assert(result === expected)
+      }
+
+      it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the left side") {
+        prepareTempViewsForTestData()
+        val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df1) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
+        val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+        val expected = buildExpectedResult(joinCondition)
+        assert(result.nonEmpty)
+        assert(result === expected)
+      }
+
+      it(s"should SELECT * in join query with $joinCondition produce correct result, broadcast the right side") {
+        prepareTempViewsForTestData()
+        val resultAll = sparkSession.sql(s"SELECT /*+ BROADCAST(df2) */ * FROM df1 JOIN df2 ON $joinCondition").collect()
+        val result = resultAll.map(row => (row.getInt(0), row.getInt(2))).sorted
+        val expected = buildExpectedResult(joinCondition)
+        assert(result.nonEmpty)
+        assert(result === expected)
+      }
+    }
+  }
+
   private def prepareTempViewsForTestData(): (DataFrame, DataFrame) = {
     val df1 = sparkSession.read.format("csv").option("header", "false").option("delimiter", testDataDelimiter)
       .load(spatialJoinLeftInputLocation)
       .withColumn("id", col("_c0").cast(IntegerType))
       .withColumn("geom", ST_GeomFromText(new Column("_c2")))
+      .select("id", "geom")
     val df2 = sparkSession.read.format("csv").option("header", "false").option("delimiter", testDataDelimiter)
       .load(spatialJoinRightInputLocation)
       .withColumn("id", col("_c0").cast(IntegerType))
       .withColumn("geom", ST_GeomFromText(new Column("_c2")))
+      .select("id", "geom")
     df1.createOrReplaceTempView("df1")
     df2.createOrReplaceTempView("df2")
     (df1, df2)