You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/04/04 14:28:36 UTC
[incubator-sedona] branch master updated: [SEDONA-100] Add st_multi function (#595)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 1dbc30ee [SEDONA-100] Add st_multi function (#595)
1dbc30ee is described below
commit 1dbc30eef4251248573eed4b65ad59f8bf1af2b6
Author: aggunr <me...@gmail.com>
AuthorDate: Mon Apr 4 16:28:31 2022 +0200
[SEDONA-100] Add st_multi function (#595)
Co-authored-by: Aggun <ag...@purplescout.se>
---
docs/api/sql/Function.md | 33 ++++++++++++++++++++-
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 +
.../sql/sedona_sql/expressions/Functions.scala | 15 ++++++++++
.../sedona_sql/expressions/collect/Collect.scala | 34 ++++++++++++++++++++++
.../expressions/collect/ST_Collect.scala | 33 ++-------------------
.../org/apache/sedona/sql/functionTestScala.scala | 6 ++++
6 files changed, 91 insertions(+), 31 deletions(-)
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index 7922fdb9..7085cdee 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -887,12 +887,14 @@ Format
Since: `v1.2.0`
Example:
+
```SQL
SELECT ST_Collect(
ST_GeomFromText('POINT(21.427834 52.042576573)'),
ST_GeomFromText('POINT(45.342524 56.342354355)')
) AS geom
```
+
Result:
```
@@ -924,6 +926,35 @@ Result:
+---------------------------------------------------------------+
```
+## ST_Multi
+
+Introduction: Returns a MultiGeometry object based on the geometry input.
+ST_Multi is basically an alias for ST_Collect with one geometry.
+
+Format
+
+`ST_Multi(geom: geometry)`
+
+Since: `v1.2.1`
+
+Example:
+
+```SQL
+SELECT ST_Multi(
+ ST_GeomFromText('POINT(1 1)')
+) AS geom
+```
+
+Result:
+
+```
++---------------------------------------------------------------+
+|geom |
++---------------------------------------------------------------+
+|MULTIPOINT (1 1) |
++---------------------------------------------------------------+
+```
+
## ST_Difference
Introduction: Return the difference between geometry A and B (return part of geometry A that does not intersect geometry B)
@@ -984,4 +1015,4 @@ Result:
```
POLYGON ((3 -1, 3 -3, -3 -3, -3 3, 3 3, 3 1, 5 0, 3 -1))
-```
\ No newline at end of file
+```
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 56c38c55..473eb5c7 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -100,6 +100,7 @@ object Catalog {
ST_GeoHash,
ST_GeomFromGeoHash,
ST_Collect,
+ ST_Multi,
// Expression for rasters
RS_NormalizedDifference,
RS_Mean,
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 35309c6b..effe50bb 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, Codege
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Generator}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.sedona_sql.expressions.collect.Collect
import org.apache.spark.sql.sedona_sql.expressions.geohash.{GeoHashDecoder, GeometryGeoHashEncoder, InvalidGeoHashException}
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.sedona_sql.expressions.subdivide.GeometrySubDivider
@@ -1522,4 +1523,18 @@ case class ST_Union(inputExpressions: Seq[Expression])
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
+}
+
+case class ST_Multi(inputExpressions: Seq[Expression]) extends UnaryGeometryExpression with CodegenFallback{
+ override def dataType: DataType = GeometryUDT
+
+ override def children: Seq[Expression] = inputExpressions
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+ copy(inputExpressions = newChildren)
+ }
+
+ override protected def nullSafeEval(geometry: Geometry): Any ={
+ Collect.createMultiGeometry(Seq(geometry)).toGenericArrayData
+ }
}
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
new file mode 100644
index 00000000..8b8cb267
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
@@ -0,0 +1,34 @@
+package org.apache.spark.sql.sedona_sql.expressions.collect
+
+import org.apache.sedona.core.geometryObjects.Circle
+import org.locationtech.jts.geom.{Geometry, GeometryCollection, GeometryFactory, LineString, Point, Polygon}
+import scala.collection.JavaConverters._
+
+
+object Collect {
+ private val geomFactory = new GeometryFactory()
+
+ def createMultiGeometry(geometries : Seq[Geometry]): Geometry = {
+ if (geometries.length>1){
+ geomFactory.buildGeometry(geometries.asJava)
+ }
+ else if(geometries.length==1){
+ createMultiGeometryFromOneElement(geometries.head)
+ }
+ else{
+ geomFactory.createGeometryCollection()
+ }
+ }
+
+ private def createMultiGeometryFromOneElement(geom: Geometry): Geometry = {
+ geom match {
+ case circle: Circle => geomFactory.createGeometryCollection(Array(circle))
+ case collection: GeometryCollection => collection
+ case string: LineString =>
+ geomFactory.createMultiLineString(Array(string))
+ case point: Point => geomFactory.createMultiPoint(Array(point))
+ case polygon: Polygon => geomFactory.createMultiPolygon(Array(polygon))
+ case _ => geomFactory.createGeometryCollection()
+ }
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index c7b4fa35..cdded9d8 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -18,7 +18,6 @@
*/
package org.apache.spark.sql.sedona_sql.expressions.collect
-import org.apache.sedona.core.geometryObjects.Circle
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
@@ -28,14 +27,11 @@ import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.types.{ArrayType, _}
-import org.locationtech.jts.geom._
-import scala.collection.JavaConverters._
case class ST_Collect(inputExpressions: Seq[Expression])
extends Expression
with CodegenFallback {
assert(inputExpressions.length >= 1)
- private val geomFactory = new GeometryFactory()
override def nullable: Boolean = true
@@ -53,37 +49,14 @@ case class ST_Collect(inputExpressions: Seq[Expression])
.filter(_ != null)
.map(_.toGeometry)
- createMultiGeometry(geomElements)
- case _ => emptyCollection
+ Collect.createMultiGeometry(geomElements).toGenericArrayData
+ case _ => Collect.createMultiGeometry(Seq()).toGenericArrayData
}
case _ =>
val geomElements =
inputExpressions.map(_.toGeometry(input)).filter(_ != null)
- val length = geomElements.length
- if (length > 1) createMultiGeometry(geomElements)
- else if (length == 1)
- createMultiGeometryFromOneElement(
- geomElements.head
- ).toGenericArrayData
- else emptyCollection
- }
- }
-
- private def createMultiGeometry(geomElements: Seq[Geometry]) =
- geomFactory.buildGeometry(geomElements.asJava).toGenericArrayData
-
- private def emptyCollection =
- geomFactory.createGeometryCollection().toGenericArrayData
+ Collect.createMultiGeometry(geomElements).toGenericArrayData
- private def createMultiGeometryFromOneElement(geom: Geometry): Geometry = {
- geom match {
- case circle: Circle => geomFactory.createGeometryCollection(Array(circle))
- case collection: GeometryCollection => collection
- case string: LineString =>
- geomFactory.createMultiLineString(Array(string))
- case point: Point => geomFactory.createMultiPoint(Array(point))
- case polygon: Polygon => geomFactory.createMultiPolygon(Array(polygon))
- case _ => geomFactory.createGeometryCollection()
}
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 422afcf0..9612fd56 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -1278,6 +1278,12 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
)
}
+ it ("Should pass ST_Multi"){
+ val df = sparkSession.sql("select ST_Astext(ST_Multi(ST_Point(1.0,1.0)))")
+ val result = df.collect()
+ assert(result.head.get(0).asInstanceOf[String]=="MULTIPOINT ((1 1))")
+
+ }
it("handles nulls") {
var functionDf: DataFrame = null