You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/04/04 14:28:36 UTC

[incubator-sedona] branch master updated: [SEDONA-100] Add st_multi function (#595)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 1dbc30ee [SEDONA-100] Add st_multi function (#595)
1dbc30ee is described below

commit 1dbc30eef4251248573eed4b65ad59f8bf1af2b6
Author: aggunr <me...@gmail.com>
AuthorDate: Mon Apr 4 16:28:31 2022 +0200

    [SEDONA-100] Add st_multi function (#595)
    
    Co-authored-by: Aggun <ag...@purplescout.se>
---
 docs/api/sql/Function.md                           | 33 ++++++++++++++++++++-
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  1 +
 .../sql/sedona_sql/expressions/Functions.scala     | 15 ++++++++++
 .../sedona_sql/expressions/collect/Collect.scala   | 34 ++++++++++++++++++++++
 .../expressions/collect/ST_Collect.scala           | 33 ++-------------------
 .../org/apache/sedona/sql/functionTestScala.scala  |  6 ++++
 6 files changed, 91 insertions(+), 31 deletions(-)

diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index 7922fdb9..7085cdee 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -887,12 +887,14 @@ Format
 Since: `v1.2.0`
 
 Example: 
+
 ```SQL
 SELECT ST_Collect(
     ST_GeomFromText('POINT(21.427834 52.042576573)'),
     ST_GeomFromText('POINT(45.342524 56.342354355)')
 ) AS geom
 ```
+
 Result:
 
 ```
@@ -924,6 +926,35 @@ Result:
 +---------------------------------------------------------------+
 ```
 
+## ST_Multi
+
+Introduction: Returns a MultiGeometry object based on the geometry input.
+ST_Multi is basically an alias for ST_Collect with one geometry.
+
+Format
+
+`ST_Multi(geom: geometry)`
+
+Since: `v1.2.1`
+
+Example:
+
+```SQL
+SELECT ST_Multi(
+    ST_GeomFromText('POINT(1 1)')
+) AS geom
+```
+
+Result:
+
+```
++---------------------------------------------------------------+
+|geom                                                           |
++---------------------------------------------------------------+
+|MULTIPOINT (1 1)                                               |
++---------------------------------------------------------------+
+```
+
 ## ST_Difference
 
 Introduction: Return the difference between geometry A and B (return part of geometry A that does not intersect geometry B)
@@ -984,4 +1015,4 @@ Result:
 
 ```
 POLYGON ((3 -1, 3 -3, -3 -3, -3 3, 3 3, 3 1, 5 0, 3 -1))
-```
\ No newline at end of file
+```
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 56c38c55..473eb5c7 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -100,6 +100,7 @@ object Catalog {
     ST_GeoHash,
     ST_GeomFromGeoHash,
     ST_Collect,
+    ST_Multi,
     // Expression for rasters
     RS_NormalizedDifference,
     RS_Mean,
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 35309c6b..effe50bb 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, Codege
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Generator}
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.expressions.collect.Collect
 import org.apache.spark.sql.sedona_sql.expressions.geohash.{GeoHashDecoder, GeometryGeoHashEncoder, InvalidGeoHashException}
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
 import org.apache.spark.sql.sedona_sql.expressions.subdivide.GeometrySubDivider
@@ -1522,4 +1523,18 @@ case class ST_Union(inputExpressions: Seq[Expression])
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
     copy(inputExpressions = newChildren)
   }
+}
+
+case class ST_Multi(inputExpressions: Seq[Expression]) extends UnaryGeometryExpression with CodegenFallback{
+  override def dataType: DataType = GeometryUDT
+
+  override def children: Seq[Expression] = inputExpressions
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+    copy(inputExpressions = newChildren)
+  }
+
+  override protected def nullSafeEval(geometry: Geometry): Any ={
+    Collect.createMultiGeometry(Seq(geometry)).toGenericArrayData
+  }
 }
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
new file mode 100644
index 00000000..8b8cb267
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
@@ -0,0 +1,34 @@
+package org.apache.spark.sql.sedona_sql.expressions.collect
+
+import org.apache.sedona.core.geometryObjects.Circle
+import org.locationtech.jts.geom.{Geometry, GeometryCollection, GeometryFactory, LineString, Point, Polygon}
+import scala.collection.JavaConverters._
+
+
+object Collect {
+  private val geomFactory = new GeometryFactory()
+
+  def createMultiGeometry(geometries : Seq[Geometry]): Geometry = {
+    if (geometries.length>1){
+      geomFactory.buildGeometry(geometries.asJava)
+    }
+    else if(geometries.length==1){
+      createMultiGeometryFromOneElement(geometries.head)
+    }
+    else{
+      geomFactory.createGeometryCollection()
+    }
+  }
+
+  private def createMultiGeometryFromOneElement(geom: Geometry): Geometry = {
+    geom match {
+      case circle: Circle                 => geomFactory.createGeometryCollection(Array(circle))
+      case collection: GeometryCollection => collection
+      case string: LineString =>
+        geomFactory.createMultiLineString(Array(string))
+      case point: Point     => geomFactory.createMultiPoint(Array(point))
+      case polygon: Polygon => geomFactory.createMultiPolygon(Array(polygon))
+      case _                => geomFactory.createGeometryCollection()
+    }
+  }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index c7b4fa35..cdded9d8 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -18,7 +18,6 @@
  */
 package org.apache.spark.sql.sedona_sql.expressions.collect
 
-import org.apache.sedona.core.geometryObjects.Circle
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 
@@ -28,14 +27,11 @@ import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
 import org.apache.spark.sql.types.{ArrayType, _}
-import org.locationtech.jts.geom._
-import scala.collection.JavaConverters._
 
 case class ST_Collect(inputExpressions: Seq[Expression])
     extends Expression
     with CodegenFallback {
   assert(inputExpressions.length >= 1)
-  private val geomFactory = new GeometryFactory()
 
   override def nullable: Boolean = true
 
@@ -53,37 +49,14 @@ case class ST_Collect(inputExpressions: Seq[Expression])
               .filter(_ != null)
               .map(_.toGeometry)
 
-            createMultiGeometry(geomElements)
-          case _ => emptyCollection
+            Collect.createMultiGeometry(geomElements).toGenericArrayData
+          case _ => Collect.createMultiGeometry(Seq()).toGenericArrayData
         }
       case _ =>
         val geomElements =
           inputExpressions.map(_.toGeometry(input)).filter(_ != null)
-        val length = geomElements.length
-        if (length > 1) createMultiGeometry(geomElements)
-        else if (length == 1)
-          createMultiGeometryFromOneElement(
-            geomElements.head
-          ).toGenericArrayData
-        else emptyCollection
-    }
-  }
-
-  private def createMultiGeometry(geomElements: Seq[Geometry]) =
-    geomFactory.buildGeometry(geomElements.asJava).toGenericArrayData
-
-  private def emptyCollection =
-    geomFactory.createGeometryCollection().toGenericArrayData
+        Collect.createMultiGeometry(geomElements).toGenericArrayData
 
-  private def createMultiGeometryFromOneElement(geom: Geometry): Geometry = {
-    geom match {
-      case circle: Circle                 => geomFactory.createGeometryCollection(Array(circle))
-      case collection: GeometryCollection => collection
-      case string: LineString =>
-        geomFactory.createMultiLineString(Array(string))
-      case point: Point     => geomFactory.createMultiPoint(Array(point))
-      case polygon: Polygon => geomFactory.createMultiPolygon(Array(polygon))
-      case _                => geomFactory.createGeometryCollection()
     }
   }
 
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 422afcf0..9612fd56 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -1278,6 +1278,12 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
       )
   }
 
+  it ("Should pass ST_Multi"){
+    val df = sparkSession.sql("select ST_Astext(ST_Multi(ST_Point(1.0,1.0)))")
+    val result = df.collect()
+    assert(result.head.get(0).asInstanceOf[String]=="MULTIPOINT ((1 1))")
+
+  }
 
   it("handles nulls") {
     var functionDf: DataFrame = null