You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by zo...@apache.org on 2023/01/30 17:52:48 UTC

[sedona] 01/01: add google-s2 library, function geometry to cell ids, add spark sql function

This is an automated email from the ASF dual-hosted git repository.

zongsizhang pushed a commit to branch feature/google-s2
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit b723ab1c9af8c3bcc3155aa5544b4a8f3231934a
Author: zzs-wherobots <zz...@ZongsideMac-Studio.local>
AuthorDate: Tue Jan 31 00:07:37 2023 +0800

    add google-s2 library, function geometry to cell ids, add spark sql function
---
 common/pom.xml                                     |  5 ++
 .../java/org/apache/sedona/common/Functions.java   | 25 ++++++++
 .../org/apache/sedona/common/FunctionsTest.java    | 35 ++++++++----
 pom.xml                                            |  1 +
 sql/pom.xml                                        |  5 ++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  3 +-
 .../sql/sedona_sql/expressions/Functions.scala     |  8 +++
 .../expressions/NullSafeExpressions.scala          | 14 +++++
 .../sql/sedona_sql/expressions/st_functions.scala  |  3 +
 .../sql/functions/StGetGoogleS2CellIDs.scala       | 66 ++++++++++++++++++++++
 10 files changed, 154 insertions(+), 11 deletions(-)

diff --git a/common/pom.xml b/common/pom.xml
index 04ad4a2f..37697941 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -57,6 +57,11 @@
             <groupId>org.wololo</groupId>
             <artifactId>jts2geojson</artifactId>
         </dependency>
+        <dependency>
+            <groupId>io.sgr</groupId>
+            <artifactId>s2-geometry-library-java</artifactId>
+            <version>1.0.1</version>
+        </dependency>
     </dependencies>
     <build>
         <sourceDirectory>src/main/java</sourceDirectory>
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java
index ccc6f7a0..e69d65f1 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -13,6 +13,9 @@
  */
 package org.apache.sedona.common;
 
+import com.google.common.geometry.S2CellId;
+import com.google.common.geometry.S2Point;
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.sedona.common.geometryObjects.Circle;
 import org.apache.sedona.common.utils.GeomUtils;
 import org.apache.sedona.common.utils.GeometryGeoHashEncoder;
@@ -543,4 +546,26 @@ public class Functions {
         // check input geometry
         return new GeometrySplitter(GEOMETRY_FACTORY).split(input, blade);
     }
+
+
+    /**
+     * get the coordinates of a geometry and transform to Google s2 cell id
+     * @param input Geometry
+     * @return List of coordinates
+     */
+    public static Long[] getGoogleS2CellIDs(Geometry input) {
+        ArrayList<Long> cellIds = new ArrayList<>();
+        for (Coordinate coordinate: input.getCoordinates()) {
+            cellIds.add(
+                    S2CellId.fromPoint(
+                            new S2Point(
+                                    coordinate.getX(),
+                                    coordinate.getY(),
+                                    coordinate.getZ()
+                            )
+                    ).id()
+            );
+        }
+        return cellIds.toArray(new Long[cellIds.size()]);
+    }
 }
diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
index abe522f9..8e5d8641 100644
--- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
@@ -14,18 +14,11 @@
 package org.apache.sedona.common;
 
 import org.junit.Test;
-import org.locationtech.jts.geom.Coordinate;
-import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.geom.GeometryCollection;
-import org.locationtech.jts.geom.GeometryFactory;
-import org.locationtech.jts.geom.LinearRing;
-import org.locationtech.jts.geom.LineString;
-import org.locationtech.jts.geom.MultiLineString;
-import org.locationtech.jts.geom.MultiPoint;
-import org.locationtech.jts.geom.MultiPolygon;
-import org.locationtech.jts.geom.Polygon;
+import org.locationtech.jts.geom.*;
 import org.locationtech.jts.io.ParseException;
 
+import java.util.ArrayList;
+
 import static org.junit.Assert.*;
 
 public class FunctionsTest {
@@ -238,4 +231,26 @@ public class FunctionsTest {
 
         assertNull(actualResult);
     }
+
+    @Test
+    public void getGoogleS2CellIDs() {
+        Geometry[] geometries = new Geometry[]{
+                GEOMETRY_FACTORY.createPoint(new Coordinate(0.0, 0.1, 1.0)), // test point
+                GEOMETRY_FACTORY.createLineString(coordArray(0.0, 0.0, 1.5, 1.5)), // test linestring
+                GEOMETRY_FACTORY.createPolygon(coordArray(0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0)) // test polygon
+        };
+        for (Geometry geom : geometries) {
+            Long[] cellIDs = Functions.getGoogleS2CellIDs(geom);
+            if (geom instanceof Point) {
+                assertEquals(1, cellIDs.length);
+            }
+            if (geom instanceof LineString) {
+                assertEquals(2, cellIDs.length);
+            }
+            if (geom instanceof Polygon) {
+                assertEquals(4, cellIDs.length);
+            }
+        }
+
+    }
 }
diff --git a/pom.xml b/pom.xml
index f8a5227e..c76560d0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -84,6 +84,7 @@
         <geotools.scope>provided</geotools.scope>
         <!-- Because it's not in Maven central, make it provided by default -->
         <cdm.scope>provided</cdm.scope>
+        <googles2.scope>compile</googles2.scope>
     </properties>
 
     <dependencies>
diff --git a/sql/pom.xml b/sql/pom.xml
index 12461afd..e262c977 100644
--- a/sql/pom.xml
+++ b/sql/pom.xml
@@ -123,6 +123,11 @@
             <groupId>org.scalatest</groupId>
             <artifactId>scalatest_${scala.compat.version}</artifactId>
         </dependency>
+        <dependency>
+            <groupId>io.sgr</groupId>
+            <artifactId>s2-geometry-library-java</artifactId>
+            <version>1.0.1</version>
+        </dependency>
     </dependencies>
 	<build>
         <sourceDirectory>src/main/scala</sourceDirectory>
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 292bee07..2165b9ba 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -167,7 +167,8 @@ object Catalog {
     function[RS_HTML](),
     function[RS_Array](),
     function[RS_Normalize](),
-    function[RS_Append]()
+    function[RS_Append](),
+    function[ST_GetGoogleS2CellIDs](),
   )
 
   val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] = Seq(
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 5fdcc586..c9439397 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -1055,3 +1055,11 @@ case class ST_Split(inputExpressions: Seq[Expression])
     copy(inputExpressions = newChildren)
   }
 }
+
+case class ST_GetGoogleS2CellIDs(inputExpressions: Seq[Expression])
+  extends InferredUnaryExpression(Functions.getGoogleS2CellIDs) with FoldableExpression {
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+    copy(inputExpressions = newChildren)
+  }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
index ea30ecd3..4a29eb4a 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
@@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.types._
@@ -53,6 +54,8 @@ abstract class UnaryGeometryExpression extends Expression with ExpectsInputTypes
   }
 
   protected def nullSafeEval(geometry: Geometry): Any
+
+
 }
 
 abstract class BinaryGeometryExpression extends Expression with ExpectsInputTypes {
@@ -95,6 +98,8 @@ object InferrableType {
     new InferrableType[String] {}
   implicit val binaryInstance: InferrableType[Array[Byte]] =
     new InferrableType[Array[Byte]] {}
+  implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] =
+    new InferrableType[Array[java.lang.Long]] {}
 }
 
 object InferredTypes {
@@ -121,6 +126,13 @@ object InferredTypes {
       } else {
         null
       }
+    } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) {
+      output: T =>
+        if (output != null) {
+          ArrayData.toArrayData(output)
+        } else {
+          null
+        }
     } else {
       output: T => output
     }
@@ -141,6 +153,8 @@ object InferredTypes {
       StringType
     } else if (typeOf[T] =:= typeOf[Array[Byte]]) {
       BinaryType
+    } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) {
+      DataTypes.createArrayType(LongType)
     } else {
       BooleanType
     }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index 48c5a446..3b7c9bb9 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -271,4 +271,7 @@ object st_functions extends DataFrameAPI {
 
   def ST_ZMin(geometry: Column): Column = wrapExpression[ST_ZMin](geometry)
   def ST_ZMin(geometry: String): Column = wrapExpression[ST_ZMin](geometry)
+
+  def ST_GetGoogleS2CellIDs(geometry: Column): Column = wrapExpression[ST_GetGoogleS2CellIDs](geometry)
+  def ST_GetGoogleS2CellIDs(geometry: String): Column = wrapExpression[ST_GetGoogleS2CellIDs](geometry)
 }
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functions/StGetGoogleS2CellIDs.scala b/sql/src/test/scala/org/apache/sedona/sql/functions/StGetGoogleS2CellIDs.scala
new file mode 100644
index 00000000..0f98d8f1
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/functions/StGetGoogleS2CellIDs.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.functions
+
+import org.apache.sedona.sql.{GeometrySample, TestBaseScala}
+import org.apache.spark.sql.functions.{col, expr, lit, when}
+import org.scalatest.{GivenWhenThen, Matchers}
+
+class StGetGoogleS2CellIDs extends TestBaseScala with Matchers with GeometrySample with GivenWhenThen {
+  import sparkSession.implicits._
+
+  describe("should pass ST_GetGoogleS2CellIDs"){
+
+    it("should return null while using ST_GetGoogleS2CellIDs when geometry is empty") {
+      Given("DataFrame with null line strings")
+      val geometryTable = sparkSession.sparkContext.parallelize(1 to 10).toDF()
+        .withColumn("geom", lit(null))
+
+      When("using ST_MakePolygon on null geometries")
+      val geometryTableWithCellIDs = geometryTable
+        .withColumn("cell_ids", expr("ST_GetGoogleS2CellIDs(geom)"))
+
+
+      Then("no exception should be raised")
+    }
+
+    it("should correctly return array of cell ids use of ST_GetGoogleS2CellIDs"){
+      Given("DataFrame with valid line strings")
+      val geometryTable = Seq(
+        "POINT(1 2)",
+        "LINESTRING(-5 8, -6 1, -8 6, -2 5, -6 1, -5 8)",
+        "POLYGON ((75 29, 77 29, 77 29, 75 29))"
+
+      ).map(geom => Tuple1(wktReader.read(geom))).toDF("geom")
+
+      When("using ST_MakePolygon on those geometries")
+      val geometryDfWithCellIDs = geometryTable
+        .withColumn("cell_ids", expr("ST_GetGoogleS2CellIDs(geom)"))
+
+      Then("valid should have cell ids returned")
+      geometryDfWithCellIDs.selectExpr("concat_ws(',', cast(cell_ids as array<string>))")
+        .collect().map(r => r.get(0)) should contain theSameElementsAs Seq(
+        "4611686018427387905",
+        "4611686018427387905,4611686018427387905,4611686018427387905,4611686018427387905,4611686018427387905,4611686018427387905",
+        "4611686018427387905,4611686018427387905,4611686018427387905,4611686018427387905"
+      )
+
+    }
+  }
+}