You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/04/24 04:12:17 UTC

[sedona] branch master updated: [SEDONA-274] move sql function collectionExtract and collect to common (#823)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ccc33f6 [SEDONA-274] move sql function collectionExtract and collect to common (#823)
7ccc33f6 is described below

commit 7ccc33f6daa6d9044e114e8134c2c209af2706f8
Author: zongsi.zhang <kr...@gmail.com>
AuthorDate: Mon Apr 24 12:12:12 2023 +0800

    [SEDONA-274] move sql function collectionExtract and collect to common (#823)
---
 .../java/org/apache/sedona/common/Functions.java   |  60 ++++++++++++
 .../org/apache/sedona/common/utils/GeomUtils.java  |  14 ++-
 .../org/apache/sedona/common/GeometryUtilTest.java |   3 +-
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |   4 +-
 .../sql/sedona_sql/expressions/Functions.scala     |  10 ++
 .../sedona_sql/expressions/collect/Collect.scala   |  53 -----------
 .../expressions/collect/ST_Collect.scala           |  10 +-
 .../expressions/collect/ST_CollectionExtract.scala | 103 ---------------------
 .../sql/sedona_sql/expressions/st_functions.scala  |   7 +-
 9 files changed, 91 insertions(+), 173 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java
index c429c4c8..c86572f8 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -47,6 +47,7 @@ import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 
 public class Functions {
@@ -670,4 +671,63 @@ public class Functions {
             return null;
         }
     }
+
+    public static Geometry createMultiGeometry(Geometry[] geometries) {
+        if (geometries.length > 1){
+            return GEOMETRY_FACTORY.buildGeometry(Arrays.asList(geometries));
+        }
+        else if(geometries.length==1){
+            return createMultiGeometryFromOneElement(geometries[0]);
+        }
+        else{
+            return GEOMETRY_FACTORY.createGeometryCollection();
+        }
+    }
+
+    public static Geometry collectionExtract(Geometry geometry, Integer geomType) {
+        if (geomType == null) {
+            return collectionExtract(geometry);
+        }
+        Class<? extends Geometry> geomClass;
+        GeometryCollection emptyResult;
+        switch (geomType) {
+            case 1:
+                geomClass = Point.class;
+                emptyResult = GEOMETRY_FACTORY.createMultiPoint();
+                break;
+            case 2:
+                geomClass = LineString.class;
+                emptyResult = GEOMETRY_FACTORY.createMultiLineString();
+                break;
+            case 3:
+                geomClass = Polygon.class;
+                emptyResult = GEOMETRY_FACTORY.createMultiPolygon();
+                break;
+            default:
+                throw new IllegalArgumentException("Invalid geometry type");
+        }
+        List<Geometry> geometries = GeomUtils.extractGeometryCollection(geometry, geomClass);
+        if (geometries.isEmpty()) {
+            return emptyResult;
+        }
+        return Functions.createMultiGeometry(geometries.toArray(new Geometry[0]));
+    }
+
+    public static Geometry collectionExtract(Geometry geometry) {
+        List<Geometry> geometries = GeomUtils.extractGeometryCollection(geometry);
+        Polygon[] polygons = geometries.stream().filter(g -> g instanceof Polygon).toArray(Polygon[]::new);
+        if (polygons.length > 0) {
+            return GEOMETRY_FACTORY.createMultiPolygon(polygons);
+        }
+        LineString[] lines = geometries.stream().filter(g -> g instanceof LineString).toArray(LineString[]::new);
+        if (lines.length > 0) {
+            return GEOMETRY_FACTORY.createMultiLineString(lines);
+        }
+        Point[] points = geometries.stream().filter(g -> g instanceof Point).toArray(Point[]::new);
+        if (points.length > 0) {
+            return GEOMETRY_FACTORY.createMultiPoint(points);
+        }
+        return GEOMETRY_FACTORY.createGeometryCollection();
+    }
+
 }
diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
index 58191e94..635c8cd4 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
@@ -382,10 +382,12 @@ public class GeomUtils {
         return pCount;
     }
 
-    public static List<Geometry> extractGeometryCollection(Geometry geom){
+    public static <T extends Geometry> List<Geometry> extractGeometryCollection(Geometry geom, Class<T> geomType){
         ArrayList<Geometry> leafs = new ArrayList<>();
         if (!(geom instanceof GeometryCollection)) {
-            leafs.add(geom);
+            if (geomType.isAssignableFrom(geom.getClass())) {
+                leafs.add(geom);
+            }
             return leafs;
         }
         LinkedList<GeometryCollection> parents = new LinkedList<>();
@@ -397,13 +399,19 @@ public class GeomUtils {
                 if (child instanceof GeometryCollection) {
                     parents.add((GeometryCollection) child);
                 } else {
-                    leafs.add(child);
+                    if (geomType.isAssignableFrom(child.getClass())) {
+                        leafs.add(child);
+                    }
                 }
             }
         }
         return leafs;
     }
 
+    public static List<Geometry> extractGeometryCollection(Geometry geom){
+        return extractGeometryCollection(geom, Geometry.class);
+    }
+
     public static Geometry[] getSubGeometries(Geometry geom) {
         Geometry[] geometries = new Geometry[geom.getNumGeometries()];
         for ( int i = 0; i < geom.getNumGeometries() ; i++) {
diff --git a/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java b/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
index 855b077f..418f0332 100644
--- a/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
+++ b/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
@@ -34,7 +34,7 @@ public class GeometryUtilTest {
     }
 
     @Test
-    public void extractGeometryCollection() throws ParseException, IOException {
+    public void extractGeometryCollection() {
         MultiPolygon multiPolygon = GEOMETRY_FACTORY.createMultiPolygon(
                 new Polygon[] {
                         GEOMETRY_FACTORY.createPolygon(coordArray(0, 1,3, 0,4, 3,0, 4,0, 1)),
@@ -60,7 +60,6 @@ public class GeometryUtilTest {
                         "GEOMETRYCOLLECTION (POLYGON ((0 1, 3 0, 4 3, 0 4, 0 1)), POLYGON ((3 4, 6 3, 5 5, 3 4)), POINT (5 8), POLYGON ((0 1, 3 0, 4 3, 0 4, 0 1)), POLYGON ((3 4, 6 3, 5 5, 3 4)))"
                 )
         );
-
     }
 
 
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index fa3493ab..1c1347c1 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
 import org.apache.spark.sql.expressions.Aggregator
-import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect, ST_CollectionExtract}
+import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect}
 import org.apache.spark.sql.sedona_sql.expressions.raster._
 import org.apache.spark.sql.sedona_sql.expressions._
 import org.locationtech.jts.geom.Geometry
@@ -135,7 +135,7 @@ object Catalog {
     function[ST_XMin](),
     function[ST_BuildArea](),
     function[ST_OrderingEquals](),
-    function[ST_CollectionExtract](),
+    function[ST_CollectionExtract](defaultArgs = null),
     function[ST_Normalize](),
     function[ST_LineFromMultiPoint](),
     function[ST_MPolyFromText](0),
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 2c56845b..9d065a78 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -923,3 +923,13 @@ case class ST_S2CellIDs(inputExpressions: Seq[Expression])
     copy(inputExpressions = newChildren)
   }
 }
+
+case class ST_CollectionExtract(inputExpressions: Seq[Expression])
+  extends InferredBinaryExpression(Functions.collectionExtract) with FoldableExpression {
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+    copy(inputExpressions = newChildren)
+  }
+
+  override def allowRightNull: Boolean = true
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
deleted file mode 100644
index aba2e86d..00000000
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/Collect.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.sedona_sql.expressions.collect
-
-import org.apache.sedona.common.geometryObjects.Circle
-import org.locationtech.jts.geom.{Geometry, GeometryCollection, GeometryFactory, LineString, Point, Polygon}
-import scala.collection.JavaConverters._
-
-
-object Collect {
-  private val geomFactory = new GeometryFactory()
-
-  def createMultiGeometry(geometries : Seq[Geometry]): Geometry = {
-    if (geometries.length>1){
-      geomFactory.buildGeometry(geometries.asJava)
-    }
-    else if(geometries.length==1){
-      createMultiGeometryFromOneElement(geometries.head)
-    }
-    else{
-      geomFactory.createGeometryCollection()
-    }
-  }
-
-  def createMultiGeometryFromOneElement(geom: Geometry): Geometry = {
-    geom match {
-      case circle: Circle                 => geomFactory.createGeometryCollection(Array(circle))
-      case collection: GeometryCollection => collection
-      case string: LineString =>
-        geomFactory.createMultiLineString(Array(string))
-      case point: Point     => geomFactory.createMultiPoint(Array(point))
-      case polygon: Polygon => geomFactory.createMultiPolygon(Array(polygon))
-      case _                => geomFactory.createGeometryCollection()
-    }
-  }
-}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index c259c366..794dedf6 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -19,17 +19,15 @@
 package org.apache.spark.sql.sedona_sql.expressions.collect
 
 
+import org.apache.sedona.common.Functions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-
 import org.apache.spark.sql.catalyst.expressions.Expression
-
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
 import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
 import org.apache.spark.sql.types.{ArrayType, _}
-
 import org.locationtech.jts.geom.Geometry
 
 case class ST_Collect(inputExpressions: Seq[Expression])
@@ -58,13 +56,13 @@ case class ST_Collect(inputExpressions: Seq[Expression])
               .filter(_ != null)
               .map(_.toGeometry)
 
-            Collect.createMultiGeometry(geomElements)
-          case _ => Collect.createMultiGeometry(Seq())
+            Functions.createMultiGeometry(geomElements.toArray)
+          case _ => Functions.createMultiGeometry(Array())
         }
       case _ =>
         val geomElements =
           inputExpressions.map(_.toGeometry(input)).filter(_ != null)
-        Collect.createMultiGeometry(geomElements)
+        Functions.createMultiGeometry(geomElements.toArray)
     }
   }
 
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala
deleted file mode 100644
index 415c5c13..00000000
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_CollectionExtract.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.sedona_sql.expressions.collect
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.GenericArrayData
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.sedona_sql.expressions.collect.GeomType.GeomTypeVal
-import org.apache.spark.sql.sedona_sql.expressions.implicits.{GeometryEnhancer, InputExpressionEnhancer}
-import org.apache.spark.sql.types.DataType
-import org.locationtech.jts.geom.{Geometry, GeometryCollection, LineString, Point, Polygon}
-
-import java.util
-
-object GeomType extends Enumeration(1) {
-  case class GeomTypeVal(empty: Function[Geometry, Geometry], multi: Function[util.List[Geometry], Geometry]) extends super.Val {}
-  import scala.language.implicitConversions
-  implicit def valueToGeomTypeVal(x: Value): GeomTypeVal = x.asInstanceOf[GeomTypeVal]
-
-  def getGeometryType(geometry: Geometry): GeomTypeVal = {
-    (geometry: Geometry) match {
-      case (geometry: GeometryCollection) =>
-        GeomType.apply(Range(0, geometry.getNumGeometries).map(i => geometry.getGeometryN(i)).map(geom => getGeometryType(geom).id).reduce(scala.math.max))
-      case (geometry: Point) => point
-      case (geometry: LineString) => line
-      case (geometry: Polygon) => polygon
-    }
-  }
-
-  val point = GeomTypeVal(geom => geom.getFactory.createMultiPoint(), geoms => geoms.get(0).getFactory().createMultiPoint(geoms.toArray(Array[Point]())))
-  val line = GeomTypeVal(geom => geom.getFactory.createMultiLineString(), geoms => geoms.get(0).getFactory().createMultiLineString(geoms.toArray(Array[LineString]())))
-  val polygon = GeomTypeVal(geom => geom.getFactory.createMultiPolygon(), geoms => geoms.get(0).getFactory().createMultiPolygon(geoms.toArray(Array[Polygon]())))
-}
-
-case class ST_CollectionExtract(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback {
-
-  override def dataType: DataType = GeometryUDT
-
-  override def children: Seq[Expression] = inputExpressions
-
-  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
-    copy(inputExpressions = newChildren)
-  }
-
-  override def nullable: Boolean = true
-
-  def nullSafeEval(geometry: Geometry, geomType: GeomTypeVal): Array[Byte] = {
-    val geometries : util.ArrayList[Geometry] = new util.ArrayList[Geometry]()
-    filterGeometry(geometries, geometry, geomType);
-
-    if (geometries.isEmpty()) {
-      geomType.empty(geometry).toGenericArrayData
-    }
-    else{
-      geomType.multi(geometries).toGenericArrayData
-    }
-  }
-
-  def filterGeometry(geometries: util.ArrayList[Geometry], geometry: Geometry, geomType: GeomTypeVal):Unit = {
-    (geometry: Geometry) match {
-      case (geometry: GeometryCollection) =>
-        Range(0, geometry.getNumGeometries).map(i => geometry.getGeometryN(i)).foreach(geom => filterGeometry(geometries, geom, geomType))
-      case (geometry: Geometry) => {
-        if(geomType==GeomType.getGeometryType(geometry))
-          geometries.add(geometry)
-      }
-
-    }
-  }
-  override def eval(input: InternalRow): Any = {
-
-    val geometry = inputExpressions.head.toGeometry(input)
-    val geomType = if (inputExpressions.length == 2) {
-      GeomType.apply(inputExpressions(1).eval(input).asInstanceOf[Int])
-    } else {
-      GeomType.getGeometryType(geometry)
-    }
-
-    (geometry) match {
-      case (geometry: Geometry) => nullSafeEval(geometry, geomType)
-      case _ => null
-    }
-  }
-}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index c8c7ac73..ce077d99 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -19,8 +19,7 @@
 package org.apache.spark.sql.sedona_sql.expressions
 
 import org.apache.spark.sql.Column
-import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect, ST_CollectionExtract}
-import org.locationtech.jts.geom.Geometry
+import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect}
 import org.locationtech.jts.operation.buffer.BufferParameters
 
 object st_functions extends DataFrameAPI {
@@ -75,8 +74,8 @@ object st_functions extends DataFrameAPI {
   def ST_Collect(geoms: String): Column = wrapExpression[ST_Collect](geoms)
   def ST_Collect(geoms: Any*): Column = wrapVarArgExpression[ST_Collect](geoms)
 
-  def ST_CollectionExtract(collection: Column): Column = wrapExpression[ST_CollectionExtract](collection)
-  def ST_CollectionExtract(collection: String): Column = wrapExpression[ST_CollectionExtract](collection)
+  def ST_CollectionExtract(collection: Column): Column = wrapExpression[ST_CollectionExtract](collection, null)
+  def ST_CollectionExtract(collection: String): Column = wrapExpression[ST_CollectionExtract](collection, null)
   def ST_CollectionExtract(collection: Column, geomType: Column): Column = wrapExpression[ST_CollectionExtract](collection, geomType)
   def ST_CollectionExtract(collection: String, geomType: Int): Column = wrapExpression[ST_CollectionExtract](collection, geomType)