You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/06/05 23:15:05 UTC

[incubator-sedona] branch master updated: [SEDONA-124] Add ST_CollectionExtract to Apache Sedona (#629)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new dae17a7a [SEDONA-124] Add ST_CollectionExtract to Apache Sedona (#629)
dae17a7a is described below

commit dae17a7a50b956899eb1366dbd0eb3ff75f5647a
Author: aggunr <me...@gmail.com>
AuthorDate: Mon Jun 6 01:15:00 2022 +0200

    [SEDONA-124] Add ST_CollectionExtract to Apache Sedona (#629)
    
    Co-authored-by: Aggun <ag...@purplescout.se>
---
 docs/api/sql/Function.md                           | 41 +++++++++++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  3 +-
 .../expressions/collect/ST_CollectionExtract.scala | 84 ++++++++++++++++++++++
 .../org/apache/sedona/sql/functionTestScala.scala  | 15 ++++
 4 files changed, 142 insertions(+), 1 deletion(-)

diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index b02d26cf..2fcbc247 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -1259,9 +1259,50 @@ SELECT ST_BuildArea(
 Result:
 
 ```
+
 +----------------------------------------------------------------------------+
 |geom                                                                        |
 +----------------------------------------------------------------------------+
 |POLYGON((0 0,0 20,20 20,20 0,0 0),(2 2,18 2,18 18,2 18,2 2))                |
 +----------------------------------------------------------------------------+
 ```
+
+## ST_CollectionExtract
+
+Introduction: Returns a homogeneous multi-geometry from a given geometry collection.
+
+The type numbers are: 
+1. POINT
+2. LINESTRING
+3. POLYGON
+
+If the type parameter is omitted a multi-geometry of the highest dimension is returned.
+
+Format: `ST_CollectionExtract (A:geometry)`
+
+Format: `ST_CollectionExtract (A:geometry, type:Int)`
+
+Since: `v1.2.1`
+
+Example:
+
+```SQL
+WITH test_data as (
+    ST_GeomFromText(
+        'GEOMETRYCOLLECTION(POINT(40 10), POLYGON((0 0, 0 5, 5 5, 5 0, 0 0)))'
+    ) as geom
+)
+SELECT ST_CollectionExtract(geom) as c1, ST_CollectionExtract(geom, 1) as c2 
+FROM test_data
+
+```
+
+Result:
+
+```
++----------------------------------------------------------------------------+
+|c1                                        |c2                               |
++----------------------------------------------------------------------------+
+|MULTIPOLYGON(((0 0, 0 5, 5 5, 5 0, 0 0))) |MULTIPOINT(40 10)                |              |
++----------------------------------------------------------------------------+
+```
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index dbb402f5..d0babdc1 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -21,7 +21,7 @@ package org.apache.sedona.sql.UDF
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction}
 import org.apache.spark.sql.sedona_sql.expressions.{ST_YMax, ST_YMin, _}
-import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
+import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect, ST_CollectionExtract}
 import org.apache.spark.sql.sedona_sql.expressions.raster.{RS_Add, RS_Append, RS_Array, RS_Base64, RS_BitwiseAnd, RS_BitwiseOr, RS_Count, RS_Divide, RS_FetchRegion, RS_GetBand, RS_GreaterThan, RS_GreaterThanEqual, RS_HTML, RS_LessThan, RS_LessThanEqual, RS_LogicalDifference, RS_LogicalOver, RS_Mean, RS_Mode, RS_Modulo, RS_Multiply, RS_MultiplyFactor, RS_Normalize, RS_NormalizedDifference, RS_SquareRoot, RS_Subtract}
 import org.locationtech.jts.geom.Geometry
 
@@ -115,6 +115,7 @@ object Catalog {
     ST_XMin,
     ST_BuildArea,
     ST_OrderingEquals,
+    ST_CollectionExtract,
     // Expression for rasters
     RS_NormalizedDifference,
     RS_Mean,
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala
new file mode 100644
index 00000000..28871111
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala
@@ -0,0 +1,84 @@
+package org.apache.spark.sql.sedona_sql.expressions.collect
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.expressions.collect.GeomType.GeomTypeVal
+import org.apache.spark.sql.sedona_sql.expressions.implicits.{GeometryEnhancer, InputExpressionEnhancer}
+import org.apache.spark.sql.types.DataType
+import org.locationtech.jts.geom.{Geometry, GeometryCollection, LineString, Point, Polygon}
+
+import java.util
+
+object GeomType extends Enumeration(1) {
+  case class GeomTypeVal(empty: Function[Geometry, Geometry], multi: Function[util.List[Geometry], Geometry]) extends super.Val {}
+  import scala.language.implicitConversions
+  implicit def valueToGeomTypeVal(x: Value): GeomTypeVal = x.asInstanceOf[GeomTypeVal]
+
+  def getGeometryType(geometry: Geometry): GeomTypeVal = {
+    (geometry: Geometry) match {
+      case (geometry: GeometryCollection) =>
+        GeomType.apply(Range(0, geometry.getNumGeometries).map(i => geometry.getGeometryN(i)).map(geom => getGeometryType(geom).id).reduce(scala.math.max))
+      case (geometry: Point) => point
+      case (geometry: LineString) => line
+      case (geometry: Polygon) => polygon
+    }
+  }
+
+  val point = GeomTypeVal(geom => geom.getFactory.createMultiPoint(), geoms => geoms.get(0).getFactory().createMultiPoint(geoms.toArray(Array[Point]())))
+  val line = GeomTypeVal(geom => geom.getFactory.createMultiLineString(), geoms => geoms.get(0).getFactory().createMultiLineString(geoms.toArray(Array[LineString]())))
+  val polygon = GeomTypeVal(geom => geom.getFactory.createMultiPolygon(), geoms => geoms.get(0).getFactory().createMultiPolygon(geoms.toArray(Array[Polygon]())))
+}
+
+case class ST_CollectionExtract(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback {
+
+  override def dataType: DataType = GeometryUDT
+
+  override def children: Seq[Expression] = inputExpressions
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
+    copy(inputExpressions = newChildren)
+  }
+
+  override def nullable: Boolean = true
+
+  def nullSafeEval(geometry: Geometry, geomType: GeomTypeVal): GenericArrayData = {
+    val geometries : util.ArrayList[Geometry] = new util.ArrayList[Geometry]()
+    filterGeometry(geometries, geometry, geomType);
+
+    if (geometries.isEmpty()) {
+      geomType.empty(geometry).toGenericArrayData
+    }
+    else{
+      geomType.multi(geometries).toGenericArrayData
+    }
+  }
+
+  def filterGeometry(geometries: util.ArrayList[Geometry], geometry: Geometry, geomType: GeomTypeVal):Unit = {
+    (geometry: Geometry) match {
+      case (geometry: GeometryCollection) =>
+        Range(0, geometry.getNumGeometries).map(i => geometry.getGeometryN(i)).foreach(geom => filterGeometry(geometries, geom, geomType))
+      case (geometry: Geometry) => {
+        if(geomType==GeomType.getGeometryType(geometry))
+          geometries.add(geometry)
+      }
+
+    }
+  }
+  override def eval(input: InternalRow): Any = {
+
+    val geometry = inputExpressions.head.toGeometry(input)
+    val geomType = if (inputExpressions.length == 2) {
+      GeomType.apply(inputExpressions(1).eval(input).asInstanceOf[Int])
+    } else {
+      GeomType.getGeometryType(geometry)
+    }
+
+    (geometry) match {
+      case (geometry: Geometry) => nullSafeEval(geometry, geomType)
+      case _ => null
+    }
+  }
+}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index a75d70aa..75c68eac 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -1276,6 +1276,8 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
 
   }
 
+
+
   it ("Should pass ST_PointOnSurface") {
 
     val geomTestCases1 = Map(
@@ -1574,4 +1576,17 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
     functionDf = sparkSession.sql("select ST_BuildArea(null)")
     assert(functionDf.first().get(0) == null)
   }
+
+  it ("Should pass St_CollectionExtract") {
+    var df = sparkSession.sql("SELECT ST_GeomFromText('GEOMETRYCOLLECTION(POINT(40 10), LINESTRING(0 5, 0 10), POLYGON((0 0, 0 5, 5 5, 5 0, 0 0)))') as geom")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom))").collect().head.get(0) == "MULTIPOLYGON (((0 0, 0 5, 5 5, 5 0, 0 0)))")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom, 3))").collect().head.get(0) == "MULTIPOLYGON (((0 0, 0 5, 5 5, 5 0, 0 0)))")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom, 1))").collect().head.get(0) == "MULTIPOINT ((40 10))")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom, 2))").collect().head.get(0) == "MULTILINESTRING ((0 5, 0 10))")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom, 2))").collect().head.get(0) == "MULTILINESTRING ((0 5, 0 10))")
+    df = sparkSession.sql("SELECT ST_GeomFromText('GEOMETRYCOLLECTION (POINT (40 10), POINT (40 10))') as geom")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom, 1))").collect().head.get(0) == "MULTIPOINT ((40 10), (40 10))")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom, 2))").collect().head.get(0) == "MULTILINESTRING EMPTY")
+    assert(df.selectExpr("ST_AsText(ST_CollectionExtract(geom))").collect().head.get(0) == "MULTIPOINT ((40 10), (40 10))")
+  }
 }
\ No newline at end of file