You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2021/09/22 21:59:55 UTC

[incubator-sedona] branch master updated: [SEDONA-22] [SEDONA-60] Fix join in Spark SQL when one side has no rows or only one row (#546)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c9bed6  [SEDONA-22] [SEDONA-60] Fix join in Spark SQL when one side has no rows or only one row (#546)
4c9bed6 is described below

commit 4c9bed6782735b8227bc7d45f29181c3ddf33c4f
Author: Kristin Cowalcijk <mo...@yeah.net>
AuthorDate: Thu Sep 23 05:59:48 2021 +0800

    [SEDONA-22] [SEDONA-60] Fix join in Spark SQL when one side has no rows or only one row (#546)
---
 .../apache/sedona/core/spatialRDD/SpatialRDD.java  |  5 +--
 .../apache/sedona/core/utils/RDDSampleUtils.java   |  6 ++--
 .../core/spatialOperator/PolygonJoinTest.java      | 40 ++++++++++++++++++++-
 .../sedona/core/utils/RDDSampleUtilsTest.java      |  2 ++
 .../strategy/join/TraitJoinQueryBase.scala         |  6 ++--
 .../strategy/join/TraitJoinQueryExec.scala         | 12 +++++--
 .../apache/sedona/sql/predicateJoinTestScala.scala | 42 ++++++++++++++++++++++
 7 files changed, 102 insertions(+), 11 deletions(-)

diff --git a/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java b/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java
index 3eb6ae7..0403f47 100644
--- a/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java
+++ b/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java
@@ -220,7 +220,7 @@ public class SpatialRDD<T extends Geometry>
             throw new Exception("[AbstractSpatialRDD][spatialPartitioning] SpatialRDD boundary is null. Please call analyze() first.");
         }
         if (this.approximateTotalCount == -1) {
-            throw new Exception("[AbstractSpatialRDD][spatialPartitioning] SpatialRDD total count is unkown. Please call analyze() first.");
+            throw new Exception("[AbstractSpatialRDD][spatialPartitioning] SpatialRDD total count is unknown. Please call analyze() first.");
         }
 
         //Calculate the number of samples we need to take.
@@ -259,7 +259,8 @@ public class SpatialRDD<T extends Geometry>
                 break;
             }
             case KDBTREE: {
-                final KDBTree tree = new KDBTree(samples.size() / numPartitions, numPartitions, paddedBoundary);
+                int maxItemsPerNode = Math.max(samples.size() / numPartitions, 1);
+                final KDBTree tree = new KDBTree(maxItemsPerNode, numPartitions, paddedBoundary);
                 for (final Envelope sample : samples) {
                     tree.insert(sample);
                 }
diff --git a/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java b/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java
index 6e8f4af..7bc96cd 100644
--- a/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java
+++ b/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java
@@ -56,7 +56,7 @@ public class RDDSampleUtils
         }
 
         // Make sure that number of records >= 2 * number of partitions
-        if (totalNumberOfRecords < 2 * numPartitions) {
+        if (numPartitions > (totalNumberOfRecords + 1) / 2) {
             throw new IllegalArgumentException("[Sedona] Number of partitions " + numPartitions + " cannot be larger than half of total records num " + totalNumberOfRecords);
         }
 
@@ -64,7 +64,7 @@ public class RDDSampleUtils
             return (int) totalNumberOfRecords;
         }
 
-        final int minSampleCnt = numPartitions * 2;
+        final long minSampleCnt = Math.min(numPartitions * 2L, totalNumberOfRecords);
         return (int) Math.max(minSampleCnt, Math.min(totalNumberOfRecords / 100, Integer.MAX_VALUE));
     }
-}
\ No newline at end of file
+}
diff --git a/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java b/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java
index 1b1606d..97b14e7 100644
--- a/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java
+++ b/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java
@@ -22,6 +22,7 @@ import org.apache.sedona.core.enums.GridType;
 import org.apache.sedona.core.enums.IndexType;
 import org.apache.sedona.core.enums.JoinBuildSide;
 import org.apache.sedona.core.spatialRDD.PolygonRDD;
+import org.apache.spark.storage.StorageLevel;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -32,6 +33,7 @@ import scala.Tuple2;
 
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
@@ -217,4 +219,40 @@ public class PolygonJoinTest
     {
         return intersects ? expectedIntersectsWithOriginalDuplicatesCount : expectedContainsWithOriginalDuplicatesCount;
     }
-}
\ No newline at end of file
+
+    @Test
+    public void testJoinWithSingletonRDD() throws Exception
+    {
+        PolygonRDD queryRDD = createPolygonRDD(InputLocationQueryPolygon);
+        PolygonRDD spatialRDD = createPolygonRDD(InputLocation);
+        PolygonRDD singletonRDD = new PolygonRDD();
+        Polygon queryPolygon = queryRDD.rawSpatialRDD.first();
+        singletonRDD.rawSpatialRDD = sc.parallelize(Collections.singletonList(queryPolygon), 1);
+        singletonRDD.analyze(StorageLevel.MEMORY_ONLY());
+
+        // Joining with a singleton RDD is essentially the same with a range query
+        long expectedResultCount = RangeQuery.SpatialRangeQuery(spatialRDD, queryPolygon, true, false).count();
+
+        partitionRdds(singletonRDD, spatialRDD);
+        List<Tuple2<Polygon, Polygon>> result = JoinQuery.SpatialJoinQueryFlat(spatialRDD, singletonRDD, false, true).collect();
+        sanityCheckFlatJoinResults(result);
+        assertEquals(expectedResultCount, result.size());
+
+        partitionRdds(spatialRDD, singletonRDD);
+        result = JoinQuery.SpatialJoinQueryFlat(singletonRDD, spatialRDD, false, true).collect();
+        sanityCheckFlatJoinResults(result);
+        assertEquals(expectedResultCount, result.size());
+
+        partitionRdds(singletonRDD, spatialRDD);
+        spatialRDD.buildIndex(indexType, true);
+        result = JoinQuery.SpatialJoinQueryFlat(spatialRDD, singletonRDD, true, true).collect();
+        sanityCheckFlatJoinResults(result);
+        assertEquals(expectedResultCount, result.size());
+
+        partitionRdds(spatialRDD, singletonRDD);
+        singletonRDD.buildIndex(indexType, true);
+        result = JoinQuery.SpatialJoinQueryFlat(singletonRDD, spatialRDD, true, true).collect();
+        sanityCheckFlatJoinResults(result);
+        assertEquals(expectedResultCount, result.size());
+    }
+}
diff --git a/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java b/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java
index 1cffae1..812f200 100644
--- a/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java
+++ b/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java
@@ -41,6 +41,7 @@ public class RDDSampleUtilsTest
         assertEquals(99, RDDSampleUtils.getSampleNumbers(6, 100011, 99));
         assertEquals(999, RDDSampleUtils.getSampleNumbers(20, 999, -1));
         assertEquals(40, RDDSampleUtils.getSampleNumbers(20, 1000, -1));
+        assertEquals(1, RDDSampleUtils.getSampleNumbers(1, 1, -1));
     }
 
     /**
@@ -52,6 +53,7 @@ public class RDDSampleUtilsTest
         assertFailure(505, 999);
         assertFailure(505, 1000);
         assertFailure(10, 1000, 2100);
+        assertFailure(2, 1, -1);
     }
 
     private void assertFailure(int numPartitions, long totalNumberOfRecords)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 3b7a0e1..0e179b5 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -69,7 +69,9 @@ trait TraitJoinQueryBase {
 
   def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry], followerShapes: SpatialRDD[Geometry],
                             numPartitions: Integer, sedonaConf: SedonaConf): Unit = {
-    dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
-    followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
+    if (dominantShapes.approximateTotalCount > 0) {
+      dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
+      followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
+    }
   }
 }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index 9850b78..d0f59ae 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, Predicate, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.execution.SparkPlan
+import org.locationtech.jts.geom.Geometry
 
 trait TraitJoinQueryExec extends TraitJoinQueryBase {
   self: SparkPlan =>
@@ -123,11 +124,16 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
     //logInfo(s"leftShape count ${leftShapes.spatialPartitionedRDD.count()}")
     //logInfo(s"rightShape count ${rightShapes.spatialPartitionedRDD.count()}")
 
-    val matches = JoinQuery.spatialJoin(leftShapes, rightShapes, joinParams)
+    val matchesRDD: RDD[(Geometry, Geometry)] = (leftShapes.spatialPartitionedRDD, rightShapes.spatialPartitionedRDD) match {
+      case (null, null) =>
+        // Dominant side is empty, skipped creating partitioned RDDs. Result of join should also be empty.
+        sparkContext.parallelize(Seq[(Geometry, Geometry)]())
+      case _ => JoinQuery.spatialJoin(leftShapes, rightShapes, joinParams).rdd
+    }
 
-    logDebug(s"Join result has ${matches.count()} rows")
+    logDebug(s"Join result has ${matchesRDD.count()} rows")
 
-    matches.rdd.mapPartitions { iter =>
+    matchesRDD.mapPartitions { iter =>
       val filtered =
         if (extraCondition.isDefined) {
           val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
diff --git a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index dc6bb0a..688016b 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -22,6 +22,8 @@ package org.apache.sedona.sql
 import org.apache.sedona.core.utils.SedonaConf
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.types._
+import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.io.WKTWriter
 
 class predicateJoinTestScala extends TestBaseScala {
 
@@ -125,6 +127,46 @@ class predicateJoinTestScala extends TestBaseScala {
       assert(rangeJoinDf.count() == 1000)
     }
 
+    it("Passed ST_Intersects in a join with singleton dataframe") {
+      var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPolygonInputLocation).limit(1).repartition(1)
+      polygonCsvDf.createOrReplaceTempView("polygontable")
+      var polygonDf = sparkSession.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable")
+      polygonDf.createOrReplaceTempView("polygondf")
+
+      var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
+      pointCsvDF.createOrReplaceTempView("pointtable")
+      var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable")
+      pointDf.createOrReplaceTempView("pointdf")
+
+      // Join with a singleton dataframe is essentially a range query
+      val polygon = polygonDf.first().getAs[Geometry]("polygonshape")
+      val rangeQueryDf = sparkSession.sql(s"select * from pointdf where ST_Intersects(pointdf.pointshape, ST_GeomFromWKT('${polygon.toString}'))")
+      val rangeQueryCount = rangeQueryDf.count()
+
+      // Perform spatial join and compare results
+      val rangeJoinDf1 = sparkSession.sql("select * from polygondf, pointdf where ST_Intersects(pointdf.pointshape, polygondf.polygonshape)")
+      val rangeJoinDf2 = sparkSession.sql("select * from pointdf, polygondf where ST_Intersects(polygondf.polygonshape, pointdf.pointshape)")
+      assert(rangeJoinDf1.count() == rangeQueryCount)
+      assert(rangeJoinDf2.count() == rangeQueryCount)
+    }
+
+    it("Passed ST_Intersects in a join with empty dataframe") {
+      var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPolygonInputLocation)
+      polygonCsvDf.createOrReplaceTempView("polygontable")
+      var polygonDf = sparkSession.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable")
+      polygonDf.createOrReplaceTempView("polygondf")
+
+      var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
+      pointCsvDF.createOrReplaceTempView("pointtable")
+      var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable where pointtable._c0 > 10000").repartition(1)
+      pointDf.createOrReplaceTempView("pointdf")
+
+      val rangeJoinDf1 = sparkSession.sql("select * from polygondf, pointdf where ST_Intersects(pointdf.pointshape, polygondf.polygonshape)")
+      val rangeJoinDf2 = sparkSession.sql("select * from pointdf, polygondf where ST_Intersects(polygondf.polygonshape, pointdf.pointshape)")
+      assert(rangeJoinDf1.count() == 0)
+      assert(rangeJoinDf2.count() == 0)
+    }
+
     it("Passed ST_Distance <= radius in a join") {
       var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
       pointCsvDF1.createOrReplaceTempView("pointtable")