You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by zo...@apache.org on 2023/04/18 16:15:32 UTC
[sedona] 01/01: move expressions to inferred typed, and implementation to commons
This is an automated email from the ASF dual-hosted git repository.
zongsizhang pushed a commit to branch fix/move-sql-funcitons-implementation-to-commons
in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 80c4bf69856d7129baf58a387cbf598a49bc2860
Author: zongsi.zhang <kr...@gmail.com>
AuthorDate: Wed Apr 19 00:14:25 2023 +0800
move expressions to inferred typed, and implementation to commons
---
.../org/apache/sedona/common/Constructors.java | 66 +++++++
.../java/org/apache/sedona/common/Functions.java | 101 +++++++++-
.../java/org/apache/sedona/common/Predicates.java | 52 ++++++
.../apache/sedona/common/utils/GeoHashDecoder.java | 84 +++++++++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 2 +-
.../sql/sedona_sql/expressions/Constructors.scala | 194 ++-----------------
.../sql/sedona_sql/expressions/Functions.scala | 207 ++-------------------
.../expressions/NullSafeExpressions.scala | 17 +-
.../sql/sedona_sql/expressions/Predicates.scala | 32 ++--
.../sql/sedona_sql/expressions/implicits.scala | 14 ++
.../sedona_sql/expressions/st_constructors.scala | 4 +
.../sql/sedona_sql/expressions/st_functions.scala | 5 +-
12 files changed, 381 insertions(+), 397 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/Constructors.java b/common/src/main/java/org/apache/sedona/common/Constructors.java
index 4cc5c48e..793738de 100644
--- a/common/src/main/java/org/apache/sedona/common/Constructors.java
+++ b/common/src/main/java/org/apache/sedona/common/Constructors.java
@@ -16,15 +16,25 @@ package org.apache.sedona.common;
import org.apache.sedona.common.enums.FileDataSplitter;
import org.apache.sedona.common.enums.GeometryType;
import org.apache.sedona.common.utils.FormatUtils;
+import org.apache.sedona.common.utils.GeoHashDecoder;
import org.locationtech.jts.geom.Coordinate;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.GeometryFactory;
import org.locationtech.jts.geom.PrecisionModel;
import org.locationtech.jts.io.ParseException;
+import org.locationtech.jts.io.WKBReader;
import org.locationtech.jts.io.WKTReader;
+import org.locationtech.jts.io.gml2.GMLReader;
+import org.locationtech.jts.io.kml.KMLReader;
+import org.xml.sax.SAXException;
+
+import javax.xml.parsers.ParserConfigurationException;
+import java.io.IOException;
public class Constructors {
+ private static final GeometryFactory GEOMETRY_FACTORY = new GeometryFactory();
+
public static Geometry geomFromWKT(String wkt, int srid) throws ParseException {
if (wkt == null) {
return null;
@@ -33,6 +43,10 @@ public class Constructors {
return new WKTReader(geometryFactory).read(wkt);
}
+ public static Geometry geomFromWKB(byte[] wkb) throws ParseException {
+ return new WKBReader().read(wkb);
+ }
+
public static Geometry mLineFromText(String wkt, int srid) throws ParseException {
if (wkt == null || !wkt.startsWith("MULTILINESTRING")) {
return null;
@@ -100,4 +114,56 @@ public class Constructors {
throw new RuntimeException(e);
}
}
+
+ public static Geometry pointFromText(String geomString, String geomFormat) {
+ return geomFromText(geomString, geomFormat, GeometryType.POINT);
+ }
+
+ public static Geometry polygonFromText(String geomString, String geomFormat) {
+ return geomFromText(geomString, geomFormat, GeometryType.POLYGON);
+ }
+
+ public static Geometry lineStringFromText(String geomString, String geomFormat) {
+ return geomFromText(geomString, geomFormat, GeometryType.LINESTRING);
+ }
+
+ public static Geometry lineFromText(String geomString) {
+ FileDataSplitter fileDataSplitter = FileDataSplitter.WKT;
+ Geometry geometry = Constructors.geomFromText(geomString, fileDataSplitter);
+ if(geometry.getGeometryType().contains("LineString")) {
+ return geometry;
+ } else {
+ return null;
+ }
+ }
+
+ public static Geometry polygonFromEnvelope(double minX, double minY, double maxX, double maxY) {
+ Coordinate[] coordinates = new Coordinate[5];
+ coordinates[0] = new Coordinate(minX, minY);
+ coordinates[1] = new Coordinate(minX, maxY);
+ coordinates[2] = new Coordinate(maxX, maxY);
+ coordinates[3] = new Coordinate(maxX, minY);
+ coordinates[4] = coordinates[0];
+ return GEOMETRY_FACTORY.createPolygon(coordinates);
+ }
+
+ public static Geometry geomFromGeoHash(String geoHash, Integer precision) {
+ System.out.println(geoHash);
+ System.out.println(precision);
+ try {
+ return GeoHashDecoder.decode(geoHash, precision);
+ } catch (GeoHashDecoder.InvalidGeoHashException e) {
+ return null;
+ }
+ }
+
+ public static Geometry geomFromGML(String gml) throws IOException, ParserConfigurationException, SAXException {
+ return new GMLReader().read(gml, GEOMETRY_FACTORY);
+ }
+
+ public static Geometry geomFromKML(String kml) throws ParseException {
+ return new KMLReader().read(kml);
+ }
+
+
}
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java
index 2c90d06b..cde62dcd 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -14,11 +14,8 @@
package org.apache.sedona.common;
import com.google.common.geometry.S2CellId;
-import com.google.common.geometry.S2Point;
-import com.google.common.geometry.S2Region;
-import com.google.common.geometry.S2RegionCoverer;
-import org.apache.commons.lang3.ArrayUtils;
import org.apache.sedona.common.geometryObjects.Circle;
+import org.apache.sedona.common.subDivide.GeometrySubDivider;
import org.apache.sedona.common.utils.GeomUtils;
import org.apache.sedona.common.utils.GeometryGeoHashEncoder;
import org.apache.sedona.common.utils.GeometrySplitter;
@@ -37,6 +34,7 @@ import org.locationtech.jts.operation.linemerge.LineMerger;
import org.locationtech.jts.operation.valid.IsSimpleOp;
import org.locationtech.jts.operation.valid.IsValidOp;
import org.locationtech.jts.precision.GeometryPrecisionReducer;
+import org.locationtech.jts.simplify.TopologyPreservingSimplifier;
import org.opengis.referencing.FactoryException;
import org.opengis.referencing.NoSuchAuthorityCodeException;
import org.opengis.referencing.crs.CoordinateReferenceSystem;
@@ -48,7 +46,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
-import java.util.function.Function;
import java.util.stream.Collectors;
@@ -575,4 +572,98 @@ public class Functions {
}
return S2Utils.roundCellsToSameLevel(new ArrayList<>(cellIds), level).stream().map(S2CellId::id).collect(Collectors.toList()).toArray(new Long[cellIds.size()]);
}
+
+
+ // create static function named simplifyPreserveTopology
+ public static Geometry simplifyPreserveTopology(Geometry geometry, double distanceTolerance) {
+ return TopologyPreservingSimplifier.simplify(geometry, distanceTolerance);
+ }
+
+ public static String geometryType(Geometry geometry) {
+ return "ST_" + geometry.getGeometryType();
+ }
+
+ public static Geometry startPoint(Geometry geometry) {
+ if (geometry instanceof LineString) {
+ LineString line = (LineString) geometry;
+ return line.getStartPoint();
+ }
+ return null;
+ }
+
+ public static Geometry endPoint(Geometry geometry) {
+ if (geometry instanceof LineString) {
+ LineString line = (LineString) geometry;
+ return line.getEndPoint();
+ }
+ return null;
+ }
+
+ public static Geometry[] dump(Geometry geometry) {
+ int numGeom = geometry.getNumGeometries();
+ if (geometry instanceof GeometryCollection) {
+ Geometry[] geoms = new Geometry[geometry.getNumGeometries()];
+ for (int i = 0; i < numGeom; i++) {
+ geoms[i] = geometry.getGeometryN(i);
+ }
+ return geoms;
+ } else {
+ return new Geometry[] {geometry};
+ }
+ }
+
+ public static Geometry[] dumpPoints(Geometry geometry) {
+ return Arrays.stream(geometry.getCoordinates()).map(GEOMETRY_FACTORY::createPoint).toArray(Point[]::new);
+ }
+
+ public static Geometry symDifference(Geometry leftGeom, Geometry rightGeom) {
+ return leftGeom.symDifference(rightGeom);
+ }
+
+ public static Geometry union(Geometry leftGeom, Geometry rightGeom) {
+ return leftGeom.union(rightGeom);
+ }
+
+ public static Geometry createMultiGeometryFromOneElement(Geometry geometry) {
+ if (geometry instanceof Circle) {
+ return GEOMETRY_FACTORY.createGeometryCollection(new Circle[] {(Circle) geometry});
+ } else if (geometry instanceof GeometryCollection) {
+ return geometry;
+ } else if (geometry instanceof LineString) {
+ return GEOMETRY_FACTORY.createMultiLineString(new LineString[]{(LineString) geometry});
+ } else if (geometry instanceof Point) {
+ return GEOMETRY_FACTORY.createMultiPoint(new Point[] {(Point) geometry});
+ } else if (geometry instanceof Polygon) {
+ return GEOMETRY_FACTORY.createMultiPolygon(new Polygon[] {(Polygon) geometry});
+ } else {
+ return GEOMETRY_FACTORY.createGeometryCollection();
+ }
+ }
+
+ public static Geometry[] subDivide(Geometry geometry, int maxVertices) {
+ return GeometrySubDivider.subDivide(geometry, maxVertices);
+ }
+
+ public static Geometry makePolygon(Geometry shell, Geometry[] holes) {
+ if (holes != null) {
+ LinearRing[] interiorRings = Arrays.stream(holes).filter(
+ h -> h != null && !h.isEmpty() && h instanceof LineString && ((LineString) h).isClosed()
+ ).map(
+ h -> GEOMETRY_FACTORY.createLinearRing(h.getCoordinates())
+ ).toArray(LinearRing[]::new);
+ if (interiorRings.length != 0) {
+ return GEOMETRY_FACTORY.createPolygon(
+ GEOMETRY_FACTORY.createLinearRing(shell.getCoordinates()),
+ Arrays.stream(holes).filter(
+ h -> h != null && !h.isEmpty() && h instanceof LineString && ((LineString) h).isClosed()
+ ).map(
+ h -> GEOMETRY_FACTORY.createLinearRing(h.getCoordinates())
+ ).toArray(LinearRing[]::new)
+ );
+ }
+ }
+ return GEOMETRY_FACTORY.createPolygon(
+ GEOMETRY_FACTORY.createLinearRing(shell.getCoordinates())
+ );
+ }
}
diff --git a/common/src/main/java/org/apache/sedona/common/Predicates.java b/common/src/main/java/org/apache/sedona/common/Predicates.java
new file mode 100644
index 00000000..cb37b5c1
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/Predicates.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common;
+
+import org.locationtech.jts.geom.Geometry;
+
+public class Predicates {
+ public static boolean contains(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.contains(rightGeometry);
+ }
+ public static boolean intersects(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.intersects(rightGeometry);
+ }
+ public static boolean within(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.within(rightGeometry);
+ }
+ public static boolean covers(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.covers(rightGeometry);
+ }
+ public static boolean coveredBy(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.coveredBy(rightGeometry);
+ }
+ public static boolean crosses(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.crosses(rightGeometry);
+ }
+ public static boolean overlaps(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.overlaps(rightGeometry);
+ }
+ public static boolean touches(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.touches(rightGeometry);
+ }
+ public static boolean equals(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.symDifference(rightGeometry).isEmpty();
+ }
+ public static boolean disjoint(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.disjoint(rightGeometry);
+ }
+ public static boolean orderingEquals(Geometry leftGeometry, Geometry rightGeometry) {
+ return leftGeometry.equalsExact(rightGeometry);
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java b/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java
new file mode 100644
index 00000000..dd1bba1c
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/utils/GeoHashDecoder.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.utils;
+
+import org.locationtech.jts.geom.Geometry;
+
+public class GeoHashDecoder {
+ private static final int[] bits = new int[] {16, 8, 4, 2, 1};
+ private static final String base32 = "0123456789bcdefghjkmnpqrstuvwxyz";
+
+ public static class InvalidGeoHashException extends Exception {
+ public InvalidGeoHashException(String message) {
+ super(message);
+ }
+ }
+
+ public static Geometry decode(String geohash, Integer precision) throws InvalidGeoHashException {
+ return decodeGeoHashBBox(geohash, precision).getBbox().toPolygon();
+ }
+
+ private static class LatLon {
+ public Double[] lons;
+
+ public Double[] lats;
+
+ public LatLon(Double[] lons, Double[] lats) {
+ this.lons = lons;
+ this.lats = lats;
+ }
+
+ BBox getBbox() {
+ return new BBox(
+ lons[0],
+ lons[1],
+ lats[0],
+ lats[1]
+ );
+ }
+ }
+
+ private static LatLon decodeGeoHashBBox(String geohash, Integer precision) throws InvalidGeoHashException {
+ LatLon latLon = new LatLon(new Double[] {-180.0, 180.0}, new Double[] {-90.0, 90.0});
+ String geoHashLowered = geohash.toLowerCase();
+ int geoHashLength = geohash.length();
+ int targetPrecision = geoHashLength;
+ if (precision != null) {
+ if (precision < 0) throw new InvalidGeoHashException("Precision can not be negative");
+ else targetPrecision = Math.min(geoHashLength, precision);
+ }
+ boolean isEven = true;
+
+ for (int i = 0; i < targetPrecision ; i++){
+ char c = geoHashLowered.charAt(i);
+ byte cd = (byte) base32.indexOf(c);
+ if (cd == -1){
+ throw new InvalidGeoHashException(String.format("Invalid character '%s' found at index %d", c, i));
+ }
+ for (int j = 0;j < 5; j++){
+ byte mask = (byte) bits[j];
+ int index = (mask & cd) == 0 ? 1 : 0;
+ if (isEven){
+ latLon.lons[index] = (latLon.lons[0] + latLon.lons[1]) / 2;
+ }
+ else {
+ latLon.lats[index] = (latLon.lats[0] + latLon.lats[1]) / 2;
+ }
+ isEven = !isEven;
+ }
+ }
+ return latLon;
+ }
+
+}
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 03dbbab5..fa3493ab 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -117,7 +117,7 @@ object Catalog {
function[ST_LineInterpolatePoint](),
function[ST_SubDivideExplode](),
function[ST_SubDivide](),
- function[ST_MakePolygon](),
+ function[ST_MakePolygon](null),
function[ST_GeoHash](),
function[ST_GeomFromGeoHash](null),
function[ST_Collect](),
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index 188ab906..fff08228 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -19,20 +19,15 @@
package org.apache.spark.sql.sedona_sql.expressions
import org.apache.sedona.common.Constructors
-import org.apache.sedona.common.enums.{FileDataSplitter, GeometryType}
+import org.apache.sedona.common.enums.FileDataSplitter
import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.sedona_sql.expressions.geohash.GeoHashDecoder
import org.apache.spark.sql.sedona_sql.expressions.implicits.GeometryEnhancer
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-import org.locationtech.jts.geom.{Coordinate, GeometryFactory}
-import org.locationtech.jts.io.WKBReader
-import org.locationtech.jts.io.gml2.GMLReader
-import org.locationtech.jts.io.kml.KMLReader
/**
* Return a point from a string. The string must be plain string and each coordinate must be separated by a delimiter.
@@ -41,25 +36,7 @@ import org.locationtech.jts.io.kml.KMLReader
* string, the second parameter is the delimiter. String format should be similar to CSV/TSV
*/
case class ST_PointFromText(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
- // This is an expression which takes two input expressions.
- assert(inputExpressions.length == 2)
-
- override def nullable: Boolean = false
-
- override def eval(inputRow: InternalRow): Any = {
- val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
- val geomFormat = inputExpressions(1).eval(inputRow).asInstanceOf[UTF8String].toString
- val geometry = Constructors.geomFromText(geomString, geomFormat, GeometryType.POINT)
- GeometrySerializer.serialize(geometry)
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
-
- override def children: Seq[Expression] = inputExpressions
-
+ extends InferredBinaryExpression(Constructors.pointFromText) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
@@ -71,26 +48,7 @@ case class ST_PointFromText(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_PolygonFromText(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
- // This is an expression which takes two input expressions.
- assert(inputExpressions.length == 2)
-
- override def nullable: Boolean = false
-
- override def eval(inputRow: InternalRow): Any = {
- val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
- val geomFormat = inputExpressions(1).eval(inputRow).asInstanceOf[UTF8String].toString
-
- var geometry = Constructors.geomFromText(geomString, geomFormat, GeometryType.POLYGON)
- GeometrySerializer.serialize(geometry)
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
-
- override def children: Seq[Expression] = inputExpressions
-
+ extends InferredBinaryExpression(Constructors.polygonFromText) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
@@ -102,68 +60,23 @@ case class ST_PolygonFromText(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_LineFromText(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
- // This is an expression which takes one input expressions.
- assert(inputExpressions.length == 1)
-
- override def nullable: Boolean = true
-
- override def eval(inputRow: InternalRow): Any = {
- val lineString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
-
- val fileDataSplitter = FileDataSplitter.WKT
- val geometry = Constructors.geomFromText(lineString, fileDataSplitter)
- if(geometry.getGeometryType.contains("LineString")) {
- GeometrySerializer.serialize(geometry)
- } else {
- null
- }
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
-
- override def children: Seq[Expression] = inputExpressions
-
+ extends InferredUnaryExpression(Constructors.lineFromText) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}
-
/**
* Return a linestring from a string. The string must be plain string and each coordinate must be separated by a delimiter.
*
* @param inputExpressions
*/
case class ST_LineStringFromText(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
- // This is an expression which takes two input expressions.
- assert(inputExpressions.length == 2)
-
- override def nullable: Boolean = false
-
- override def eval(inputRow: InternalRow): Any = {
- val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
- val geomFormat = inputExpressions(1).eval(inputRow).asInstanceOf[UTF8String].toString
-
- val geometry = Constructors.geomFromText(geomString, geomFormat, GeometryType.LINESTRING)
-
- GeometrySerializer.serialize(geometry)
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
-
- override def children: Seq[Expression] = inputExpressions
-
+ extends InferredBinaryExpression(Constructors.lineStringFromText) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}
-
/**
* Return a Geometry from a WKT string
*
@@ -212,7 +125,7 @@ case class ST_GeomFromWKB(inputExpressions: Seq[Expression])
}
case (wkb: Array[Byte]) => {
// convert raw wkb byte array to geometry
- new WKBReader().read(wkb).toGenericArrayData
+ Constructors.geomFromWKB(wkb).toGenericArrayData
}
case null => null
}
@@ -294,32 +207,7 @@ case class ST_PointZ(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_PolygonFromEnvelope(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
- assert(inputExpressions.length == 4)
-
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val minX = inputExpressions(0).eval(input).asInstanceOf[Double]
- val minY = inputExpressions(1).eval(input).asInstanceOf[Double]
- val maxX = inputExpressions(2).eval(input).asInstanceOf[Double]
- val maxY = inputExpressions(3).eval(input).asInstanceOf[Double]
- var coordinates = new Array[Coordinate](5)
- coordinates(0) = new Coordinate(minX, minY)
- coordinates(1) = new Coordinate(minX, maxY)
- coordinates(2) = new Coordinate(maxX, maxY)
- coordinates(3) = new Coordinate(maxX, minY)
- coordinates(4) = coordinates(0)
- val geometryFactory = new GeometryFactory()
- val polygon = geometryFactory.createPolygon(coordinates)
- GeometrySerializer.serialize(polygon)
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, DoubleType)
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredQuarternaryExpression(Constructors.polygonFromEnvelope) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -337,81 +225,23 @@ trait UserDataGeneratator {
}
}
-
case class ST_GeomFromGeoHash(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback {
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geoHash = Option(inputExpressions.head.eval(input))
- .map(_.asInstanceOf[UTF8String].toString)
- val precision = Option(inputExpressions(1).eval(input)).map(_.asInstanceOf[Int])
-
- try {
- geoHash match {
- case Some(value) => GeoHashDecoder.decode(value, precision).toGenericArrayData
- case None => null
- }
- }
- catch {
- case e: Exception => null
- }
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType)
-
- override def children: Seq[Expression] = inputExpressions
-
+ extends InferredBinaryExpression(Constructors.geomFromGeoHash) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
+ override def allowRightNull: Boolean = true
}
case class ST_GeomFromGML(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback {
- assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(inputRow: InternalRow): Any = {
- (inputExpressions(0).eval(inputRow)) match {
- case geomString: UTF8String =>
- new GMLReader().read(geomString.toString, new GeometryFactory()).toGenericArrayData
- case _ => null
- }
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
-
- override def children: Seq[Expression] = inputExpressions
-
+ extends InferredUnaryExpression(Constructors.geomFromGML) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}
case class ST_GeomFromKML(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback {
- assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(inputRow: InternalRow): Any = {
- inputExpressions(0).eval(inputRow) match {
- case geomString: UTF8String =>
- new KMLReader().read(geomString.toString).toGenericArrayData
- case _ => null
- }
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
-
- override def children: Seq[Expression] = inputExpressions
-
+ extends InferredUnaryExpression(Constructors.geomFromKML) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index a75b044f..2c56845b 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -19,19 +19,15 @@
package org.apache.spark.sql.sedona_sql.expressions
import org.apache.sedona.common.Functions
-import org.apache.sedona.common.subDivide.GeometrySubDivider
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, ImplicitCastInputTypes}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.sedona_sql.expressions.collect.Collect
import org.apache.spark.sql.sedona_sql.expressions.implicits._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.algorithm.MinimumBoundingCircle
-import org.locationtech.jts.geom.{Geometry, _}
-import org.locationtech.jts.simplify.TopologyPreservingSimplifier
+import org.locationtech.jts.geom._
/**
* Return the distance between two geometries.
@@ -221,26 +217,7 @@ case class ST_Centroid(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Transform(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback {
-
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions(0).toGeometry(input)
- val sourceCRSString = inputExpressions(1).asString(input)
- val targetCRSString = inputExpressions(2).asString(input)
- val lenient = inputExpressions(3).eval(input).asInstanceOf[Boolean]
- (geometry,sourceCRSString,targetCRSString,lenient) match {
- case (null,_,_,_) => null
- case _ => Functions.transform(geometry, sourceCRSString, targetCRSString, lenient).toGenericArrayData
- }
- }
-
- override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, StringType, StringType, BooleanType)
-
- override def dataType: DataType = GeometryUDT
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredQuarternaryExpression(Functions.transform) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -309,7 +286,7 @@ case class ST_IsSimple(inputExpressions: Seq[Expression])
* second arg is distance tolerance for the simplification(all vertices in the simplified geometry will be within this distance of the original geometry)
*/
case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression])
- extends InferredBinaryExpression(TopologyPreservingSimplifier.simplify) with FoldableExpression {
+ extends InferredBinaryExpression(Functions.simplifyPreserveTopology) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -379,15 +356,7 @@ case class ST_SetSRID(inputExpressions: Seq[Expression])
}
case class ST_GeometryType(inputExpressions: Seq[Expression])
- extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
-
- override protected def nullSafeEval(geometry: Geometry): Any = {
- UTF8String.fromString("ST_" + geometry.getGeometryType)
- }
-
- override def dataType: DataType = StringType
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredUnaryExpression(Functions.geometryType) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -443,27 +412,13 @@ case class ST_Z(inputExpressions: Seq[Expression])
}
case class ST_StartPoint(inputExpressions: Seq[Expression])
- extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
-
- override protected def nullSafeEval(geometry: Geometry): Any = {
- geometry match {
- case line: LineString => {
- line.getPointN(0)
- }
- case _ => null
- }
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredUnaryExpression(Functions.startPoint) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
}
-
case class ST_Boundary(inputExpressions: Seq[Expression])
extends InferredUnaryExpression(Functions.boundary) with FoldableExpression {
@@ -552,20 +507,8 @@ case class ST_LineInterpolatePoint(inputExpressions: Seq[Expression])
}
}
-
case class ST_EndPoint(inputExpressions: Seq[Expression])
- extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
-
- override protected def nullSafeEval(geometry: Geometry): Any = {
- geometry match {
- case string: LineString => string.getEndPoint
- case _ => null
- }
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredUnaryExpression(Functions.endPoint) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -598,32 +541,7 @@ case class ST_InteriorRingN(inputExpressions: Seq[Expression])
}
case class ST_Dump(inputExpressions: Seq[Expression])
- extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
-
- override protected def nullSafeEval(geometry: Geometry): Any = {
- geometry match {
- case collection: GeometryCollection => {
- val numberOfGeometries = collection.getNumGeometries
- (0 until numberOfGeometries).map(
- index => collection.getGeometryN(index)
- ).toArray
- }
- case geom: Geometry => Array(geom)
- }
- }
-
- override protected def serializeResult(result: Any): Any = {
- result match {
- case array: Array[Geometry] => ArrayData.toArrayData(
- array.map(_.toGenericArrayData)
- )
- case _ => null
- }
- }
-
- override def dataType: DataType = ArrayType(GeometryUDT)
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredUnaryExpression(Functions.dump) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -631,25 +549,7 @@ case class ST_Dump(inputExpressions: Seq[Expression])
}
case class ST_DumpPoints(inputExpressions: Seq[Expression])
- extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
-
- override protected def nullSafeEval(geometry: Geometry): Any = {
- geometry.getPoints.map(geom => geom).toArray
- }
-
- override protected def serializeResult(result: Any): Any = {
- result match {
- case array: Array[Geometry] => ArrayData.toArrayData(
- array.map(geom => geom.toGenericArrayData)
- )
- case _ => null
- }
-
- }
-
- override def dataType: DataType = ArrayType(GeometryUDT)
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredUnaryExpression(Functions.dumpPoints) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -746,24 +646,7 @@ case class ST_FlipCoordinates(inputExpressions: Seq[Expression])
}
case class ST_SubDivide(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback {
-
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- inputExpressions(0).toGeometry(input) match {
- case geom: Geometry => ArrayData.toArrayData(
- GeometrySubDivider.subDivide(geom, inputExpressions(1).toInt(input)).map(_.toGenericArrayData)
- )
- case null => null
- }
- }
-
- override def dataType: DataType = ArrayType(GeometryUDT)
-
- override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, IntegerType)
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredBinaryExpression(Functions.subDivide) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -779,9 +662,9 @@ case class ST_SubDivideExplode(children: Seq[Expression])
val maxVerticesRaw = children(1)
geometryRaw.toGeometry(input) match {
case geom: Geometry => ArrayData.toArrayData(
- GeometrySubDivider.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData)
+ Functions.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData)
)
- GeometrySubDivider.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData).map(InternalRow(_))
+ Functions.subDivide(geom, maxVerticesRaw.toInt(input)).map(_.toGenericArrayData).map(InternalRow(_))
case _ => new Array[InternalRow](0)
}
}
@@ -795,52 +678,14 @@ case class ST_SubDivideExplode(children: Seq[Expression])
}
}
-
case class ST_MakePolygon(inputExpressions: Seq[Expression])
- extends Expression with FoldableExpression with CodegenFallback {
- inputExpressions.betweenLength(1, 2)
-
- override def nullable: Boolean = true
- private val geometryFactory = new GeometryFactory()
-
- override def eval(input: InternalRow): Any = {
- val exteriorRing = inputExpressions.head
- val possibleHolesRaw = inputExpressions.tail.headOption.map(_.eval(input).asInstanceOf[ArrayData])
- val numOfElements = possibleHolesRaw.map(_.numElements()).getOrElse(0)
-
- val holes = (0 until numOfElements).map(el => possibleHolesRaw match {
- case Some(value) => Some(value.getBinary(el))
- case None => None
- }).filter(_.nonEmpty)
- .map(el => el.map(_.toGeometry))
- .flatMap{
- case maybeLine: Option[LineString] =>
- maybeLine.map(line => geometryFactory.createLinearRing(line.getCoordinates))
- case _ => None
- }
-
- exteriorRing.toGeometry(input) match {
- case geom: LineString =>
- try {
- val poly = new Polygon(geometryFactory.createLinearRing(geom.getCoordinates), holes.toArray, geometryFactory)
- poly.toGenericArrayData
- }
- catch {
- case e: Exception => null
- }
-
- case _ => null
- }
-
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredBinaryExpression(Functions.makePolygon) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
}
+
+ override def allowRightNull: Boolean = true
}
case class ST_GeoHash(inputExpressions: Seq[Expression])
@@ -870,15 +715,7 @@ case class ST_Difference(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_SymDifference(inputExpressions: Seq[Expression])
- extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {
-
- override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
- leftGeometry.symDifference(rightGeometry)
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredBinaryExpression(Functions.symDifference) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -891,15 +728,7 @@ case class ST_SymDifference(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Union(inputExpressions: Seq[Expression])
- extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {
-
- override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
- leftGeometry.union(rightGeometry)
- }
-
- override def dataType: DataType = GeometryUDT
-
- override def children: Seq[Expression] = inputExpressions
+ extends InferredBinaryExpression(Functions.union) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
@@ -907,7 +736,7 @@ case class ST_Union(inputExpressions: Seq[Expression])
}
case class ST_Multi(inputExpressions: Seq[Expression])
- extends InferredUnaryExpression(Collect.createMultiGeometryFromOneElement) with FoldableExpression {
+ extends InferredUnaryExpression(Functions.createMultiGeometryFromOneElement) with FoldableExpression {
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
copy(inputExpressions = newChildren)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
index 05434771..fc4fbb6e 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
@@ -123,6 +123,8 @@ sealed class InferrableType[T: TypeTag]
object InferrableType {
implicit val geometryInstance: InferrableType[Geometry] =
new InferrableType[Geometry] {}
+ implicit val geometryArrayInstance: InferrableType[Array[Geometry]] =
+ new InferrableType[Array[Geometry]] {}
implicit val javaDoubleInstance: InferrableType[java.lang.Double] =
new InferrableType[java.lang.Double] {}
implicit val javaIntegerInstance: InferrableType[java.lang.Integer] =
@@ -145,6 +147,8 @@ object InferredTypes {
def buildExtractor[T: TypeTag](expr: Expression): InternalRow => T = {
if (typeOf[T] =:= typeOf[Geometry]) {
input: InternalRow => expr.toGeometry(input).asInstanceOf[T]
+ } else if (typeOf[T] =:= typeOf[Array[Geometry]]) {
+ input: InternalRow => expr.toGeometryArray(input).asInstanceOf[T]
} else if (typeOf[T] =:= typeOf[String]) {
input: InternalRow => expr.asString(input).asInstanceOf[T]
} else {
@@ -172,6 +176,13 @@ object InferredTypes {
} else {
null
}
+ } else if (typeOf[T] =:= typeOf[Array[Geometry]]) {
+ output: T =>
+ if (output != null) {
+ ArrayData.toArrayData(output.asInstanceOf[Array[Geometry]].map(_.toGenericArrayData))
+ } else {
+ null
+ }
} else {
output: T => output
}
@@ -180,6 +191,8 @@ object InferredTypes {
def inferSparkType[T: TypeTag]: DataType = {
if (typeOf[T] =:= typeOf[Geometry]) {
GeometryUDT
+ } else if (typeOf[T] =:= typeOf[Array[Geometry]]) {
+ DataTypes.createArrayType(GeometryUDT)
} else if (typeOf[T] =:= typeOf[java.lang.Double]) {
DoubleType
} else if (typeOf[T] =:= typeOf[java.lang.Integer]) {
@@ -254,6 +267,8 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
override def nullable: Boolean = true
+ def allowRightNull: Boolean = false
+
override def dataType = inferSparkType[R]
lazy val extractLeft = buildExtractor[A1](inputExpressions(0))
@@ -266,7 +281,7 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
override def evalWithoutSerialization(input: InternalRow): Any = {
val left = extractLeft(input)
val right = extractRight(input)
- if (left != null && right != null) {
+ if (left != null && (right != null || allowRightNull)) {
f(left, right)
} else {
null
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
index 39f68d4c..00923fd2 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Predicates.scala
@@ -18,15 +18,14 @@
*/
package org.apache.spark.sql.sedona_sql.expressions
+import org.apache.sedona.common.Predicates
import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{BooleanType, DataType}
-import org.locationtech.jts.geom.Geometry
-import org.apache.spark.sql.types.AbstractDataType
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, NullIntolerant}
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType}
+import org.locationtech.jts.geom.Geometry
abstract class ST_Predicate extends Expression
with FoldableExpression
@@ -73,7 +72,7 @@ case class ST_Contains(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.contains(rightGeometry)
+ Predicates.contains(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -90,7 +89,7 @@ case class ST_Intersects(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.intersects(rightGeometry)
+ Predicates.intersects(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -107,7 +106,7 @@ case class ST_Within(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.within(rightGeometry)
+ Predicates.within(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -124,7 +123,7 @@ case class ST_Covers(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.covers(rightGeometry)
+ Predicates.covers(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -141,7 +140,7 @@ case class ST_CoveredBy(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.coveredBy(rightGeometry)
+ Predicates.coveredBy(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -158,7 +157,7 @@ case class ST_Crosses(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.crosses(rightGeometry)
+ Predicates.crosses(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -176,7 +175,7 @@ case class ST_Overlaps(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.overlaps(rightGeometry)
+ Predicates.overlaps(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -193,7 +192,7 @@ case class ST_Touches(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.touches(rightGeometry)
+ Predicates.touches(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -211,8 +210,7 @@ case class ST_Equals(inputExpressions: Seq[Expression])
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
// Returns GeometryCollection object
- val symDifference = leftGeometry.symDifference(rightGeometry)
- symDifference.isEmpty
+ Predicates.equals(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -229,7 +227,7 @@ case class ST_Disjoint(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.disjoint(rightGeometry)
+ Predicates.disjoint(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
@@ -246,7 +244,7 @@ case class ST_OrderingEquals(inputExpressions: Seq[Expression])
extends ST_Predicate with CodegenFallback {
override def evalGeom(leftGeometry: Geometry, rightGeometry: Geometry): Boolean = {
- leftGeometry.equalsExact(rightGeometry)
+ Predicates.orderingEquals(leftGeometry, rightGeometry)
}
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
index 2bacc766..941a14c2 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
@@ -23,6 +23,7 @@ import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.types.{ByteType, DataTypes}
import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point}
@@ -40,6 +41,19 @@ object implicits {
}
}
+ def toGeometryArray(input: InternalRow): Array[Geometry] = {
+ inputExpression.eval(input).asInstanceOf[ArrayData] match {
+ case arrayData: ArrayData =>
+ val length = arrayData.numElements()
+ val geometries = new Array[Geometry](length)
+ for (i <- 0 until length) {
+ geometries(i) = arrayData.getBinary(i).toGeometry
+ }
+ geometries
+ case _ => null
+ }
+ }
+
def toInt(input: InternalRow): Int = {
inputExpression.eval(input).asInstanceOf[Int]
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
index 005d112f..aa9eada8 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
@@ -25,6 +25,10 @@ object st_constructors extends DataFrameAPI {
def ST_GeomFromGeoHash(geohash: Column, precision: Column): Column = wrapExpression[ST_GeomFromGeoHash](geohash, precision)
def ST_GeomFromGeoHash(geohash: String, precision: Int): Column = wrapExpression[ST_GeomFromGeoHash](geohash, precision)
+ def ST_GeomFromGeoHash(geohash: Column): Column = wrapExpression[ST_GeomFromGeoHash](geohash, null)
+
+ def ST_GeomFromGeoHash(geohash: String): Column = wrapExpression[ST_GeomFromGeoHash](geohash, null)
+
def ST_GeomFromGeoJSON(geojsonString: Column): Column = wrapExpression[ST_GeomFromGeoJSON](geojsonString)
def ST_GeomFromGeoJSON(geojsonString: String): Column = wrapExpression[ST_GeomFromGeoJSON](geojsonString)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
index 2ef94936..c8c7ac73 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sedona_sql.expressions
import org.apache.spark.sql.Column
import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect, ST_CollectionExtract}
+import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters
object st_functions extends DataFrameAPI {
@@ -159,8 +160,8 @@ object st_functions extends DataFrameAPI {
def ST_LineSubstring(lineString: Column, startFraction: Column, endFraction: Column): Column = wrapExpression[ST_LineSubstring](lineString, startFraction, endFraction)
def ST_LineSubstring(lineString: String, startFraction: Double, endFraction: Double): Column = wrapExpression[ST_LineSubstring](lineString, startFraction, endFraction)
- def ST_MakePolygon(lineString: Column): Column = wrapExpression[ST_MakePolygon](lineString)
- def ST_MakePolygon(lineString: String): Column = wrapExpression[ST_MakePolygon](lineString)
+ def ST_MakePolygon(lineString: Column): Column = wrapExpression[ST_MakePolygon](lineString, null)
+ def ST_MakePolygon(lineString: String): Column = wrapExpression[ST_MakePolygon](lineString, null)
def ST_MakePolygon(lineString: Column, holes: Column): Column = wrapExpression[ST_MakePolygon](lineString, holes)
def ST_MakePolygon(lineString: String, holes: String): Column = wrapExpression[ST_MakePolygon](lineString, holes)