You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2021/09/22 21:59:55 UTC
[incubator-sedona] branch master updated: [SEDONA-22] [SEDONA-60]
Fix join in Spark SQL when one side has no rows or only one row (#546)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 4c9bed6 [SEDONA-22] [SEDONA-60] Fix join in Spark SQL when one side has no rows or only one row (#546)
4c9bed6 is described below
commit 4c9bed6782735b8227bc7d45f29181c3ddf33c4f
Author: Kristin Cowalcijk <mo...@yeah.net>
AuthorDate: Thu Sep 23 05:59:48 2021 +0800
[SEDONA-22] [SEDONA-60] Fix join in Spark SQL when one side has no rows or only one row (#546)
---
.../apache/sedona/core/spatialRDD/SpatialRDD.java | 5 +--
.../apache/sedona/core/utils/RDDSampleUtils.java | 6 ++--
.../core/spatialOperator/PolygonJoinTest.java | 40 ++++++++++++++++++++-
.../sedona/core/utils/RDDSampleUtilsTest.java | 2 ++
.../strategy/join/TraitJoinQueryBase.scala | 6 ++--
.../strategy/join/TraitJoinQueryExec.scala | 12 +++++--
.../apache/sedona/sql/predicateJoinTestScala.scala | 42 ++++++++++++++++++++++
7 files changed, 102 insertions(+), 11 deletions(-)
diff --git a/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java b/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java
index 3eb6ae7..0403f47 100644
--- a/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java
+++ b/core/src/main/java/org/apache/sedona/core/spatialRDD/SpatialRDD.java
@@ -220,7 +220,7 @@ public class SpatialRDD<T extends Geometry>
throw new Exception("[AbstractSpatialRDD][spatialPartitioning] SpatialRDD boundary is null. Please call analyze() first.");
}
if (this.approximateTotalCount == -1) {
- throw new Exception("[AbstractSpatialRDD][spatialPartitioning] SpatialRDD total count is unkown. Please call analyze() first.");
+ throw new Exception("[AbstractSpatialRDD][spatialPartitioning] SpatialRDD total count is unknown. Please call analyze() first.");
}
//Calculate the number of samples we need to take.
@@ -259,7 +259,8 @@ public class SpatialRDD<T extends Geometry>
break;
}
case KDBTREE: {
- final KDBTree tree = new KDBTree(samples.size() / numPartitions, numPartitions, paddedBoundary);
+ int maxItemsPerNode = Math.max(samples.size() / numPartitions, 1);
+ final KDBTree tree = new KDBTree(maxItemsPerNode, numPartitions, paddedBoundary);
for (final Envelope sample : samples) {
tree.insert(sample);
}
diff --git a/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java b/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java
index 6e8f4af..7bc96cd 100644
--- a/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java
+++ b/core/src/main/java/org/apache/sedona/core/utils/RDDSampleUtils.java
@@ -56,7 +56,7 @@ public class RDDSampleUtils
}
// Make sure that number of records >= 2 * number of partitions
- if (totalNumberOfRecords < 2 * numPartitions) {
+ if (numPartitions > (totalNumberOfRecords + 1) / 2) {
throw new IllegalArgumentException("[Sedona] Number of partitions " + numPartitions + " cannot be larger than half of total records num " + totalNumberOfRecords);
}
@@ -64,7 +64,7 @@ public class RDDSampleUtils
return (int) totalNumberOfRecords;
}
- final int minSampleCnt = numPartitions * 2;
+ final long minSampleCnt = Math.min(numPartitions * 2L, totalNumberOfRecords);
return (int) Math.max(minSampleCnt, Math.min(totalNumberOfRecords / 100, Integer.MAX_VALUE));
}
-}
\ No newline at end of file
+}
diff --git a/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java b/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java
index 1b1606d..97b14e7 100644
--- a/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java
+++ b/core/src/test/java/org/apache/sedona/core/spatialOperator/PolygonJoinTest.java
@@ -22,6 +22,7 @@ import org.apache.sedona.core.enums.GridType;
import org.apache.sedona.core.enums.IndexType;
import org.apache.sedona.core.enums.JoinBuildSide;
import org.apache.sedona.core.spatialRDD.PolygonRDD;
+import org.apache.spark.storage.StorageLevel;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
@@ -32,6 +33,7 @@ import scala.Tuple2;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
@@ -217,4 +219,40 @@ public class PolygonJoinTest
{
return intersects ? expectedIntersectsWithOriginalDuplicatesCount : expectedContainsWithOriginalDuplicatesCount;
}
-}
\ No newline at end of file
+
+ @Test
+ public void testJoinWithSingletonRDD() throws Exception
+ {
+ PolygonRDD queryRDD = createPolygonRDD(InputLocationQueryPolygon);
+ PolygonRDD spatialRDD = createPolygonRDD(InputLocation);
+ PolygonRDD singletonRDD = new PolygonRDD();
+ Polygon queryPolygon = queryRDD.rawSpatialRDD.first();
+ singletonRDD.rawSpatialRDD = sc.parallelize(Collections.singletonList(queryPolygon), 1);
+ singletonRDD.analyze(StorageLevel.MEMORY_ONLY());
+
+ // Joining with a singleton RDD is essentially the same with a range query
+ long expectedResultCount = RangeQuery.SpatialRangeQuery(spatialRDD, queryPolygon, true, false).count();
+
+ partitionRdds(singletonRDD, spatialRDD);
+ List<Tuple2<Polygon, Polygon>> result = JoinQuery.SpatialJoinQueryFlat(spatialRDD, singletonRDD, false, true).collect();
+ sanityCheckFlatJoinResults(result);
+ assertEquals(expectedResultCount, result.size());
+
+ partitionRdds(spatialRDD, singletonRDD);
+ result = JoinQuery.SpatialJoinQueryFlat(singletonRDD, spatialRDD, false, true).collect();
+ sanityCheckFlatJoinResults(result);
+ assertEquals(expectedResultCount, result.size());
+
+ partitionRdds(singletonRDD, spatialRDD);
+ spatialRDD.buildIndex(indexType, true);
+ result = JoinQuery.SpatialJoinQueryFlat(spatialRDD, singletonRDD, true, true).collect();
+ sanityCheckFlatJoinResults(result);
+ assertEquals(expectedResultCount, result.size());
+
+ partitionRdds(spatialRDD, singletonRDD);
+ singletonRDD.buildIndex(indexType, true);
+ result = JoinQuery.SpatialJoinQueryFlat(singletonRDD, spatialRDD, true, true).collect();
+ sanityCheckFlatJoinResults(result);
+ assertEquals(expectedResultCount, result.size());
+ }
+}
diff --git a/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java b/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java
index 1cffae1..812f200 100644
--- a/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java
+++ b/core/src/test/java/org/apache/sedona/core/utils/RDDSampleUtilsTest.java
@@ -41,6 +41,7 @@ public class RDDSampleUtilsTest
assertEquals(99, RDDSampleUtils.getSampleNumbers(6, 100011, 99));
assertEquals(999, RDDSampleUtils.getSampleNumbers(20, 999, -1));
assertEquals(40, RDDSampleUtils.getSampleNumbers(20, 1000, -1));
+ assertEquals(1, RDDSampleUtils.getSampleNumbers(1, 1, -1));
}
/**
@@ -52,6 +53,7 @@ public class RDDSampleUtilsTest
assertFailure(505, 999);
assertFailure(505, 1000);
assertFailure(10, 1000, 2100);
+ assertFailure(2, 1, -1);
}
private void assertFailure(int numPartitions, long totalNumberOfRecords)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
index 3b7a0e1..0e179b5 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryBase.scala
@@ -69,7 +69,9 @@ trait TraitJoinQueryBase {
def doSpatialPartitioning(dominantShapes: SpatialRDD[Geometry], followerShapes: SpatialRDD[Geometry],
numPartitions: Integer, sedonaConf: SedonaConf): Unit = {
- dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
- followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
+ if (dominantShapes.approximateTotalCount > 0) {
+ dominantShapes.spatialPartitioning(sedonaConf.getJoinGridType, numPartitions)
+ followerShapes.spatialPartitioning(dominantShapes.getPartitioner)
+ }
}
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
index 9850b78..d0f59ae 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, Predicate, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.SparkPlan
+import org.locationtech.jts.geom.Geometry
trait TraitJoinQueryExec extends TraitJoinQueryBase {
self: SparkPlan =>
@@ -123,11 +124,16 @@ trait TraitJoinQueryExec extends TraitJoinQueryBase {
//logInfo(s"leftShape count ${leftShapes.spatialPartitionedRDD.count()}")
//logInfo(s"rightShape count ${rightShapes.spatialPartitionedRDD.count()}")
- val matches = JoinQuery.spatialJoin(leftShapes, rightShapes, joinParams)
+ val matchesRDD: RDD[(Geometry, Geometry)] = (leftShapes.spatialPartitionedRDD, rightShapes.spatialPartitionedRDD) match {
+ case (null, null) =>
+ // Dominant side is empty, skipped creating partitioned RDDs. Result of join should also be empty.
+ sparkContext.parallelize(Seq[(Geometry, Geometry)]())
+ case _ => JoinQuery.spatialJoin(leftShapes, rightShapes, joinParams).rdd
+ }
- logDebug(s"Join result has ${matches.count()} rows")
+ logDebug(s"Join result has ${matchesRDD.count()} rows")
- matches.rdd.mapPartitions { iter =>
+ matchesRDD.mapPartitions { iter =>
val filtered =
if (extraCondition.isDefined) {
val boundCondition = Predicate.create(extraCondition.get, left.output ++ right.output) // SPARK3 anchor
diff --git a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
index dc6bb0a..688016b 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/predicateJoinTestScala.scala
@@ -22,6 +22,8 @@ package org.apache.sedona.sql
import org.apache.sedona.core.utils.SedonaConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
+import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.io.WKTWriter
class predicateJoinTestScala extends TestBaseScala {
@@ -125,6 +127,46 @@ class predicateJoinTestScala extends TestBaseScala {
assert(rangeJoinDf.count() == 1000)
}
+ it("Passed ST_Intersects in a join with singleton dataframe") {
+ var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPolygonInputLocation).limit(1).repartition(1)
+ polygonCsvDf.createOrReplaceTempView("polygontable")
+ var polygonDf = sparkSession.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable")
+ polygonDf.createOrReplaceTempView("polygondf")
+
+ var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
+ pointCsvDF.createOrReplaceTempView("pointtable")
+ var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable")
+ pointDf.createOrReplaceTempView("pointdf")
+
+ // Join with a singleton dataframe is essentially a range query
+ val polygon = polygonDf.first().getAs[Geometry]("polygonshape")
+ val rangeQueryDf = sparkSession.sql(s"select * from pointdf where ST_Intersects(pointdf.pointshape, ST_GeomFromWKT('${polygon.toString}'))")
+ val rangeQueryCount = rangeQueryDf.count()
+
+ // Perform spatial join and compare results
+ val rangeJoinDf1 = sparkSession.sql("select * from polygondf, pointdf where ST_Intersects(pointdf.pointshape, polygondf.polygonshape)")
+ val rangeJoinDf2 = sparkSession.sql("select * from pointdf, polygondf where ST_Intersects(polygondf.polygonshape, pointdf.pointshape)")
+ assert(rangeJoinDf1.count() == rangeQueryCount)
+ assert(rangeJoinDf2.count() == rangeQueryCount)
+ }
+
+ it("Passed ST_Intersects in a join with empty dataframe") {
+ var polygonCsvDf = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPolygonInputLocation)
+ polygonCsvDf.createOrReplaceTempView("polygontable")
+ var polygonDf = sparkSession.sql("select ST_PolygonFromEnvelope(cast(polygontable._c0 as Decimal(24,20)),cast(polygontable._c1 as Decimal(24,20)), cast(polygontable._c2 as Decimal(24,20)), cast(polygontable._c3 as Decimal(24,20))) as polygonshape from polygontable")
+ polygonDf.createOrReplaceTempView("polygondf")
+
+ var pointCsvDF = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
+ pointCsvDF.createOrReplaceTempView("pointtable")
+ var pointDf = sparkSession.sql("select ST_Point(cast(pointtable._c0 as Decimal(24,20)),cast(pointtable._c1 as Decimal(24,20))) as pointshape from pointtable where pointtable._c0 > 10000").repartition(1)
+ pointDf.createOrReplaceTempView("pointdf")
+
+ val rangeJoinDf1 = sparkSession.sql("select * from polygondf, pointdf where ST_Intersects(pointdf.pointshape, polygondf.polygonshape)")
+ val rangeJoinDf2 = sparkSession.sql("select * from pointdf, polygondf where ST_Intersects(polygondf.polygonshape, pointdf.pointshape)")
+ assert(rangeJoinDf1.count() == 0)
+ assert(rangeJoinDf2.count() == 0)
+ }
+
it("Passed ST_Distance <= radius in a join") {
var pointCsvDF1 = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(csvPointInputLocation)
pointCsvDF1.createOrReplaceTempView("pointtable")