You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/04/24 04:12:17 UTC
[sedona] branch master updated: [SEDONA-274] move sql function collectionExtract and collect to common (#823)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 7ccc33f6 [SEDONA-274] move sql function collectionExtract and collect to common (#823)
7ccc33f6 is described below
commit 7ccc33f6daa6d9044e114e8134c2c209af2706f8
Author: zongsi.zhang <kr...@gmail.com>
AuthorDate: Mon Apr 24 12:12:12 2023 +0800
[SEDONA-274] move sql function collectionExtract and collect to common (#823)
---
.../java/org/apache/sedona/common/Functions.java | 60 ++++++++++++
.../org/apache/sedona/common/utils/GeomUtils.java | 14 ++-
.../org/apache/sedona/common/GeometryUtilTest.java | 3 +-
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 4 +-
.../sql/sedona_sql/expressions/Functions.scala | 10 ++
.../sedona_sql/expressions/collect/Collect.scala | 53 -----------
.../expressions/collect/ST_Collect.scala | 10 +-
.../expressions/collect/ST_CollectionExtract.scala | 103 ---------------------
.../sql/sedona_sql/expressions/st_functions.scala | 7 +-
9 files changed, 91 insertions(+), 173 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java
index c429c4c8..c86572f8 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -47,6 +47,7 @@ import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
public class Functions {
@@ -670,4 +671,63 @@ public class Functions {
return null;
}
}
+
+ public static Geometry createMultiGeometry(Geometry[] geometries) {
+ if (geometries.length > 1){
+ return GEOMETRY_FACTORY.buildGeometry(Arrays.asList(geometries));
+ }
+ else if(geometries.length==1){
+ return createMultiGeometryFromOneElement(geometries[0]);
+ }
+ else{
+ return GEOMETRY_FACTORY.createGeometryCollection();
+ }
+ }
+
+ public static Geometry collectionExtract(Geometry geometry, Integer geomType) {
+ if (geomType == null) {
+ return collectionExtract(geometry);
+ }
+ Class<? extends Geometry> geomClass;
+ GeometryCollection emptyResult;
+ switch (geomType) {
+ case 1:
+ geomClass = Point.class;
+ emptyResult = GEOMETRY_FACTORY.createMultiPoint();
+ break;
+ case 2:
+ geomClass = LineString.class;
+ emptyResult = GEOMETRY_FACTORY.createMultiLineString();
+ break;
+ case 3:
+ geomClass = Polygon.class;
+ emptyResult = GEOMETRY_FACTORY.createMultiPolygon();
+ break;
+ default:
+ throw new IllegalArgumentException("Invalid geometry type");
+ }
+ List<Geometry> geometries = GeomUtils.extractGeometryCollection(geometry, geomClass);
+ if (geometries.isEmpty()) {
+ return emptyResult;
+ }
+ return Functions.createMultiGeometry(geometries.toArray(new Geometry[0]));
+ }
+
+ public static Geometry collectionExtract(Geometry geometry) {
+ List<Geometry> geometries = GeomUtils.extractGeometryCollection(geometry);
+ Polygon[] polygons = geometries.stream().filter(g -> g instanceof Polygon).toArray(Polygon[]::new);
+ if (polygons.length > 0) {
+ return GEOMETRY_FACTORY.createMultiPolygon(polygons);
+ }
+ LineString[] lines = geometries.stream().filter(g -> g instanceof LineString).toArray(LineString[]::new);
+ if (lines.length > 0) {
+ return GEOMETRY_FACTORY.createMultiLineString(lines);
+ }
+ Point[] points = geometries.stream().filter(g -> g instanceof Point).toArray(Point[]::new);
+ if (points.length > 0) {
+ return GEOMETRY_FACTORY.createMultiPoint(points);
+ }
+ return GEOMETRY_FACTORY.createGeometryCollection();
+ }
+
}
diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
index 58191e94..635c8cd4 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
@@ -382,10 +382,12 @@ public class GeomUtils {
return pCount;
}
- public static List<Geometry> extractGeometryCollection(Geometry geom){
+ public static <T extends Geometry> List<Geometry> extractGeometryCollection(Geometry geom, Class<T> geomType){
ArrayList<Geometry> leafs = new ArrayList<>();
if (!(geom instanceof GeometryCollection)) {
- leafs.add(geom);
+ if (geomType.isAssignableFrom(geom.getClass())) {
+ leafs.add(geom);
+ }
return leafs;
}
LinkedList<GeometryCollection> parents = new LinkedList<>();
@@ -397,13 +399,19 @@ public class GeomUtils {
if (child instanceof GeometryCollection) {
parents.add((GeometryCollection) child);
} else {
- leafs.add(child);
+ if (geomType.isAssignableFrom(child.getClass())) {
+ leafs.add(child);
+ }
}
}
}
return leafs;
}
+ public static List<Geometry> extractGeometryCollection(Geometry geom){
+ return extractGeometryCollection(geom, Geometry.class);
+ }
+
public static Geometry[] getSubGeometries(Geometry geom) {
Geometry[] geometries = new Geometry[geom.getNumGeometries()];
for ( int i = 0; i < geom.getNumGeometries() ; i++) {
diff --git a/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java b/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
index 855b077f..418f0332 100644
--- a/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
+++ b/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
@@ -34,7 +34,7 @@ public class GeometryUtilTest {
}
@Test
- public void extractGeometryCollection() throws ParseException, IOException {
+ public void extractGeometryCollection() {
MultiPolygon multiPolygon = GEOMETRY_FACTORY.createMultiPolygon(
new Polygon[] {
GEOMETRY_FACTORY.createPolygon(coordArray(0, 1,3, 0,4, 3,0, 4,0, 1)),
@@ -60,7 +60,6 @@ public class GeometryUtilTest {
"GEOMETRYCOLLECTION (POLYGON ((0 1, 3 0, 4 3, 0 4, 0 1)), POLYGON ((3 4, 6 3, 5 5, 3 4)), POINT (5 8), POLYGON ((0 1, 3 0, 4 3, 0 4, 0 1)), POLYGON ((3 4, 6 3, 5 5, 3 4)))"
)
);
-
}
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index fa3493ab..1c1347c1 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
-import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect, ST_CollectionExtract}
+import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect}
import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.apache.spark.sql.sedona_sql.expressions._
import org.locationtech.jts.geom.Geometry
@@ -135,7 +135,7 @@ object Catalog {
function[ST_XMin](),
function[ST_BuildArea](),
function[ST_OrderingEquals](),
- function[ST_CollectionExtract](),
+ function[ST_CollectionExtract](defaultArgs = null),
function[ST_Normalize](),
function[ST_LineFromMultiPoint](),
function[ST_MPolyFromText](0),
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 2c56845b..9d065a78 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -923,3 +923,13 @@ case class ST_S2CellIDs(inputExpressions: Seq[Expression])
copy(inputExpressions = newChildren)
}
}
+
+case class ST_CollectionExtract(inputExpressions: Seq[Expression])
+ extends InferredBinaryExpression(Functions.collectionExtract) with FoldableExpression {
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+ copy(inputExpressions = newChildren)
+ }
+
+ override def allowRightNull: Boolean = true
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
deleted file mode 100644
index aba2e86d..00000000
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.sedona_sql.expressions.collect
-
-import org.apache.sedona.common.geometryObjects.Circle
-import org.locationtech.jts.geom.{Geometry, GeometryCollection, GeometryFactory, LineString, Point, Polygon}
-import scala.collection.JavaConverters._
-
-
-object Collect {
- private val geomFactory = new GeometryFactory()
-
- def createMultiGeometry(geometries : Seq[Geometry]): Geometry = {
- if (geometries.length>1){
- geomFactory.buildGeometry(geometries.asJava)
- }
- else if(geometries.length==1){
- createMultiGeometryFromOneElement(geometries.head)
- }
- else{
- geomFactory.createGeometryCollection()
- }
- }
-
- def createMultiGeometryFromOneElement(geom: Geometry): Geometry = {
- geom match {
- case circle: Circle => geomFactory.createGeometryCollection(Array(circle))
- case collection: GeometryCollection => collection
- case string: LineString =>
- geomFactory.createMultiLineString(Array(string))
- case point: Point => geomFactory.createMultiPoint(Array(point))
- case polygon: Polygon => geomFactory.createMultiPolygon(Array(polygon))
- case _ => geomFactory.createGeometryCollection()
- }
- }
-}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index c259c366..794dedf6 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -19,17 +19,15 @@
package org.apache.spark.sql.sedona_sql.expressions.collect
+import org.apache.sedona.common.Functions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-
import org.apache.spark.sql.catalyst.expressions.Expression
-
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
import org.apache.spark.sql.types.{ArrayType, _}
-
import org.locationtech.jts.geom.Geometry
case class ST_Collect(inputExpressions: Seq[Expression])
@@ -58,13 +56,13 @@ case class ST_Collect(inputExpressions: Seq[Expression])
.filter(_ != null)
.map(_.toGeometry)
- Collect.createMultiGeometry(geomElements)
- case _ => Collect.createMultiGeometry(Seq())
+ Functions.createMultiGeometry(geomElements.toArray)
+ case _ => Functions.createMultiGeometry(Array())
}
case _ =>
val geomElements =
inputExpressions.map(_.toGeometry(input)).filter(_ != null)
- Collect.createMultiGeometry(geomElements)
+ Functions.createMultiGeometry(geomElements.toArray)
}
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala
deleted file mode 100644
index 415c5c13..00000000
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.sedona_sql.expressions.collect
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.sedona_sql.expressions.collect.GeomType.GeomTypeVal
-import org.apache.spark.sql.sedona_sql.expressions.implicits.{GeometryEnhancer, InputExpressionEnhancer}
-import org.apache.spark.sql.types.DataType
-import org.locationtech.jts.geom.{Geometry, GeometryCollection, LineString, Point, Polygon}
-
-import java.util
-
-object GeomType extends Enumeration(1) {
- case class GeomTypeVal(empty: Function[Geometry, Geometry], multi: Function[util.List[Geometry], Geometry]) extends super.Val {}
- import scala.language.implicitConversions
- implicit def valueToGeomTypeVal(x: Value): GeomTypeVal = x.asInstanceOf[GeomTypeVal]
-
- def getGeometryType(geometry: Geometry): GeomTypeVal = {
- (geometry: Geometry) match {
- case (geometry: GeometryCollection) =>
- GeomType.apply(Range(0, geometry.getNumGeometries).map(i => geometry.getGeometryN(i)).map(geom => getGeometryType(geom).id).reduce(scala.math.max))
- case (geometry: Point) => point
- case (geometry: LineString) => line
- case (geometry: Polygon) => polygon
- }
- }
-
- val point = GeomTypeVal(geom => geom.getFactory.createMultiPoint(), geoms => geoms.get(0).getFactory().createMultiPoint(geoms.toArray(Array[Point]())))
- val line = GeomTypeVal(geom => geom.getFactory.createMultiLineString(), geoms => geoms.get(0).getFactory().createMultiLineString(geoms.toArray(Array[LineString]())))
- val polygon = GeomTypeVal(geom => geom.getFactory.createMultiPolygon(), geoms => geoms.get(0).getFactory().createMultiPolygon(geoms.toArray(Array[Polygon]())))
-}
-
-case class ST_CollectionExtract(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback {
-
- override def dataType: DataType = GeometryUDT
-
- override def children: Seq[Expression] = inputExpressions
-
- protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
- copy(inputExpressions = newChildren)
- }
-
- override def nullable: Boolean = true
-
- def nullSafeEval(geometry: Geometry, geomType: GeomTypeVal): Array[Byte] = {
- val geometries : util.ArrayList[Geometry] = new util.ArrayList[Geometry]()
- filterGeometry(geometries, geometry, geomType);
-
- if (geometries.isEmpty()) {
- geomType.empty(geometry).toGenericArrayData
- }
- else{
- geomType.multi(geometries).toGenericArrayData
- }
- }
-
- def filterGeometry(geometries: util.ArrayList[Geometry], geometry: Geometry, geomType: GeomTypeVal):Unit = {
- (geometry: Geometry) match {
- case (geometry: GeometryCollection) =>
- Range(0, geometry.getNumGeometries).map(i => geometry.getGeometryN(i)).foreach(geom => filterGeometry(geometries, geom, geomType))
- case (geometry: Geometry) => {
- if(geomType==GeomType.getGeometryType(geometry))
- geometries.add(geometry)
- }
-
- }
- }
- override def eval(input: InternalRow): Any = {
-
- val geometry = inputExpressions.head.toGeometry(input)
- val geomType = if (inputExpressions.length == 2) {
- GeomType.apply(inputExpressions(1).eval(input).asInstanceOf[Int])
- } else {
- GeomType.getGeometryType(geometry)
- }
-
- (geometry) match {
- case (geometry: Geometry) => nullSafeEval(geometry, geomType)
- case _ => null
- }
- }
-}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index c8c7ac73..ce077d99 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -19,8 +19,7 @@
package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.Column
-import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect, ST_CollectionExtract}
-import org.locationtech.jts.geom.Geometry
+import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect}
import org.locationtech.jts.operation.buffer.BufferParameters
object st_functions extends DataFrameAPI {
@@ -75,8 +74,8 @@ object st_functions extends DataFrameAPI {
def ST_Collect(geoms: String): Column = wrapExpression[ST_Collect](geoms)
def ST_Collect(geoms: Any*): Column = wrapVarArgExpression[ST_Collect](geoms)
- def ST_CollectionExtract(collection: Column): Column = wrapExpression[ST_CollectionExtract](collection)
- def ST_CollectionExtract(collection: String): Column = wrapExpression[ST_CollectionExtract](collection)
+ def ST_CollectionExtract(collection: Column): Column = wrapExpression[ST_CollectionExtract](collection, null)
+ def ST_CollectionExtract(collection: String): Column = wrapExpression[ST_CollectionExtract](collection, null)
def ST_CollectionExtract(collection: Column, geomType: Column): Column = wrapExpression[ST_CollectionExtract](collection, geomType)
def ST_CollectionExtract(collection: String, geomType: Int): Column = wrapExpression[ST_CollectionExtract](collection, geomType)