You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/01/07 18:32:57 UTC
[sedona] branch master updated: [SEDONA-156] Support spatial filter push-down for GeoParquetFileFormat (#744)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 112596a3 [SEDONA-156] Support spatial filter push-down for GeoParquetFileFormat (#744)
112596a3 is described below
commit 112596a315577dc12d40e70fcda9b03ad38bd574
Author: Kristin Cowalcijk <mo...@yeah.net>
AuthorDate: Sun Jan 8 02:32:52 2023 +0800
[SEDONA-156] Support spatial filter push-down for GeoParquetFileFormat (#744)
---
docs/api/sql/Optimizer.md | 20 +-
docs/image/geoparquet-pred-pushdown.png | Bin 0 -> 39204 bytes
docs/image/scan-parquet-with-spatial-pred.png | Bin 0 -> 37681 bytes
docs/image/scan-parquet-without-spatial-pred.png | Bin 0 -> 38213 bytes
docs/tutorial/sql.md | 2 +
mkdocs.yml | 2 +-
python/tests/__init__.py | 3 +-
python/tests/sql/test_geoparquet.py | 8 +
.../sedona/sql/utils/SedonaSQLRegistrator.scala | 8 +-
.../datasources/parquet/GeoParquetFileFormat.scala | 139 ++++++------
.../datasources/parquet/GeoParquetMetaData.scala | 12 +-
.../parquet/GeoParquetSchemaConverter.scala | 10 +-
.../parquet/GeoParquetSpatialFilter.scala | 68 ++++++
.../SpatialFilterPushDownForGeoParquet.scala | 190 ++++++++++++++++
.../sql/GeoParquetSpatialFilterPushDownSuite.scala | 248 +++++++++++++++++++++
.../org/apache/sedona/sql/geoparquetIOTests.scala | 10 +
16 files changed, 644 insertions(+), 76 deletions(-)
diff --git a/docs/api/sql/Optimizer.md b/docs/api/sql/Optimizer.md
index 35f05878..84c2e351 100644
--- a/docs/api/sql/Optimizer.md
+++ b/docs/api/sql/Optimizer.md
@@ -140,4 +140,22 @@ RangeJoin polygonshape#20: geometry, pointshape#43: geometry, false
: +- *FileScan csv
+- Project [st_point(cast(_c0#31 as decimal(24,20)), cast(_c1#32 as decimal(24,20)), myPointId) AS pointshape#43]
+- *FileScan csv
-```
\ No newline at end of file
+```
+
+### GeoParquet
+
+Sedona supports spatial predicate push-down for GeoParquet files. When spatial filters were applied to dataframes backed by GeoParquet files, Sedona will use the
+[`bbox` properties in the metadata](https://github.com/opengeospatial/geoparquet/blob/v1.0.0-beta.1/format-specs/geoparquet.md#bbox)
+to determine if all data in the file will be discarded by the spatial predicate. This optimization could reduce the number of files scanned
+when the queried GeoParquet dataset was partitioned by spatial proximity.
+
+The following figure is the visualization of a GeoParquet dataset. `bbox`es of all GeoParquet files were plotted as blue rectangles and the query window was plotted as a red rectangle. Sedona will only scan 1 of the 6 files to
+answer queries such as `SELECT * FROM geoparquet_dataset WHERE ST_Intersects(geom, <query window>)`, thus only part of the data covered by the light green rectangle needs to be scanned.
+
+![](../../image/geoparquet-pred-pushdown.png)
+
+We can compare the metrics of querying the GeoParquet dataset with or without the spatial predicate and observe that querying with spatial predicate results in fewer number of rows scanned.
+
+| Without spatial predicate | With spatial predicate |
+| ----------- | ----------- |
+| ![](../../image/scan-parquet-without-spatial-pred.png) | ![](../../image/scan-parquet-with-spatial-pred.png) |
diff --git a/docs/image/geoparquet-pred-pushdown.png b/docs/image/geoparquet-pred-pushdown.png
new file mode 100644
index 00000000..7dd7f64b
Binary files /dev/null and b/docs/image/geoparquet-pred-pushdown.png differ
diff --git a/docs/image/scan-parquet-with-spatial-pred.png b/docs/image/scan-parquet-with-spatial-pred.png
new file mode 100644
index 00000000..cd0b48b7
Binary files /dev/null and b/docs/image/scan-parquet-with-spatial-pred.png differ
diff --git a/docs/image/scan-parquet-without-spatial-pred.png b/docs/image/scan-parquet-without-spatial-pred.png
new file mode 100644
index 00000000..92ae14b0
Binary files /dev/null and b/docs/image/scan-parquet-without-spatial-pred.png differ
diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md
index 07732f7b..456e52bc 100644
--- a/docs/tutorial/sql.md
+++ b/docs/tutorial/sql.md
@@ -156,6 +156,8 @@ root
|-- geometry: geometry (nullable = true)
```
+Sedona supports spatial predicate push-down for GeoParquet files, please refer to the [SedonaSQL query optimizer](../api/sql/Optimizer.md) documentation for details.
+
## Transform the Coordinate Reference System
Sedona doesn't control the coordinate unit (degree-based or meter-based) of all geometries in a Geometry column. The unit of all related distances in SedonaSQL is same as the unit of all geometries in a Geometry column.
diff --git a/mkdocs.yml b/mkdocs.yml
index b73cf245..afb9cb1d 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -59,7 +59,7 @@ nav:
- Function: api/sql/Function.md
- Predicate: api/sql/Predicate.md
- Aggregate function: api/sql/AggregateFunction.md
- - Join query (optimizer): api/sql/Optimizer.md
+ - SedonaSQL query optimizer: api/sql/Optimizer.md
- Raster data:
- Raster input and output: api/sql/Raster-loader.md
- Raster operators: api/sql/Raster-operators.md
diff --git a/python/tests/__init__.py b/python/tests/__init__.py
index 6bb0a4be..1695ee2a 100644
--- a/python/tests/__init__.py
+++ b/python/tests/__init__.py
@@ -43,4 +43,5 @@ overlap_polygon_input_location = os.path.join(tests_resource, "testenvelope_over
union_polygon_input_location = os.path.join(tests_resource, "testunion.csv")
csv_point1_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_point1.csv")
csv_point2_input_location = os.path.join(tests_resource, "equalitycheckfiles/testequals_point2.csv")
-geojson_id_input_location = os.path.join(tests_resource, "testContainsId.json")
\ No newline at end of file
+geojson_id_input_location = os.path.join(tests_resource, "testContainsId.json")
+geoparquet_input_location = os.path.join(tests_resource, "geoparquet/example1.parquet")
diff --git a/python/tests/sql/test_geoparquet.py b/python/tests/sql/test_geoparquet.py
index 00e1a160..c12433bd 100644
--- a/python/tests/sql/test_geoparquet.py
+++ b/python/tests/sql/test_geoparquet.py
@@ -24,6 +24,7 @@ from shapely.wkt import loads as wkt_loads
import geopandas
from tests.test_base import TestBase
+from tests import geoparquet_input_location
class TestGeoParquet(TestBase):
@@ -53,3 +54,10 @@ class TestGeoParquet(TestBase):
assert df2.count() == 2
row = df2.collect()[0]
assert isinstance(row['g'], BaseGeometry)
+
+ def test_load_geoparquet_with_spatial_filter(self):
+ df = self.spark.read.format("geoparquet").load(geoparquet_input_location)\
+ .where("ST_Contains(geometry, ST_GeomFromText('POINT (35.174722 -6.552465)'))")
+ rows = df.collect()
+ assert len(rows) == 1
+ assert rows[0]['name'] == 'Tanzania'
diff --git a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala b/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
index 3e60627b..91871eb9 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/utils/SedonaSQLRegistrator.scala
@@ -20,6 +20,7 @@ package org.apache.sedona.sql.utils
import org.apache.sedona.sql.UDF.UdfRegistrator
import org.apache.sedona.sql.UDT.UdtRegistrator
+import org.apache.spark.sql.sedona_sql.optimization.SpatialFilterPushDownForGeoParquet
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
@@ -29,7 +30,12 @@ object SedonaSQLRegistrator {
}
def registerAll(sparkSession: SparkSession): Unit = {
- sparkSession.experimental.extraStrategies = new JoinQueryDetector(sparkSession) :: Nil
+ if (!sparkSession.experimental.extraStrategies.exists(_.isInstanceOf[JoinQueryDetector])) {
+ sparkSession.experimental.extraStrategies ++= Seq(new JoinQueryDetector(sparkSession))
+ }
+ if (!sparkSession.experimental.extraOptimizations.exists(_.isInstanceOf[SpatialFilterPushDownForGeoParquet])) {
+ sparkSession.experimental.extraOptimizations ++= Seq(new SpatialFilterPushDownForGeoParquet(sparkSession))
+ }
UdtRegistrator.registerAll()
UdfRegistrator.registerAll(sparkSession)
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
index c9fac3f5..279efbbd 100644
--- a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
@@ -1,5 +1,4 @@
-/**
- *
+/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@@ -46,14 +45,21 @@ import scala.collection.JavaConverters._
import scala.util.Failure
import scala.util.Try
-class GeoParquetFileFormat extends ParquetFileFormat with FileFormat with DataSourceRegister with Logging with Serializable {
+class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter])
+ extends ParquetFileFormat with FileFormat with DataSourceRegister with Logging with Serializable {
+
+ def this() = this(None)
override def shortName(): String = "geoparquet"
- override def equals(other: Any): Boolean = other.isInstanceOf[GeoParquetFileFormat]
+ override def equals(other: Any): Boolean = other.isInstanceOf[GeoParquetFileFormat] &&
+ other.asInstanceOf[GeoParquetFileFormat].spatialFilter == spatialFilter
override def hashCode(): Int = getClass.hashCode()
+ def withSpatialPredicates(spatialFilter: GeoParquetSpatialFilter): GeoParquetFileFormat =
+ new GeoParquetFileFormat(Some(spatialFilter))
+
override def inferSchema(
sparkSession: SparkSession,
parameters: Map[String, String],
@@ -224,7 +230,7 @@ class GeoParquetFileFormat extends ParquetFileFormat with FileFormat with DataSo
val sharedConf = broadcastedHadoopConf.value.value
- lazy val footerFileMetaData =
+ val footerFileMetaData =
ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData
// Try to push down filters when filter push-down is enabled.
val pushed = if (enableParquetFilterPushDown) {
@@ -241,67 +247,76 @@ class GeoParquetFileFormat extends ParquetFileFormat with FileFormat with DataSo
None
}
- // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps'
- // *only* if the file was created by something other than "parquet-mr", so check the actual
- // writer here for this file. We have to do this per-file, as each file in the table may
- // have different writers.
- // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads.
- def isCreatedByParquetMr: Boolean =
- footerFileMetaData.getCreatedBy().startsWith("parquet-mr")
-
- val convertTz =
- if (timestampConversion && !isCreatedByParquetMr) {
- Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key)))
- } else {
- None
- }
- val datetimeRebaseMode = GeoDataSourceUtils.datetimeRebaseMode(
- footerFileMetaData.getKeyValueMetaData.get,
- SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_READ))
- val int96RebaseMode = GeoDataSourceUtils.int96RebaseMode(
- footerFileMetaData.getKeyValueMetaData.get,
- SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_READ))
-
- val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
- val hadoopAttemptContext =
- new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId)
-
- // Try to push down filters when filter push-down is enabled.
- // Notice: This push-down is RowGroups level, not individual records.
- if (pushed.isDefined) {
- ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
- }
- val taskContext = Option(TaskContext.get())
- if (enableVectorizedReader) {
- logWarning(s"GeoParquet currently does not support vectorized reader. Falling back to parquet-mr")
+ // Prune file scans using pushed down spatial filters and per-column bboxes in geoparquet metadata
+ val shouldScanFile = GeoParquetMetaData.parseKeyValueMetaData(footerFileMetaData.getKeyValueMetaData).forall {
+ metadata => spatialFilter.forall(_.evaluate(metadata.columns))
}
- logDebug(s"Falling back to parquet-mr")
- // ParquetRecordReader returns InternalRow
- val readSupport = new GeoParquetReadSupport(
- convertTz,
- enableVectorizedReader = false,
- datetimeRebaseMode,
- int96RebaseMode)
- val reader = if (pushed.isDefined && enableRecordFilter) {
- val parquetFilter = FilterCompat.get(pushed.get, null)
- new ParquetRecordReader[InternalRow](readSupport, parquetFilter)
+ if (!shouldScanFile) {
+ // The entire file is pruned so that we don't need to scan this file.
+ Seq.empty[InternalRow].iterator
} else {
- new ParquetRecordReader[InternalRow](readSupport)
- }
- val iter = new RecordReaderIterator[InternalRow](reader)
- // SPARK-23457 Register a task completion listener before `initialization`.
- taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
- reader.initialize(split, hadoopAttemptContext)
+ // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps'
+ // *only* if the file was created by something other than "parquet-mr", so check the actual
+ // writer here for this file. We have to do this per-file, as each file in the table may
+ // have different writers.
+ // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads.
+ def isCreatedByParquetMr: Boolean =
+ footerFileMetaData.getCreatedBy().startsWith("parquet-mr")
+
+ val convertTz =
+ if (timestampConversion && !isCreatedByParquetMr) {
+ Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key)))
+ } else {
+ None
+ }
+ val datetimeRebaseMode = GeoDataSourceUtils.datetimeRebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get,
+ SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_REBASE_MODE_IN_READ))
+ val int96RebaseMode = GeoDataSourceUtils.int96RebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get,
+ SQLConf.get.getConfString(GeoDataSourceUtils.PARQUET_INT96_REBASE_MODE_IN_READ))
+
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext =
+ new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId)
+
+ // Try to push down filters when filter push-down is enabled.
+ // Notice: This push-down is RowGroups level, not individual records.
+ if (pushed.isDefined) {
+ ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
+ }
+ val taskContext = Option(TaskContext.get())
+ if (enableVectorizedReader) {
+ logWarning(s"GeoParquet currently does not support vectorized reader. Falling back to parquet-mr")
+ }
+ logDebug(s"Falling back to parquet-mr")
+ // ParquetRecordReader returns InternalRow
+ val readSupport = new GeoParquetReadSupport(
+ convertTz,
+ enableVectorizedReader = false,
+ datetimeRebaseMode,
+ int96RebaseMode)
+ val reader = if (pushed.isDefined && enableRecordFilter) {
+ val parquetFilter = FilterCompat.get(pushed.get, null)
+ new ParquetRecordReader[InternalRow](readSupport, parquetFilter)
+ } else {
+ new ParquetRecordReader[InternalRow](readSupport)
+ }
+ val iter = new RecordReaderIterator[InternalRow](reader)
+ // SPARK-23457 Register a task completion listener before `initialization`.
+ taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
+ reader.initialize(split, hadoopAttemptContext)
- val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
- val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+ val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
- if (partitionSchema.length == 0) {
- // There is no partition columns
- iter.map(unsafeProjection)
- } else {
- val joinedRow = new JoinedRow()
- iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues)))
+ if (partitionSchema.length == 0) {
+ // There is no partition columns
+ iter.map(unsafeProjection)
+ } else {
+ val joinedRow = new JoinedRow()
+ iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues)))
+ }
}
}
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetMetaData.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetMetaData.scala
index c2419410..a3ea9ec4 100644
--- a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetMetaData.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetMetaData.scala
@@ -1,5 +1,4 @@
-/**
- *
+/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@@ -14,6 +13,8 @@
*/
package org.apache.spark.sql.execution.datasources.parquet
+import org.json4s.jackson.JsonMethods.parse
+
/**
* A case class that holds CRS metadata for geometry columns. This class is left empty since CRS
* metadata was not implemented yet.
@@ -48,4 +49,11 @@ object GeoParquetMetaData {
// https://github.com/opengeospatial/geoparquet/blob/v1.0.0-beta.1/format-specs/geoparquet.md
// for more details.
val VERSION = "1.0.0-beta.1"
+
+ def parseKeyValueMetaData(keyValueMetaData: java.util.Map[String, String]): Option[GeoParquetMetaData] = {
+ Option(keyValueMetaData.get("geo")).map { geo =>
+ implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats
+ parse(geo).camelizeKeys.extract[GeoParquetMetaData]
+ }
+ }
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
index 3ca1d948..7afdc043 100644
--- a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
@@ -28,7 +28,6 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types._
-import org.json4s.jackson.JsonMethods.parse
/**
* This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]].
@@ -48,13 +47,8 @@ class GeoParquetToSparkSchemaConverter(
assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get) {
- private val geoParquetMetaData: GeoParquetMetaData = {
- val geo = keyValueMetaData.get("geo")
- if (geo == null) {
- throw new AnalysisException("GeoParquet file does not contain valid geo metadata")
- }
- implicit val formats: org.json4s.Formats = org.json4s.DefaultFormats
- parse(geo).camelizeKeys.extract[GeoParquetMetaData]
+ private val geoParquetMetaData: GeoParquetMetaData = GeoParquetMetaData.parseKeyValueMetaData(keyValueMetaData).getOrElse {
+ throw new AnalysisException("GeoParquet file does not contain valid geo metadata")
}
def this(keyValueMetaData: java.util.Map[String, String], conf: SQLConf) = this(
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala
new file mode 100644
index 00000000..c100be6b
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSpatialFilter.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.sedona.core.spatialOperator.SpatialPredicate
+import org.locationtech.jts.geom.Envelope
+import org.locationtech.jts.geom.Geometry
+
+/**
+ * Filters containing spatial predicates such as `ST_Within(geom, ST_GeomFromText(...))` will be converted
+ * to [[GeoParquetSpatialFilter]] and get pushed down to [[GeoParquetFileFormat]] by
+ * [[org.apache.spark.sql.sedona_sql.optimization.SpatialFilterPushDownForGeoParquet]].
+ */
+trait GeoParquetSpatialFilter {
+ def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean
+}
+
+object GeoParquetSpatialFilter {
+
+ case class AndFilter(left: GeoParquetSpatialFilter, right: GeoParquetSpatialFilter) extends GeoParquetSpatialFilter {
+ override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean =
+ left.evaluate(columns) && right.evaluate(columns)
+ }
+
+ case class OrFilter(left: GeoParquetSpatialFilter, right: GeoParquetSpatialFilter) extends GeoParquetSpatialFilter {
+ override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean =
+ left.evaluate(columns) || right.evaluate(columns)
+ }
+
+ /**
+ * Spatial predicate pushed down to GeoParquet data source. We'll use the bbox in column metadata to prune
+ * unrelated files.
+ *
+ * @param columnName name of filtered geometry column
+ * @param predicateType type of spatial predicate, should be one of COVERS and INTERSECTS
+ * @param queryWindow query window
+ */
+ case class LeafFilter(
+ columnName: String,
+ predicateType: SpatialPredicate,
+ queryWindow: Geometry) extends GeoParquetSpatialFilter {
+ def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean = {
+ columns.get(columnName).forall { column =>
+ val bbox = column.bbox
+ val columnEnvelope = queryWindow.getFactory.toGeometry(new Envelope(bbox(0), bbox(2), bbox(1), bbox(3)))
+ predicateType match {
+ case SpatialPredicate.COVERS => columnEnvelope.covers(queryWindow)
+ case SpatialPredicate.INTERSECTS =>
+ // XXX: We must call the intersects method of queryWindow instead of columnEnvelope, since queryWindow
+ // may be a Circle object and geom.intersects(circle) may not work correctly.
+ queryWindow.intersects(columnEnvelope)
+ case _ => throw new IllegalArgumentException(s"Unexpected predicate type: $predicateType")
+ }
+ }
+ }
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala
new file mode 100644
index 00000000..e0262bd9
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.sedona_sql.optimization
+
+import org.apache.sedona.common.geometryObjects.Circle
+import org.apache.sedona.core.spatialOperator.SpatialPredicate
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.And
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.LessThan
+import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.Not
+import org.apache.spark.sql.catalyst.expressions.Or
+import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
+import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
+import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.execution.datasources.HadoopFsRelation
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.datasources.PushableColumn
+import org.apache.spark.sql.execution.datasources.PushableColumnBase
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter.AndFilter
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter.LeafFilter
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter.OrFilter
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.expressions.ST_Contains
+import org.apache.spark.sql.sedona_sql.expressions.ST_CoveredBy
+import org.apache.spark.sql.sedona_sql.expressions.ST_Covers
+import org.apache.spark.sql.sedona_sql.expressions.ST_Crosses
+import org.apache.spark.sql.sedona_sql.expressions.ST_Distance
+import org.apache.spark.sql.sedona_sql.expressions.ST_Equals
+import org.apache.spark.sql.sedona_sql.expressions.ST_Intersects
+import org.apache.spark.sql.sedona_sql.expressions.ST_OrderingEquals
+import org.apache.spark.sql.sedona_sql.expressions.ST_Overlaps
+import org.apache.spark.sql.sedona_sql.expressions.ST_Touches
+import org.apache.spark.sql.sedona_sql.expressions.ST_Within
+import org.apache.spark.sql.types.DoubleType
+import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.geom.Point
+
+class SpatialFilterPushDownForGeoParquet(sparkSession: SparkSession) extends Rule[LogicalPlan] {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case filter@Filter(condition, lr: LogicalRelation) if isGeoParquetRelation(lr) =>
+ val filters = splitConjunctivePredicates(condition)
+ val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, lr.output)
+ val (_, normalizedFiltersWithoutSubquery) = normalizedFilters.partition(SubqueryExpression.hasSubquery)
+ val geoParquetSpatialFilters = translateToGeoParquetSpatialFilters(normalizedFiltersWithoutSubquery)
+ val hadoopFsRelation = lr.relation.asInstanceOf[HadoopFsRelation]
+ val fileFormat = hadoopFsRelation.fileFormat.asInstanceOf[GeoParquetFileFormat]
+ if (geoParquetSpatialFilters.isEmpty) filter else {
+ val combinedSpatialFilter = geoParquetSpatialFilters.reduce(AndFilter)
+ val newFileFormat = fileFormat.withSpatialPredicates(combinedSpatialFilter)
+ val newRelation = hadoopFsRelation.copy(fileFormat = newFileFormat)(sparkSession)
+ filter.copy(child = lr.copy(relation = newRelation))
+ }
+ }
+
+ private def isGeoParquetRelation(lr: LogicalRelation): Boolean =
+ lr.relation.isInstanceOf[HadoopFsRelation] &&
+ lr.relation.asInstanceOf[HadoopFsRelation].fileFormat.isInstanceOf[GeoParquetFileFormat]
+
+ private def translateToGeoParquetSpatialFilters(predicates: Seq[Expression]): Seq[GeoParquetSpatialFilter] = {
+ val pushableColumn = PushableColumn(nestedPredicatePushdownEnabled = false)
+ predicates.flatMap { predicate => translateToGeoParquetSpatialFilter(predicate, pushableColumn) }
+ }
+
+ private def translateToGeoParquetSpatialFilter(
+ predicate: Expression,
+ pushableColumn: PushableColumnBase): Option[GeoParquetSpatialFilter] = {
+ predicate match {
+ case And(left, right) =>
+ val spatialFilterLeft = translateToGeoParquetSpatialFilter(left, pushableColumn)
+ val spatialFilterRight = translateToGeoParquetSpatialFilter(right, pushableColumn)
+ (spatialFilterLeft, spatialFilterRight) match {
+ case (Some(l), Some(r)) => Some(AndFilter(l, r))
+ case (Some(l), None) => Some(l)
+ case (None, Some(r)) => Some(r)
+ case _ => None
+ }
+
+ case Or(left, right) =>
+ for {
+ spatialFilterLeft <- translateToGeoParquetSpatialFilter(left, pushableColumn)
+ spatialFilterRight <- translateToGeoParquetSpatialFilter(right, pushableColumn)
+ } yield OrFilter(spatialFilterLeft, spatialFilterRight)
+
+ case Not(_) => None
+
+ case ST_Contains(Seq(pushableColumn(name), Literal(v, _))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
+ case ST_Contains(Seq(Literal(v, _), pushableColumn(name))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
+
+ case ST_Covers(Seq(pushableColumn(name), Literal(v, _))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
+ case ST_Covers(Seq(Literal(v, _), pushableColumn(name))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
+
+ case ST_Within(Seq(pushableColumn(name), Literal(v, _))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
+ case ST_Within(Seq(Literal(v, _), pushableColumn(name))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
+
+ case ST_CoveredBy(Seq(pushableColumn(name), Literal(v, _))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(v)))
+ case ST_CoveredBy(Seq(Literal(v, _), pushableColumn(name))) =>
+ Some(LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(v)))
+
+ case ST_Equals(_) | ST_OrderingEquals(_) =>
+ for ((name, value) <- resolveNameAndLiteral(predicate.children, pushableColumn))
+ yield LeafFilter(unquote(name), SpatialPredicate.COVERS, GeometryUDT.deserialize(value))
+
+ case ST_Intersects(_) | ST_Crosses(_) | ST_Overlaps(_) | ST_Touches(_) =>
+ for ((name, value) <- resolveNameAndLiteral(predicate.children, pushableColumn))
+ yield LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(value))
+
+ case LessThan(ST_Distance(distArgs), Literal(d, DoubleType)) =>
+ for ((name, value) <- resolveNameAndLiteral(distArgs, pushableColumn))
+ yield distanceFilter(name, GeometryUDT.deserialize(value), d.asInstanceOf[Double])
+
+ case LessThanOrEqual(ST_Distance(distArgs), Literal(d, DoubleType)) =>
+ for ((name, value) <- resolveNameAndLiteral(distArgs, pushableColumn))
+ yield distanceFilter(name, GeometryUDT.deserialize(value), d.asInstanceOf[Double])
+
+ case _ => None
+ }
+ }
+
+ private def distanceFilter(name: String, geom: Geometry, distance: Double) = {
+ val queryWindow = geom match {
+ case point: Point => new Circle(point, distance)
+ case _ =>
+ val envelope = geom.getEnvelopeInternal
+ envelope.expandBy(distance)
+ geom.getFactory.toGeometry(envelope)
+ }
+ LeafFilter(unquote(name), SpatialPredicate.INTERSECTS, queryWindow)
+ }
+
+ private def unquote(name: String): String = {
+ parseColumnPath(name).mkString(".")
+ }
+
+ private def resolveNameAndLiteral(expressions: Seq[Expression], pushableColumn: PushableColumnBase): Option[(String, Any)] = {
+ expressions match {
+ case Seq(pushableColumn(name), Literal(v, _)) => Some(name, v)
+ case Seq(Literal(v, _), pushableColumn(name)) => Some(name, v)
+ case _ => None
+ }
+ }
+
+ /**
+ * This is a polyfill for running on Spark 3.0 while compiling against Spark 3.3. We'd really like to mixin
+ * `PredicateHelper` here, but the class hierarchy of `PredicateHelper` has changed between Spark 3.0 and 3.3 so
+ * it would raise `java.lang.ClassNotFoundException: org.apache.spark.sql.catalyst.expressions.AliasHelper`
+ * at runtime on Spark 3.0.
+ * @param condition filter condition to split
+ * @return A list of conjunctive conditions
+ */
+ private def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
+ condition match {
+ case And(cond1, cond2) =>
+ splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
+ case other => other :: Nil
+ }
+ }
+}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/sql/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
new file mode 100644
index 00000000..d4cc9d75
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.parquet.hadoop.ParquetFileReader
+import org.apache.parquet.hadoop.util.HadoopInputFile
+import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.generateTestData
+import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.readGeoParquetMetaDataMap
+import org.apache.sedona.sql.GeoParquetSpatialFilterPushDownSuite.writeTestDataAsGeoParquet
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.FileSourceScanExec
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetMetaData
+import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
+import org.locationtech.jts.geom.Coordinate
+import org.locationtech.jts.geom.Geometry
+import org.locationtech.jts.geom.GeometryFactory
+import org.scalatest.prop.TableDrivenPropertyChecks
+
+import java.io.File
+import java.nio.file.Files
+
+class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrivenPropertyChecks {
+
+ val tempDir: String = Files.createTempDirectory("sedona_geoparquet_test_").toFile.getAbsolutePath
+ val geoParquetDir: String = tempDir + "/geoparquet"
+ var df: DataFrame = _
+ var geoParquetDf: DataFrame = _
+ var geoParquetMetaDataMap: Map[Int, Seq[GeoParquetMetaData]] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ df = generateTestData(sparkSession)
+ writeTestDataAsGeoParquet(df, geoParquetDir)
+ geoParquetDf = sparkSession.read.format("geoparquet").load(geoParquetDir)
+ geoParquetMetaDataMap = readGeoParquetMetaDataMap(geoParquetDir)
+ }
+
+ override def afterAll(): Unit = FileUtils.deleteDirectory(new File(tempDir))
+
+ describe("GeoParquet spatial filter push down tests") {
+ it("Push down ST_Contains") {
+ testFilter("ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(1))
+ testFilter("ST_Contains(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", Seq(0))
+ testFilter("ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", Seq.empty)
+ testFilter("ST_Contains(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3))
+ testFilter("ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", Seq(3))
+ testFilter("ST_Contains(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", Seq.empty)
+ }
+
+ it("Push down ST_Covers") {
+ testFilter("ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(1))
+ testFilter("ST_Covers(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", Seq(0))
+ testFilter("ST_Covers(ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'), geom)", Seq.empty)
+ testFilter("ST_Covers(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3))
+ testFilter("ST_Covers(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", Seq(3))
+ testFilter("ST_Covers(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", Seq.empty)
+ }
+
+ it("Push down ST_Within") {
+ testFilter("ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", Seq(1))
+ testFilter("ST_Within(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", Seq(0))
+ testFilter("ST_Within(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", Seq.empty)
+ testFilter("ST_Within(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3))
+ testFilter("ST_Within(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", Seq(3))
+ testFilter("ST_Within(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", Seq.empty)
+ }
+
+ it("Push down ST_CoveredBy") {
+ testFilter("ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'))", Seq(1))
+ testFilter("ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'))", Seq(0))
+ testFilter("ST_CoveredBy(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", Seq.empty)
+ testFilter("ST_CoveredBy(ST_GeomFromText('POINT (15 -15)'), geom)", Seq(3))
+ testFilter("ST_CoveredBy(ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'), geom)", Seq(3))
+ testFilter("ST_CoveredBy(ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'), geom)", Seq.empty)
+ }
+
+ it("Push down ST_Intersects") {
+ testFilter("ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(1))
+ testFilter("ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", Seq(0))
+ testFilter("ST_Intersects(geom, ST_GeomFromText('POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))'))", Seq.empty)
+ testFilter("ST_Intersects(geom, ST_GeomFromText('POINT (15 -15)'))", Seq(3))
+ testFilter("ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", Seq(3))
+ testFilter("ST_Intersects(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", Seq(3))
+ testFilter("ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", Seq(1, 3))
+ }
+
+ it("Push down ST_Equals") {
+ testFilter("ST_Equals(geom, ST_GeomFromText('POLYGON ((-16 -16, -16 -14, -14 -14, -14 -16, -16 -16))'))", Seq(2))
+ testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-15 -15)'))", Seq(2))
+ testFilter("ST_Equals(geom, ST_GeomFromText('POINT (-16 -16)'))", Seq(2))
+ testFilter("ST_Equals(geom, ST_GeomFromText('POLYGON ((1 -5, 5 -5, 5 -1, 1 -1, 1 -5))'))", Seq.empty)
+ }
+
+ forAll(Table("<", "<=")) { op =>
+ it(s"Push down ST_Distance $op d") {
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 1", Seq.empty)
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 5", Seq.empty)
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (3 4)')) $op 1", Seq(1))
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (0 0)')) $op 7.1", Seq(0, 1, 2, 3))
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) $op 1", Seq(2))
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 2", Seq.empty)
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))')) $op 3", Seq(0, 1, 2, 3))
+ testFilter(s"ST_Distance(geom, ST_GeomFromText('LINESTRING (17 17, 18 18)')) $op 1", Seq(1))
+ }
+ }
+
+ it("Push down And(filters...)") {
+ testFilter("ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(1))
+ testFilter("ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))')) AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))'))", Seq(3))
+ }
+
+ it("Push down Or(filters...)") {
+ testFilter("ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom) OR ST_Intersects(ST_GeomFromText('POLYGON ((-16 14, -16 16, -14 16, -14 14, -16 14))'), geom)", Seq(0, 1))
+ testFilter("ST_Distance(geom, ST_GeomFromText('POINT (-5 -5)')) <= 1 OR ST_Intersects(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(1, 2))
+ }
+
+ it("Ignore negated spatial filters") {
+ testFilter("NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(0, 1, 2, 3))
+ testFilter("ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) AND NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(3))
+ testFilter("ST_Contains(geom, ST_GeomFromText('POLYGON ((4 -5, 5 -5, 5 -4, 4 -4, 4 -5))')) OR NOT ST_Contains(ST_GeomFromText('POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0))'), geom)", Seq(0, 1, 2, 3))
+ }
+
+ it("Mixed spatial filter with other filter") {
+ testFilter("id < 10 AND ST_Intersects(geom, ST_GeomFromText('POLYGON ((5 -5, 15 -5, 15 5, 5 5, 5 -5))'))", Seq(1, 3))
+ }
+ }
+
+ /**
+ * Test filter push down using specified query condition, and verify if the pushed down filter prunes regions as
+ * expected. We'll also verify the correctness of query results.
+ * @param condition SQL query condition
+ * @param expectedPreservedRegions Regions that should be preserved after filter push down
+ */
+ private def testFilter(condition: String, expectedPreservedRegions: Seq[Int]): Unit = {
+ val dfFiltered = geoParquetDf.where(condition)
+ val preservedRegions = getPushedDownSpatialFilter(dfFiltered) match {
+ case Some(spatialFilter) => resolvePreservedRegions(spatialFilter)
+ case None => (0 until 4)
+ }
+ assert(expectedPreservedRegions == preservedRegions)
+ val expectedResult = df.where(condition).orderBy("region", "id").select("region", "id").collect()
+ val actualResult = dfFiltered.orderBy("region", "id").select("region", "id").collect()
+ assert(expectedResult sameElements actualResult)
+ }
+
+ private def getPushedDownSpatialFilter(df: DataFrame): Option[GeoParquetSpatialFilter] = {
+ val executedPlan = df.queryExecution.executedPlan
+ val fileSourceScanExec = executedPlan.find(_.isInstanceOf[FileSourceScanExec])
+ assert(fileSourceScanExec.isDefined)
+ val fileFormat = fileSourceScanExec.get.asInstanceOf[FileSourceScanExec].relation.fileFormat
+ assert(fileFormat.isInstanceOf[GeoParquetFileFormat])
+ fileFormat.asInstanceOf[GeoParquetFileFormat].spatialFilter
+ }
+
+ private def resolvePreservedRegions(spatialFilter: GeoParquetSpatialFilter): Seq[Int] = {
+ geoParquetMetaDataMap.filter { case (_, metaDataList) =>
+ metaDataList.exists(metadata => spatialFilter.evaluate(metadata.columns))
+ }.keys.toSeq
+ }
+}
+
+object GeoParquetSpatialFilterPushDownSuite {
+ case class TestDataItem(id: Int, region: Int, geom: Geometry)
+
+ /**
+ * Generate test data centered at (0, 0). The entire dataset was divided into 4 quadrants, each with a unique
+ * region ID. The dataset contains 4 points and 4 polygons in each quadrant.
+ * @param sparkSession SparkSession object
+ * @return DataFrame containing test data
+ */
+ def generateTestData(sparkSession: SparkSession): DataFrame = {
+ import sparkSession.implicits._
+ val regionCenters = Seq((-10, 10), (10, 10), (-10, -10), (10, -10))
+ val testData = regionCenters.zipWithIndex.flatMap { case ((x, y), i) => generateTestDataForRegion(i, x, y) }
+ testData.toDF()
+ }
+
+ private def generateTestDataForRegion(region: Int, centerX: Double, centerY: Double) = {
+ val factory = new GeometryFactory()
+ val points = Seq(
+ factory.createPoint(new Coordinate(centerX - 5, centerY + 5)),
+ factory.createPoint(new Coordinate(centerX + 5, centerY + 5)),
+ factory.createPoint(new Coordinate(centerX - 5, centerY - 5)),
+ factory.createPoint(new Coordinate(centerX + 5, centerY - 5))
+ )
+ val polygons = points.map { p =>
+ val envelope = p.getEnvelopeInternal
+ envelope.expandBy(1)
+ factory.toGeometry(envelope)
+ }
+ (points ++ polygons).zipWithIndex.map { case (g, i) => TestDataItem(i, region, g) }
+ }
+
+ /**
+ * Write the test dataframe as GeoParquet files. Each region is written to a separate file. We'll test spatial
+ * filter push down by examining which regions were preserved/pruned by evaluating the pushed down spatial filters
+ * @param testData dataframe containing test data
+ * @param path path to write GeoParquet files
+ */
+ def writeTestDataAsGeoParquet(testData: DataFrame, path: String): Unit = {
+ testData.coalesce(1).write.partitionBy("region").format("geoparquet").save(path)
+ }
+
+ /**
+ * Load GeoParquet metadata for each region. Note that there could be multiple files for each region, thus each
+ * region ID was associated with a list of GeoParquet metadata.
+ * @param path path to directory containing GeoParquet files
+ * @return Map of region ID to list of GeoParquet metadata
+ */
+ def readGeoParquetMetaDataMap(path: String): Map[Int, Seq[GeoParquetMetaData]] = {
+ (0 until 4).map { k =>
+ val geoParquetMetaDataSeq = readGeoParquetMetaDataByRegion(path, k)
+ k -> geoParquetMetaDataSeq
+ }.toMap
+ }
+
+ private def readGeoParquetMetaDataByRegion(geoParquetSavePath: String, region: Int): Seq[GeoParquetMetaData] = {
+ val parquetFiles = new File(geoParquetSavePath + s"/region=$region").listFiles().filter(_.getName.endsWith(".parquet"))
+ parquetFiles.flatMap { filePath =>
+ val metadata = ParquetFileReader.open(
+ HadoopInputFile.fromPath(new Path(filePath.getPath), new Configuration()))
+ .getFooter.getFileMetaData.getKeyValueMetaData
+ assert(metadata.containsKey("geo"))
+ GeoParquetMetaData.parseKeyValueMetaData(metadata)
+ }
+ }
+}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/sql/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
index cdb15f55..14b334c7 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
@@ -26,7 +26,10 @@ import org.apache.parquet.hadoop.util.HadoopInputFile
import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.functions.col
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.expressions.st_constructors.ST_Point
+import org.apache.spark.sql.sedona_sql.expressions.st_predicates.ST_Intersects
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
@@ -175,5 +178,12 @@ class geoparquetIOTests extends TestBaseScala with BeforeAndAfterAll {
}
assert(e.getMessage.contains("does not contain valid geo metadata"))
}
+
+ it("GeoParquet load with spatial predicates") {
+ val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1)
+ val rows = df.where(ST_Intersects(ST_Point(35.174722, -6.552465), col("geometry"))).collect()
+ assert(rows.length == 1)
+ assert(rows(0).getAs[String]("name") == "Tanzania")
+ }
}
}