You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/02/20 17:51:29 UTC
[sedona] branch master updated: [SEDONA-235] Integrate S2, add ST_S2CellIDs (#764)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new ef8a0d05 [SEDONA-235] Integrate S2, add ST_S2CellIDs (#764)
ef8a0d05 is described below
commit ef8a0d05676ecc4ec2747db5cf005928be4dead1
Author: zongsi.zhang <kr...@gmail.com>
AuthorDate: Tue Feb 21 01:51:22 2023 +0800
[SEDONA-235] Integrate S2, add ST_S2CellIDs (#764)
Co-authored-by: zzs-wherobots <zz...@ZongsideMac-Studio.local>
Co-authored-by: Zongsi Zhang <zo...@Zongsis-MacBook-Pro.local>
---
common/pom.xml | 4 +
.../java/org/apache/sedona/common/Functions.java | 42 +++--
.../org/apache/sedona/common/utils/GeomUtils.java | 24 +++
.../org/apache/sedona/common/utils/S2Utils.java | 99 ++++++++++++
.../org/apache/sedona/common/FunctionsTest.java | 177 +++++++++++++++++++--
.../org/apache/sedona/common/GeometryUtilTest.java | 67 ++++++++
.../java/org/apache/sedona/common/S2UtilTest.java | 102 ++++++++++++
docs/api/flink/Function.md | 20 +++
docs/api/sql/Function.md | 24 +++
.../main/java/org/apache/sedona/flink/Catalog.java | 1 +
.../apache/sedona/flink/expressions/Functions.java | 9 ++
.../java/org/apache/sedona/flink/FunctionTest.java | 41 ++++-
.../java/org/apache/sedona/flink/TestBase.java | 12 +-
pom.xml | 8 +-
python/.gitignore | 1 +
python/sedona/sql/st_functions.py | 14 ++
python/tests/sql/test_dataframe_api.py | 2 +
python/tests/sql/test_function.py | 14 ++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 +
.../sql/sedona_sql/expressions/Functions.scala | 8 +
.../expressions/NullSafeExpressions.scala | 14 ++
.../sql/sedona_sql/expressions/st_functions.scala | 4 +
.../apache/sedona/sql/dataFrameAPITestScala.scala | 13 +-
.../apache/sedona/sql/functions/STS2CellIDs.scala | 103 ++++++++++++
24 files changed, 772 insertions(+), 32 deletions(-)
diff --git a/common/pom.xml b/common/pom.xml
index 04ad4a2f..eb7bc188 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -57,6 +57,10 @@
<groupId>org.wololo</groupId>
<artifactId>jts2geojson</artifactId>
</dependency>
+ <dependency>
+ <groupId>com.google.geometry</groupId>
+ <artifactId>s2-geometry</artifactId>
+ </dependency>
</dependencies>
<build>
<sourceDirectory>src/main/java</sourceDirectory>
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java
index ccc6f7a0..17809520 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -13,25 +13,21 @@
*/
package org.apache.sedona.common;
+import com.google.common.geometry.S2CellId;
+import com.google.common.geometry.S2Point;
+import com.google.common.geometry.S2Region;
+import com.google.common.geometry.S2RegionCoverer;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.sedona.common.geometryObjects.Circle;
import org.apache.sedona.common.utils.GeomUtils;
import org.apache.sedona.common.utils.GeometryGeoHashEncoder;
import org.apache.sedona.common.utils.GeometrySplitter;
+import org.apache.sedona.common.utils.S2Utils;
import org.geotools.geometry.jts.JTS;
import org.geotools.referencing.CRS;
import org.locationtech.jts.algorithm.MinimumBoundingCircle;
import org.locationtech.jts.algorithm.hull.ConcaveHull;
-import org.locationtech.jts.geom.Coordinate;
-import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.geom.GeometryCollection;
-import org.locationtech.jts.geom.GeometryFactory;
-import org.locationtech.jts.geom.LineString;
-import org.locationtech.jts.geom.MultiLineString;
-import org.locationtech.jts.geom.MultiPoint;
-import org.locationtech.jts.geom.MultiPolygon;
-import org.locationtech.jts.geom.Point;
-import org.locationtech.jts.geom.Polygon;
-import org.locationtech.jts.geom.PrecisionModel;
+import org.locationtech.jts.geom.*;
import org.locationtech.jts.geom.util.GeometryFixer;
import org.locationtech.jts.io.gml2.GMLWriter;
import org.locationtech.jts.io.kml.KMLWriter;
@@ -50,7 +46,10 @@ import org.wololo.jts2geojson.GeoJSONWriter;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashSet;
import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
public class Functions {
@@ -543,4 +542,25 @@ public class Functions {
// check input geometry
return new GeometrySplitter(GEOMETRY_FACTORY).split(input, blade);
}
+
+
+ /**
+ * get the coordinates of a geometry and transform to Google s2 cell id
+ * @param input Geometry
+ * @param level integer, minimum level of cells covering the geometry
+ * @return List of coordinates
+ */
+ public static Long[] s2CellIDs(Geometry input, int level) {
+ HashSet<S2CellId> cellIds = new HashSet<>();
+ List<Geometry> geoms = GeomUtils.extractGeometryCollection(input);
+ for (Geometry geom : geoms) {
+ try {
+ cellIds.addAll(S2Utils.s2RegionToCellIDs(S2Utils.toS2Region(geom), 1, level, Integer.MAX_VALUE - 1));
+ } catch (IllegalArgumentException e) {
+ // the geometry can't be cast to region, we cover its coordinates, for example, Point
+ cellIds.addAll(Arrays.stream(geom.getCoordinates()).map(c -> S2Utils.coordinateToCellID(c, level)).collect(Collectors.toList()));
+ }
+ }
+ return S2Utils.roundCellsToSameLevel(new ArrayList<>(cellIds), level).stream().map(S2CellId::id).collect(Collectors.toList()).toArray(new Long[cellIds.size()]);
+ }
}
diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
index 5c0440cb..98d11e60 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/GeomUtils.java
@@ -381,4 +381,28 @@ public class GeomUtils {
}
return pCount;
}
+
+ public static List<Geometry> extractGeometryCollection(Geometry geom){
+ ArrayList<Geometry> leafs = new ArrayList<>();
+ if (!(geom instanceof GeometryCollection)) {
+ leafs.add(geom);
+ return leafs;
+ }
+ LinkedList<GeometryCollection> parents = new LinkedList<>();
+ parents.add((GeometryCollection) geom);
+ while (!parents.isEmpty()) {
+ GeometryCollection parent = parents.removeFirst();
+ for (int i = 0;i < parent.getNumGeometries(); i++) {
+ Geometry child = parent.getGeometryN(i);
+ if (child instanceof GeometryCollection) {
+ parents.add((GeometryCollection) child);
+ } else {
+ leafs.add(child);
+ }
+ }
+ }
+ return leafs;
+ }
+
+
}
diff --git a/common/src/main/java/org/apache/sedona/common/utils/S2Utils.java b/common/src/main/java/org/apache/sedona/common/utils/S2Utils.java
new file mode 100644
index 00000000..3a59dc59
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/utils/S2Utils.java
@@ -0,0 +1,99 @@
+package org.apache.sedona.common.utils;
+
+import com.google.common.geometry.*;
+import org.locationtech.jts.algorithm.Orientation;
+import org.locationtech.jts.geom.*;
+
+import javax.swing.*;
+import java.util.*;
+import java.util.stream.Collectors;
+
+public class S2Utils {
+
+ /**
+ * @param coord Coordinate: convert a jts coordinate to a S2Point
+ * @return
+ */
+ public static S2Point toS2Point(Coordinate coord) {
+ return S2LatLng.fromDegrees(coord.y, coord.x).toPoint();
+ }
+
+ public static List<S2Point> toS2Points(Coordinate[] coords) {
+ return Arrays.stream(coords).map(S2Utils::toS2Point).collect(Collectors.toList());
+ }
+
+ /**
+ * @param line
+ * @return
+ */
+ public static S2Polyline toS2PolyLine(LineString line) {
+ return new S2Polyline(toS2Points(line.getCoordinates()));
+ }
+
+ public static S2Loop toS2Loop(LinearRing ring) {
+ return new S2Loop(
+ Orientation.isCCW(ring.getCoordinates()) ? toS2Points(ring.getCoordinates()) : toS2Points(ring.reverse().getCoordinates())
+ );
+ }
+
+ public static S2Polygon toS2Polygon(Polygon polygon) {
+ List<LinearRing> rings = new ArrayList<>();
+ rings.add(polygon.getExteriorRing());
+ for (int i = 0; i < polygon.getNumInteriorRing(); i++){
+ rings.add(polygon.getInteriorRingN(i));
+ }
+ List<S2Loop> s2Loops = rings.stream().map(
+ S2Utils::toS2Loop
+ ).collect(Collectors.toList());
+ return new S2Polygon(s2Loops);
+ }
+
+ public static List<S2CellId> s2RegionToCellIDs(S2Region region, int minLevel, int maxLevel, int maxNum) {
+ S2RegionCoverer.Builder coverBuilder = S2RegionCoverer.builder();
+ coverBuilder.setMinLevel(minLevel);
+ coverBuilder.setMaxLevel(maxLevel);
+ coverBuilder.setMaxCells(maxNum);
+ return coverBuilder.build().getCovering(region).cellIds();
+ }
+
+ public static S2CellId coordinateToCellID(Coordinate coordinate, int level) {
+ return S2CellId.fromPoint(S2Utils.toS2Point(coordinate)).parent(level);
+ }
+
+ public static List<S2CellId> roundCellsToSameLevel(List<S2CellId> cellIDs, int level) {
+ Set<Long> results = new HashSet<>();
+ for (S2CellId cellID : cellIDs) {
+ if (cellID.level() > level) {
+ results.add(cellID.parent(level).id());
+ } else if(cellID.level() < level) {
+ for (S2CellId c = cellID.childBegin(level); !c.equals(cellID.childEnd(level)); c = c.next()) {
+ results.add(c.id());
+ }
+ } else {
+ results.add(cellID.id());
+ }
+ }
+ return results.stream().map(S2CellId::new).collect(Collectors.toList());
+ }
+
+ public static Polygon toJTSPolygon(S2CellId cellId) {
+ S2LatLngRect bound = new S2Cell(cellId).getRectBound();
+ Coordinate[] coords = new Coordinate[5];
+ int[] iters = new int[] {0, 1, 2, 3, 0};
+ for (int i = 0;i < 5; i++) {
+ coords[i] = new Coordinate(bound.getVertex(iters[i]).lngDegrees(), bound.getVertex(iters[i]).latDegrees());
+ }
+ return new GeometryFactory().createPolygon(coords);
+ }
+
+ public static S2Region toS2Region(Geometry geom) throws IllegalArgumentException {
+ if (geom instanceof Polygon) {
+ return S2Utils.toS2Polygon((Polygon) geom);
+ } else if (geom instanceof LineString) {
+ return S2Utils.toS2PolyLine((LineString) geom);
+ }
+ throw new IllegalArgumentException(
+ "only object of Polygon, LinearRing, LineString type can be converted to S2Region"
+ );
+ }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
index abe522f9..eabaaf36 100644
--- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java
@@ -13,20 +13,19 @@
*/
package org.apache.sedona.common;
+import com.google.common.geometry.S2CellId;
+import org.apache.sedona.common.utils.GeomUtils;
+import org.apache.sedona.common.utils.S2Utils;
import org.junit.Test;
-import org.locationtech.jts.geom.Coordinate;
-import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.geom.GeometryCollection;
-import org.locationtech.jts.geom.GeometryFactory;
-import org.locationtech.jts.geom.LinearRing;
-import org.locationtech.jts.geom.LineString;
-import org.locationtech.jts.geom.MultiLineString;
-import org.locationtech.jts.geom.MultiPoint;
-import org.locationtech.jts.geom.MultiPolygon;
-import org.locationtech.jts.geom.Polygon;
-import org.locationtech.jts.io.ParseException;
-
-import static org.junit.Assert.*;
+import org.locationtech.jts.geom.*;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
public class FunctionsTest {
public static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();
@@ -238,4 +237,156 @@ public class FunctionsTest {
assertNull(actualResult);
}
+
+ private static boolean intersects(Set<?> s1, Set<?> s2) {
+ Set<?> copy = new HashSet<>(s1);
+ copy.retainAll(s2);
+ return !copy.isEmpty();
+ }
+
+ @Test
+ public void getGoogleS2CellIDsPoint() {
+ Point point = GEOMETRY_FACTORY.createPoint(new Coordinate(1, 2));
+ Long[] cid = Functions.s2CellIDs(point, 30);
+ Polygon reversedPolygon = S2Utils.toJTSPolygon(new S2CellId(cid[0]));
+ // cast the cell to a rectangle, it must be able to cover the points
+ assert(reversedPolygon.contains(point));
+ }
+
+ @Test
+ public void getGoogleS2CellIDsPolygon() {
+ // polygon with holes
+ Polygon target = GEOMETRY_FACTORY.createPolygon(
+ GEOMETRY_FACTORY.createLinearRing(coordArray(0.1, 0.1, 0.5, 0.1, 1.0, 0.3, 1.0, 1.0, 0.1, 1.0, 0.1, 0.1)),
+ new LinearRing[] {
+ GEOMETRY_FACTORY.createLinearRing(coordArray(0.2, 0.2, 0.5, 0.2, 0.6, 0.7, 0.2, 0.6, 0.2, 0.2))
+ }
+ );
+ // polygon inside the hole, shouldn't intersect with the polygon
+ Polygon polygonInHole = GEOMETRY_FACTORY.createPolygon(coordArray(0.3, 0.3, 0.4, 0.3, 0.3, 0.4, 0.3, 0.3));
+ // mbr of the polygon that cover all
+ Geometry mbr = target.getEnvelope();
+ HashSet<Long> targetCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(target, 10)));
+ HashSet<Long> inHoleCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(polygonInHole, 10)));
+ HashSet<Long> mbrCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(mbr, 10)));
+ assert mbrCells.containsAll(targetCells);
+ assert !intersects(targetCells, inHoleCells);
+ assert mbrCells.containsAll(targetCells);
+ }
+
+ @Test
+ public void getGoogleS2CellIDsLineString() {
+ // polygon with holes
+ LineString target = GEOMETRY_FACTORY.createLineString(coordArray(0.2, 0.2, 0.3, 0.4, 0.4, 0.6));
+ LineString crossLine = GEOMETRY_FACTORY.createLineString(coordArray(0.4, 0.1, 0.1, 0.4));
+ // mbr of the polygon that cover all
+ Geometry mbr = target.getEnvelope();
+ // cover the target polygon, and convert cells back to polygons
+ HashSet<Long> targetCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(target, 15)));
+ HashSet<Long> crossCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(crossLine, 15)));
+ HashSet<Long> mbrCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(mbr, 15)));
+ assert intersects(targetCells, crossCells);
+ assert mbrCells.containsAll(targetCells);
+ }
+
+ @Test
+ public void getGoogleS2CellIDsMultiPolygon() {
+ // polygon with holes
+ Polygon[] geoms = new Polygon[] {
+ GEOMETRY_FACTORY.createPolygon(coordArray(0.1, 0.1, 0.5, 0.1, 0.1, 0.6, 0.1, 0.1)),
+ GEOMETRY_FACTORY.createPolygon(coordArray(0.2, 0.1, 0.6, 0.3, 0.7, 0.6, 0.2, 0.5, 0.2, 0.1))
+ };
+ MultiPolygon target = GEOMETRY_FACTORY.createMultiPolygon(geoms);
+ Geometry mbr = target.getEnvelope();
+ HashSet<Long> targetCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(target, 10)));
+ HashSet<Long> mbrCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(mbr, 10)));
+ HashSet<Long> separateCoverCells = new HashSet<>();
+ for(Geometry geom: geoms) {
+ separateCoverCells.addAll(Arrays.asList(Functions.s2CellIDs(geom, 10)));
+ }
+ assert mbrCells.containsAll(targetCells);
+ assert targetCells.equals(separateCoverCells);
+ }
+
+ @Test
+ public void getGoogleS2CellIDsMultiLineString() {
+ // polygon with holes
+ MultiLineString target = GEOMETRY_FACTORY.createMultiLineString(
+ new LineString[] {
+ GEOMETRY_FACTORY.createLineString(coordArray(0.1, 0.1, 0.2, 0.1, 0.3, 0.4, 0.5, 0.9)),
+ GEOMETRY_FACTORY.createLineString(coordArray(0.5, 0.1, 0.1, 0.5, 0.3, 0.1))
+ }
+ );
+ Geometry mbr = target.getEnvelope();
+ Point outsidePoint = GEOMETRY_FACTORY.createPoint(new Coordinate(0.3, 0.7));
+ HashSet<Long> targetCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(target, 10)));
+ HashSet<Long> mbrCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(mbr, 10)));
+ Long outsideCell = Functions.s2CellIDs(outsidePoint, 10)[0];
+ // the cells should all be within mbr
+ assert mbrCells.containsAll(targetCells);
+ // verify point within mbr but shouldn't intersect with linestring
+ assert mbrCells.contains(outsideCell);
+ assert !targetCells.contains(outsideCell);
+ }
+
+ @Test
+ public void getGoogleS2CellIDsMultiPoint() {
+ // polygon with holes
+ MultiPoint target = GEOMETRY_FACTORY.createMultiPoint(new Point[] {
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.1, 0.1)),
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.2, 0.1)),
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.3, 0.2)),
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.5, 0.4))
+ });
+ Geometry mbr = target.getEnvelope();
+ Point outsidePoint = GEOMETRY_FACTORY.createPoint(new Coordinate(0.3, 0.7));
+ HashSet<Long> targetCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(target, 10)));
+ HashSet<Long> mbrCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(mbr, 10)));
+ // the cells should all be within mbr
+ assert mbrCells.containsAll(targetCells);
+ assert targetCells.size() == 4;
+ }
+
+ @Test
+ public void getGoogleS2CellIDsGeometryCollection() {
+ // polygon with holes
+ Geometry[] geoms = new Geometry[] {
+ GEOMETRY_FACTORY.createLineString(coordArray(0.1, 0.1, 0.2, 0.1, 0.3, 0.4, 0.5, 0.9)),
+ GEOMETRY_FACTORY.createPolygon(coordArray(0.1, 0.1, 0.5, 0.1, 0.1, 0.6, 0.1, 0.1)),
+ GEOMETRY_FACTORY.createMultiPoint(new Point[] {
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.1, 0.1)),
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.2, 0.1)),
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.3, 0.2)),
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.5, 0.4))
+ })
+ };
+ GeometryCollection target = GEOMETRY_FACTORY.createGeometryCollection(geoms);
+ Geometry mbr = target.getEnvelope();
+ HashSet<Long> targetCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(target, 10)));
+ HashSet<Long> mbrCells = new HashSet<>(Arrays.asList(Functions.s2CellIDs(mbr, 10)));
+ HashSet<Long> separateCoverCells = new HashSet<>();
+ for(Geometry geom: geoms) {
+ separateCoverCells.addAll(Arrays.asList(Functions.s2CellIDs(geom, 10)));
+ }
+ // the cells should all be within mbr
+ assert mbrCells.containsAll(targetCells);
+ // separately cover should return same result as covered together
+ assert separateCoverCells.equals(targetCells);
+ }
+
+ @Test
+ public void getGoogleS2CellIDsAllSameLevel() {
+ // polygon with holes
+ GeometryCollection target = GEOMETRY_FACTORY.createGeometryCollection(
+ new Geometry[]{
+ GEOMETRY_FACTORY.createPolygon(coordArray(0.3, 0.3, 0.4, 0.3, 0.3, 0.4, 0.3, 0.3)),
+ GEOMETRY_FACTORY.createPoint(new Coordinate(0.7, 1.2))
+ }
+ );
+ Long[] cellIds = Functions.s2CellIDs(target, 10);
+ HashSet<Integer> levels = Arrays.stream(cellIds).map(c -> new S2CellId(c).level()).collect(Collectors.toCollection(HashSet::new));
+ HashSet<Integer> expects = new HashSet<>();
+ expects.add(10);
+ assertEquals(expects, levels);
+ }
}
diff --git a/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java b/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
new file mode 100644
index 00000000..855b077f
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/GeometryUtilTest.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common;
+
+import org.apache.sedona.common.utils.GeomUtils;
+import org.junit.Test;
+import org.locationtech.jts.geom.*;
+import org.locationtech.jts.io.ParseException;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public class GeometryUtilTest {
+ public static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();
+
+ private Coordinate[] coordArray(double... coordValues) {
+ Coordinate[] coords = new Coordinate[(int)(coordValues.length / 2)];
+ for (int i = 0; i < coordValues.length; i += 2) {
+ coords[(int)(i / 2)] = new Coordinate(coordValues[i], coordValues[i+1]);
+ }
+ return coords;
+ }
+
+ @Test
+ public void extractGeometryCollection() throws ParseException, IOException {
+ MultiPolygon multiPolygon = GEOMETRY_FACTORY.createMultiPolygon(
+ new Polygon[] {
+ GEOMETRY_FACTORY.createPolygon(coordArray(0, 1,3, 0,4, 3,0, 4,0, 1)),
+ GEOMETRY_FACTORY.createPolygon(coordArray(3, 4,6, 3,5, 5,3, 4))
+ }
+ );
+ Point point = GEOMETRY_FACTORY.createPoint(new Coordinate(5.0, 8.0));
+ GeometryCollection gc1 = GEOMETRY_FACTORY.createGeometryCollection(new Geometry[] {
+ multiPolygon, point
+ });
+ GeometryCollection gc = GEOMETRY_FACTORY.createGeometryCollection(
+ new Geometry[] {
+ multiPolygon.copy(),
+ gc1
+ }
+ );
+ List<Geometry> geoms = GeomUtils.extractGeometryCollection(gc);
+ assert (
+ Objects.equals(
+ GeomUtils.getWKT(
+ GEOMETRY_FACTORY.createGeometryCollection(
+ geoms.toArray(new Geometry[geoms.size()]))),
+ "GEOMETRYCOLLECTION (POLYGON ((0 1, 3 0, 4 3, 0 4, 0 1)), POLYGON ((3 4, 6 3, 5 5, 3 4)), POINT (5 8), POLYGON ((0 1, 3 0, 4 3, 0 4, 0 1)), POLYGON ((3 4, 6 3, 5 5, 3 4)))"
+ )
+ );
+
+ }
+
+
+}
diff --git a/common/src/test/java/org/apache/sedona/common/S2UtilTest.java b/common/src/test/java/org/apache/sedona/common/S2UtilTest.java
new file mode 100644
index 00000000..22a58f30
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/S2UtilTest.java
@@ -0,0 +1,102 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common;
+
+import com.google.common.geometry.S2CellId;
+import com.google.common.geometry.S2LatLng;
+import com.google.common.geometry.S2Loop;
+import com.google.common.geometry.S2Point;
+import org.apache.sedona.common.utils.S2Utils;
+import org.junit.Test;
+import org.locationtech.jts.geom.*;
+import org.locationtech.jts.io.ParseException;
+import org.locationtech.jts.io.WKTReader;
+
+import java.text.DecimalFormat;
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+
+public class S2UtilTest {
+ public static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();
+
+ private Coordinate[] coordArray(double... coordValues) {
+ Coordinate[] coords = new Coordinate[(int)(coordValues.length / 2)];
+ for (int i = 0; i < coordValues.length; i += 2) {
+ coords[(int)(i / 2)] = new Coordinate(coordValues[i], coordValues[i+1]);
+ }
+ return coords;
+ }
+
+ @Test
+ public void toS2Point() throws ParseException {
+ Coordinate jtsCoord = new Coordinate(1, 2);
+ S2Point s2Point = S2Utils.toS2Point(jtsCoord);
+ S2LatLng latLng = new S2LatLng(s2Point);
+ assertEquals(Math.round(latLng.lngDegrees()), 1);
+ assertEquals(Math.round(latLng.latDegrees()), 2);
+ }
+
+ @Test
+ public void coverPolygon() throws ParseException {
+ Polygon polygon = (Polygon) new WKTReader().read("POLYGON ((0.5 0.5, 5 0, 6 3, 5 5, 0 5, 0.5 0.5))");
+ List<S2CellId> cellIds = S2Utils.s2RegionToCellIDs(S2Utils.toS2Polygon(polygon), 1, 5, Integer.MAX_VALUE);
+ assertEquals(5, cellIds.size());
+ assertEquals(cellIds.stream().max(Comparator.comparingLong(S2CellId::level)).get().level(), 5);
+ }
+
+ @Test
+ public void coverPolygonWithHole() throws ParseException {
+ Polygon polygon = (Polygon) new WKTReader().read("POLYGON((0.5 0.5,5 0,5 5,0 5,0.5 0.5), (1.5 1,4 1,4 3,1.5 1))");
+ Polygon hole = (Polygon) new WKTReader().read("POLYGON((1.5 1,4 1,4 3,1.5 1))");
+ List<S2CellId> cellIds = S2Utils.roundCellsToSameLevel(S2Utils.s2RegionToCellIDs(S2Utils.toS2Polygon(polygon), 1, 8, Integer.MAX_VALUE-1), 8);
+ S2CellId holeCentroidCell = S2Utils.coordinateToCellID(hole.getCentroid().getCoordinate(), 8);
+ S2CellId holeFirstVertexCell = S2Utils.coordinateToCellID(hole.getCoordinates()[0], 8);
+ assertEquals(8, cellIds.stream().max(Comparator.comparingLong(S2CellId::level)).get().level());
+ assert(!cellIds.contains(holeCentroidCell));
+ assert(cellIds.contains(holeFirstVertexCell));
+
+ }
+
+ @Test
+ public void coverLineString() throws ParseException {
+ LineString line = (LineString) new WKTReader().read("LINESTRING (1.5 2.45, 3.21 4)");
+ List<S2CellId> cellIds = S2Utils.s2RegionToCellIDs(S2Utils.toS2PolyLine(line), 1, 8, Integer.MAX_VALUE);
+ assertEquals(12, cellIds.size());
+ assertEquals(cellIds.stream().max(Comparator.comparingLong(S2CellId::level)).get().level(), 8);
+ }
+
+ @Test
+ public void coverLinearLoop() throws ParseException {
+ LineString line = GEOMETRY_FACTORY.createLineString(new WKTReader().read("LINESTRING (1.5 2.45, 3.21 4, 5 2, 1.5 2.45)").getCoordinates());
+ List<S2CellId> cellIds = S2Utils.s2RegionToCellIDs(S2Utils.toS2PolyLine(line), 1, 8, Integer.MAX_VALUE);
+ assertEquals(31, cellIds.size());
+ assertEquals(cellIds.stream().max(Comparator.comparingLong(S2CellId::level)).get().level(), 8);
+ }
+
+ @Test
+ public void toS2Loop() throws ParseException {
+ LinearRing ringCW = GEOMETRY_FACTORY.createLinearRing(new WKTReader().read("LINESTRING (1.5 2.45, 3.21 4, 5 2, 1.5 2.45)").getCoordinates());
+ LinearRing ringCCW = GEOMETRY_FACTORY.createLinearRing(new WKTReader().read("LINESTRING (1.5 2.45, 5 2, 3.21 4, 1.5 2.45)").getCoordinates());
+ assert(ringCCW != ringCW);
+ S2Loop s2Loop = S2Utils.toS2Loop(ringCW);
+ DecimalFormat df = new DecimalFormat("#.##");
+ LinearRing reversedRing = GEOMETRY_FACTORY.createLinearRing(
+ s2Loop.vertices().stream().map(S2LatLng::new).map(l -> new Coordinate(Double.parseDouble(df.format(l.lngDegrees())), Double.parseDouble(df.format(l.latDegrees())))).collect(Collectors.toList()).toArray(new Coordinate[4])
+ );
+ assertEquals(ringCCW, reversedRing);
+ }
+}
diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md
index 72dd31f6..ee2a4f60 100644
--- a/docs/api/flink/Function.md
+++ b/docs/api/flink/Function.md
@@ -681,6 +681,26 @@ SELECT ST_RemovePoint(ST_GeomFromText("LINESTRING(0 0, 1 1, 1 0)"), 1)
Output: `LINESTRING(0 0, 1 0)`
+## ST_S2CellIDs
+
+Introduction: Cover the geometry with Google S2 Cells, return the corresponding cell IDs with the given level.
+The level indicates the [size of cells](https://s2geometry.io/resources/s2cell_statistics.html). With a bigger level,
+the cells will be smaller, the coverage will be more accurate, but the result size will be exponentially increasing.
+
+Format: `ST_S2CellIDs(geom: geometry, level: Int)`
+
+Since: `v1.4.0`
+
+Example:
+```SQL
+SELECT ST_S2CellIDs(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'), 6)
+```
+
+Output:
+```
+[1159395429071192064, 1159958379024613376, 1160521328978034688, 1161084278931456000, 1170091478186196992, 1170654428139618304]
+```
+
## ST_SetPoint
Introduction: Replace Nth point of linestring with given point. Index is 0-based. Negative index are counted backwards, e.g., -1 is last point.
diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md
index 9dfeb3ad..d15bc735 100644
--- a/docs/api/sql/Function.md
+++ b/docs/api/sql/Function.md
@@ -1101,6 +1101,30 @@ Result:
+---------------------------------------------------------------+
```
+## ST_S2CellIDs
+
+Introduction: Cover the geometry with Google S2 Cells, return the corresponding cell IDs with the given level.
+The level indicates the [size of cells](https://s2geometry.io/resources/s2cell_statistics.html). With a bigger level,
+the cells will be smaller, the coverage will be more accurate, but the result size will be exponentially increasing.
+
+Format: `ST_S2CellIDs(geom: geometry, level: Int)`
+
+Since: `v1.4.0`
+
+Spark SQL example:
+```SQL
+SELECT ST_S2CellIDs(ST_GeomFromText('LINESTRING(1 3 4, 5 6 7)'), 6)
+```
+
+Output:
+```
++------------------------------------------------------------------------------------------------------------------------------+
+|st_s2cellids(st_geomfromtext(LINESTRING(1 3 4, 5 6 7), 0), 6) |
++------------------------------------------------------------------------------------------------------------------------------+
+|[1159395429071192064, 1159958379024613376, 1160521328978034688, 1161084278931456000, 1170091478186196992, 1170654428139618304]|
++------------------------------------------------------------------------------------------------------------------------------+
+```
+
## ST_SetPoint
Introduction: Replace Nth point of linestring with given point. Index is 0-based. Negative index are counted backwards, e.g., -1 is last point.
diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
index 40e4c4dc..83a99029 100644
--- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java
+++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java
@@ -89,6 +89,7 @@ public class Catalog {
new Functions.ST_SetPoint(),
new Functions.ST_LineFromMultiPoint(),
new Functions.ST_Split(),
+ new Functions.ST_S2CellIDs()
};
}
diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
index 198df1d3..a22b91c2 100644
--- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
+++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java
@@ -488,4 +488,13 @@ public class Functions {
return org.apache.sedona.common.Functions.split(input, blade);
}
}
+
+ public static class ST_S2CellIDs extends ScalarFunction {
+ @DataTypeHint(value = "ARRAY<BIGINT>")
+ public Long[] eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o,
+ @DataTypeHint("INT") Integer level) {
+ Geometry geom = (Geometry) o;
+ return org.apache.sedona.common.Functions.s2CellIDs(geom, level);
+ }
+ }
}
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index 4e2a0a14..4dcb3aab 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -27,13 +27,14 @@ import org.locationtech.jts.geom.Polygon;
import org.opengis.referencing.FactoryException;
import org.opengis.referencing.crs.CoordinateReferenceSystem;
+import java.util.Arrays;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
import static junit.framework.TestCase.assertNull;
import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.call;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
public class FunctionTest extends TestBase{
@BeforeClass
@@ -581,4 +582,36 @@ public class FunctionTest extends TestBase{
Table pointTable = tableEnv.sqlQuery("SELECT ST_Split(ST_GeomFromWKT('LINESTRING (0 0, 1.5 1.5, 2 2)'), ST_GeomFromWKT('MULTIPOINT (0.5 0.5, 1 1)'))");
assertEquals("MULTILINESTRING ((0 0, 0.5 0.5), (0.5 0.5, 1 1), (1 1, 1.5 1.5, 2 2))", ((Geometry)first(pointTable).getField(0)).norm().toText());
}
+
+ @Test
+ public void testS2CellIDs() {
+ String initExplodeQuery = "SELECT id, geom, cell_tbl.cell from (VALUES %s) as raw_tbl(id, geom, cells) CROSS JOIN UNNEST(raw_tbl.cells) AS cell_tbl (cell)";
+ // left is a polygon
+ tableEnv.createTemporaryView(
+ "lefts",
+ tableEnv.sqlQuery(String.format(initExplodeQuery, "(1, ST_GeomFromWKT('POLYGON ((0 0, 0.2 0, 0.2 0.2, 0 0.2, 0 0))'), ST_S2CellIDs(ST_GeomFromWKT('POLYGON ((0 0, 0.2 0, 0.2 0.2, 0 0.2, 0 0))'), 10))"))
+ );
+ // points for test
+ String points = String.join(", ", new String[] {
+ "(2, ST_GeomFromWKT('POINT (0.1 0.1)'), ST_S2CellIDs(ST_GeomFromWKT('POINT (0.1 0.1)'), 10))", // points within polygon
+ "(3, ST_GeomFromWKT('POINT (0.25 0.1)'), ST_S2CellIDs(ST_GeomFromWKT('POINT (0.25 0.1)'), 10))", // points outside of polygon
+ "(4, ST_GeomFromWKT('POINT (0.2005 0.1)'), ST_S2CellIDs(ST_GeomFromWKT('POINT (0.2005 0.1)'), 10))" // points outside of polygon, but very close to border
+ });
+ tableEnv.createTemporaryView(
+ "rights",
+ tableEnv.sqlQuery(String.format(initExplodeQuery, points))
+ );
+ Table joinTable = tableEnv.sqlQuery("select lefts.id, rights.id from lefts join rights on lefts.cell = rights.cell group by (lefts.id, rights.id)");
+ assertEquals(2, count(joinTable));
+ ;
+ assert take(joinTable, 2).stream().map(
+ r -> Objects.requireNonNull(r.getField(1)).toString()
+ ).collect(Collectors.toSet()).containsAll(Arrays.asList("2", "4"));
+ // This is due to under level = 10, point id = 4 fall into same cell as the boarder of polygon id = 1
+ // join and filter by st_intersects to exclude the wrong join
+ Table joinCleanedTable = tableEnv.sqlQuery("select lefts.id, rights.id from lefts join rights on lefts.cell = rights.cell where ST_Intersects(lefts.geom, rights.geom) is true group by (lefts.id, rights.id)");
+ // after filter by ST_Intersects, only id =2 point
+ assertEquals(1, count(joinCleanedTable));
+ assertEquals(2, first(joinCleanedTable).getField(1));
+ }
}
diff --git a/flink/src/test/java/org/apache/sedona/flink/TestBase.java b/flink/src/test/java/org/apache/sedona/flink/TestBase.java
index 86d1774f..be88c5fb 100644
--- a/flink/src/test/java/org/apache/sedona/flink/TestBase.java
+++ b/flink/src/test/java/org/apache/sedona/flink/TestBase.java
@@ -34,8 +34,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import static org.apache.flink.table.api.Expressions.$;
-import static org.apache.flink.table.api.Expressions.call;
+import static org.apache.flink.table.api.Expressions.*;
public class TestBase {
protected static StreamExecutionEnvironment env;
@@ -380,6 +379,15 @@ public class TestBase {
return firstRow;
}
+ static List<Row> take(Table table, int n) {
+ CloseableIterator<Row> it = iterate(table);
+ List<Row> rows = new ArrayList<>();
+ while (it.hasNext() && rows.size() < n ) {
+ rows.add(it.next());
+ }
+ return rows;
+ }
+
static long count(Table table) {
CloseableIterator<Row> it = iterate(table);
long count = 0;
diff --git a/pom.xml b/pom.xml
index 8c85ef11..7bc4bc7d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -76,7 +76,8 @@
<slf4j.version>1.7.36</slf4j.version>
<spark.version>3.3.0</spark.version>
<spark.compat.version>3.0</spark.compat.version>
-
+ <googles2.version>2.0.0</googles2.version>
+
<!-- Actual scala version will be set by a profile.
Setting a default value helps IDE:s that can't make sense of profiles. -->
<scala.compat.version>2.12</scala.compat.version>
@@ -319,6 +320,11 @@
<version>3.1.1</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>com.google.geometry</groupId>
+ <artifactId>s2-geometry</artifactId>
+ <version>${googles2.version}</version>
+ </dependency>
</dependencies>
</dependencyManagement>
<repositories>
diff --git a/python/.gitignore b/python/.gitignore
index bda58367..9c41b9af 100644
--- a/python/.gitignore
+++ b/python/.gitignore
@@ -5,3 +5,4 @@
/apache_sedona.egg-info/
/build/
/dist/
+/sedona/utils/*.so
diff --git a/python/sedona/sql/st_functions.py b/python/sedona/sql/st_functions.py
index 6e5932af..d2fd8edc 100644
--- a/python/sedona/sql/st_functions.py
+++ b/python/sedona/sql/st_functions.py
@@ -82,6 +82,7 @@ __all__ = [
"ST_PrecisionReduce",
"ST_RemovePoint",
"ST_Reverse",
+ "ST_S2CellIDs",
"ST_SetPoint",
"ST_SetSRID",
"ST_SRID",
@@ -877,6 +878,19 @@ def ST_Reverse(geometry: ColumnOrName) -> Column:
"""
return _call_st_function("ST_Reverse", geometry)
+@validate_argument_types
+def ST_S2CellIDs(geometry: ColumnOrName, level: Union[ColumnOrName, int]) -> Column:
+ """Cover Geometry with S2 Cells and return a List of Long type cell IDs
+ :param geometry: Geometry column to generate cell IDs
+ :type geometry: ColumnOrName
+ :param level: value between 1 and 31, controls the size of the cells used for coverage. With a bigger level, the cells will be smaller, the coverage will be more accurate, but the result size will be exponentially increasing.
+ :type level: int
+ :return: List of cellIDs
+ :rtype: List[long]
+ """
+ args = (geometry, level)
+ return _call_st_function("ST_S2CellIDs", args)
+
@validate_argument_types
def ST_SetPoint(line_string: ColumnOrName, index: Union[ColumnOrName, int], point: ColumnOrName) -> Column:
diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py
index b0926d04..dd5f275a 100644
--- a/python/tests/sql/test_dataframe_api.py
+++ b/python/tests/sql/test_dataframe_api.py
@@ -110,6 +110,7 @@ test_configurations = [
(stf.ST_PrecisionReduce, ("geom", 1), "precision_reduce_point", "", "POINT (0.1 0.2)"),
(stf.ST_RemovePoint, ("line", 1), "linestring_geom", "", "LINESTRING (0 0, 2 0, 3 0, 4 0, 5 0)"),
(stf.ST_Reverse, ("line",), "linestring_geom", "", "LINESTRING (5 0, 4 0, 3 0, 2 0, 1 0, 0 0)"),
+ (stf.ST_S2CellIDs, ("point", 30), "point_geom", "", [1153451514845492609]),
(stf.ST_SetPoint, ("line", 1, lambda: f.expr("ST_Point(1.0, 1.0)")), "linestring_geom", "", "LINESTRING (0 0, 1 1, 2 0, 3 0, 4 0, 5 0)"),
(stf.ST_SetSRID, ("point", 3021), "point_geom", "ST_SRID(geom)", 3021),
(stf.ST_SimplifyPreserveTopology, ("geom", 0.2), "0.9_poly", "", "POLYGON ((0 0, 1 0, 1 1, 0 0))"),
@@ -254,6 +255,7 @@ wrong_type_configurations = [
(stf.ST_RemovePoint, ("", None)),
(stf.ST_RemovePoint, ("", 1.0)),
(stf.ST_Reverse, (None,)),
+ (stf.ST_S2CellIDs, (None, 2)),
(stf.ST_SetPoint, (None, 1, "")),
(stf.ST_SetPoint, ("", None, "")),
(stf.ST_SetPoint, ("", 1, None)),
diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py
index 50bbb3c8..059fd30b 100644
--- a/python/tests/sql/test_function.py
+++ b/python/tests/sql/test_function.py
@@ -1060,3 +1060,17 @@ class TestPredicateJoin(TestBase):
for input_geom, expected_geom in test_cases.items():
line_geometry = self.spark.sql("select ST_AsText(ST_LineFromMultiPoint(ST_GeomFromText({})))".format(input_geom))
assert line_geometry.take(1)[0][0] == expected_geom
+
+ def test_st_s2_cell_ids(self):
+ test_cases = [
+ "'POLYGON((-1 0, 1 0, 0 0, 0 1, -1 0))'",
+ "'LINESTRING(0 0, 1 2, 2 4, 3 6)'",
+ "'POINT(1 2)'"
+ ]
+ for input_geom in test_cases:
+ cell_ids = self.spark.sql("select ST_S2CellIDs(ST_GeomFromText({}), 6)".format(input_geom)).take(1)[0][0]
+ assert isinstance(cell_ids, list)
+ assert isinstance(cell_ids[0], int)
+ # test null case
+ cell_ids = self.spark.sql("select ST_S2CellIDs(null, 6)").take(1)[0][0]
+ assert cell_ids is None
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 292bee07..f637cb34 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -141,6 +141,7 @@ object Catalog {
function[ST_MPolyFromText](0),
function[ST_MLineFromText](0),
function[ST_Split](),
+ function[ST_S2CellIDs](),
// Expression for rasters
function[RS_NormalizedDifference](),
function[RS_Mean](),
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 5fdcc586..2b850cff 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -1055,3 +1055,11 @@ case class ST_Split(inputExpressions: Seq[Expression])
copy(inputExpressions = newChildren)
}
}
+
+case class ST_S2CellIDs(inputExpressions: Seq[Expression])
+ extends InferredBinaryExpression(Functions.s2CellIDs) with FoldableExpression {
+
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+ copy(inputExpressions = newChildren)
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
index ea30ecd3..4a29eb4a 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
@@ -21,6 +21,7 @@ package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types._
@@ -53,6 +54,8 @@ abstract class UnaryGeometryExpression extends Expression with ExpectsInputTypes
}
protected def nullSafeEval(geometry: Geometry): Any
+
+
}
abstract class BinaryGeometryExpression extends Expression with ExpectsInputTypes {
@@ -95,6 +98,8 @@ object InferrableType {
new InferrableType[String] {}
implicit val binaryInstance: InferrableType[Array[Byte]] =
new InferrableType[Array[Byte]] {}
+ implicit val longArrayInstance: InferrableType[Array[java.lang.Long]] =
+ new InferrableType[Array[java.lang.Long]] {}
}
object InferredTypes {
@@ -121,6 +126,13 @@ object InferredTypes {
} else {
null
}
+ } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) {
+ output: T =>
+ if (output != null) {
+ ArrayData.toArrayData(output)
+ } else {
+ null
+ }
} else {
output: T => output
}
@@ -141,6 +153,8 @@ object InferredTypes {
StringType
} else if (typeOf[T] =:= typeOf[Array[Byte]]) {
BinaryType
+ } else if (typeOf[T] =:= typeOf[Array[java.lang.Long]]) {
+ DataTypes.createArrayType(LongType)
} else {
BooleanType
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index 48c5a446..2ef94936 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -210,6 +210,10 @@ object st_functions extends DataFrameAPI {
def ST_Reverse(geometry: Column): Column = wrapExpression[ST_Reverse](geometry)
def ST_Reverse(geometry: String): Column = wrapExpression[ST_Reverse](geometry)
+ def ST_S2CellIDs(geometry: Column, level: Column): Column = wrapExpression[ST_S2CellIDs](geometry, level)
+
+ def ST_S2CellIDs(geometry: String, level: Int): Column = wrapExpression[ST_S2CellIDs](geometry, level)
+
def ST_SetPoint(lineString: Column, index: Column, point: Column): Column = wrapExpression[ST_SetPoint](lineString, index, point)
def ST_SetPoint(lineString: String, index: Int, point: String): Column = wrapExpression[ST_SetPoint](lineString, index, point)
diff --git a/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index d77c9f0b..db11c415 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -20,13 +20,15 @@ package org.apache.sedona.sql
import scala.collection.mutable.WrappedArray
import org.apache.commons.codec.binary.Hex
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{col, lit}
import org.locationtech.jts.geom.Geometry
import org.apache.spark.sql.sedona_sql.expressions.st_constructors._
import org.apache.spark.sql.sedona_sql.expressions.st_functions._
import org.apache.spark.sql.sedona_sql.expressions.st_predicates._
import org.apache.spark.sql.sedona_sql.expressions.st_aggregates._
+import scala.collection.mutable
+
class dataFrameAPITestScala extends TestBaseScala {
import sparkSession.implicits._
@@ -900,5 +902,14 @@ class dataFrameAPITestScala extends TestBaseScala {
val expectedResult = "LINESTRING (10 40, 40 30, 20 20, 30 10)"
assert(actualResult == expectedResult)
}
+
+ it("Passed ST_S2CellIDs") {
+ val baseDF = sparkSession.sql("SELECT ST_GeomFromWKT('Polygon ((0 0, 1 2, 2 2, 3 2, 5 0, 4 0, 3 1, 2 1, 1 0, 0 0))') as geom")
+ val df = baseDF.select(ST_S2CellIDs("geom", 6))
+ val dfMRB = baseDF.select(ST_S2CellIDs(ST_Envelope(col("geom")), lit(6)))
+ val actualResult = df.take(1)(0).getAs[mutable.WrappedArray[Long]](0).toSet
+ val mbrResult = dfMRB.take(1)(0).getAs[mutable.WrappedArray[Long]](0).toSet
+ assert (actualResult.subsetOf(mbrResult))
+ }
}
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functions/STS2CellIDs.scala b/sql/src/test/scala/org/apache/sedona/sql/functions/STS2CellIDs.scala
new file mode 100644
index 00000000..dffed533
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/functions/STS2CellIDs.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.functions
+
+import org.apache.sedona.sql.{GeometrySample, TestBaseScala}
+import org.apache.spark.sql.functions.{col, expr, lit, monotonically_increasing_id}
+import org.scalatest.{GivenWhenThen, Matchers}
+
+import scala.collection.mutable
+
+class STS2CellIDs extends TestBaseScala with Matchers with GeometrySample with GivenWhenThen {
+ import sparkSession.implicits._
+
+ describe("should pass ST_S2CellIDs"){
+
+ it("should return null while using ST_S2CellIDs when geometry is empty") {
+ val geometryTable = sparkSession.sparkContext.parallelize(1 to 10).toDF()
+ .withColumn("geom", lit(null))
+
+ When("using ST_MakePolygon on null geometries")
+ val geometryTableWithCellIDs = geometryTable
+ .withColumn("cell_ids", expr("ST_S2CellIDs(geom, 4)"))
+ .select("cell_ids").collect().filter(
+ r => r.get(0)!= null
+ )
+
+ Then("no exception should be raised")
+ require(geometryTableWithCellIDs.isEmpty)
+ }
+
+ it("should correctly return array of cell ids use of ST_S2CellIDs"){
+ Given("DataFrame with valid Geometries")
+ val geometryTable = Seq(
+ "POINT(1 2)",
+ "LINESTRING(-5 8, -6 1, -8 6, -2 5, -6 1, -5 8)",
+ "POLYGON ((75 29, 77 29, 77 29, 75 29))"
+
+ ).map(geom => Tuple1(wktReader.read(geom))).toDF("geom")
+
+ When("generate ST_S2CellIDs from those geometries")
+ val geometryDfWithCellIDs = geometryTable
+ .withColumn("cell_ids", expr("ST_S2CellIDs(geom, 5)"))
+
+ Then("valid should have list of Long type cell ids returned")
+ geometryDfWithCellIDs.select("cell_ids").collect().foreach(
+ r => require(r.get(0).isInstanceOf[mutable.WrappedArray[Long]] && r.size > 0)
+ )
+ }
+
+ it("use ST_S2CellIDs for spatial join") {
+ Given("DataFrame with valid line strings")
+ val polygonDf = sparkSession.read.format("csv").option("delimiter", "\t").option("header", "false").load(geojsonInputLocation)
+ .select(expr("ST_GeomFromGeoJSON(_c0)").as("countyshape"))
+ .select(
+ monotonically_increasing_id.as("id"),
+ col("countyshape").as("geom")
+ ).limit(100)
+ val rightPolygons = polygonDf.filter("id > 50")
+ rightPolygons.createOrReplaceTempView("rights")
+ // generate a sub list of polygons
+ val leftPolygons = polygonDf.filter("id <= 50")
+ leftPolygons.createOrReplaceTempView("lefts")
+ When("generate the cellIds for both set of polygons, and explode into separate rows, join them by cellIds")
+ val joinedDf = sparkSession.sql(
+ """
+ |with lcs as (
+ | select id, geom, explode(ST_S2CellIDs(geom, 15)) as cellId from lefts
+ |)
+ |, rcs as (
+ | select id, geom, explode(ST_S2CellIDs(geom, 15)) as cellId from rights
+ |)
+ |select sum(if(ST_Intersects(lcs.geom, rcs.geom), 1, 0)) count_true_positive, count(1) count_positive from lcs join rcs on lcs.cellId = rcs.cellId
+ |""".stripMargin
+ )
+ Then("the geoms joined by cell ids should all really intersect in this case." +
+ "Note that, cellIds equal doesn't necessarily mean the geoms intersect." +
+ "If a coordinate fall on the border of 2 cells, S2 cover it with both cells. Use s2_intersects to filter out false positives")
+ val res = joinedDf.collect()(0)
+ require(
+ res.get(1).asInstanceOf[Long] == 48
+ )
+ require(
+ res.get(0) == res.get(1)
+ )
+ }
+ }
+}