You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/03/15 04:34:58 UTC
[sedona] branch master updated: [SEDONA-231] Redundant Serde Elimination (#792)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 7dfa2c3f [SEDONA-231] Redundant Serde Elimination (#792)
7dfa2c3f is described below
commit 7dfa2c3f566beade3cd524ab64327fcde836e6e8
Author: Douglas Dennis <do...@gmail.com>
AuthorDate: Tue Mar 14 21:34:51 2023 -0700
[SEDONA-231] Redundant Serde Elimination (#792)
---
.gitignore | 2 +
.../java/org/apache/sedona/common/Functions.java | 18 +++++-
.../java/org/apache/sedona/flink/FunctionTest.java | 14 ++---
pom.xml | 6 ++
sql/pom.xml | 4 ++
.../spark/sql/sedona_sql/UDT/GeometryUDT.scala | 1 -
.../sql/sedona_sql/expressions/Functions.scala | 56 +++++++++++++----
.../expressions/NullSafeExpressions.scala | 73 +++++++++++++++++-----
.../sql/sedona_sql/expressions/SerdeAware.scala | 25 ++++++++
.../expressions/collect/ST_Collect.scala | 16 +++--
.../sql/sedona_sql/expressions/implicits.scala | 10 ++-
.../org/apache/sedona/sql/serdeAwareTest.scala | 62 ++++++++++++++++++
12 files changed, 243 insertions(+), 44 deletions(-)
diff --git a/.gitignore b/.gitignore
index 8f6bbd43..955529bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,3 +17,5 @@
/.vscode/
.Rproj.user
__pycache__
+/.bsp
+/.scala-build
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java
index 17809520..2c90d06b 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -71,7 +71,11 @@ public class Functions {
}
public static Geometry boundary(Geometry geometry) {
- return geometry.getBoundary();
+ Geometry boundary = geometry.getBoundary();
+ if (boundary instanceof LinearRing) {
+ boundary = GEOMETRY_FACTORY.createLineString(boundary.getCoordinates());
+ }
+ return boundary;
}
public static Geometry buffer(Geometry geometry, double radius) {
@@ -236,7 +240,11 @@ public class Functions {
if (geometry instanceof Polygon) {
Polygon polygon = (Polygon) geometry;
if (n < polygon.getNumInteriorRing()) {
- return polygon.getInteriorRingN(n);
+ Geometry interiorRing = polygon.getInteriorRingN(n);
+ if (interiorRing instanceof LinearRing) {
+ interiorRing = GEOMETRY_FACTORY.createLineString(interiorRing.getCoordinates());
+ }
+ return interiorRing;
}
}
return null;
@@ -250,7 +258,11 @@ public class Functions {
}
public static Geometry exteriorRing(Geometry geometry) {
- return GeomUtils.getExteriorRing(geometry);
+ Geometry ring = GeomUtils.getExteriorRing(geometry);
+ if (ring instanceof LinearRing) {
+ ring = GEOMETRY_FACTORY.createLineString(ring.getCoordinates());
+ }
+ return ring;
}
public static String asEWKT(Geometry geometry) {
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index 4dcb3aab..94546edd 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -21,7 +21,7 @@ import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.geom.LinearRing;
+import org.locationtech.jts.geom.LineString;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.geom.Polygon;
import org.opengis.referencing.FactoryException;
@@ -62,7 +62,7 @@ public class FunctionTest extends TestBase{
Table polygonTable = tableEnv.sqlQuery("SELECT ST_GeomFromWKT('POLYGON ((1 1, 0 0, -1 1, 1 1))') AS geom");
Table boundaryTable = polygonTable.select(call(Functions.ST_Boundary.class.getSimpleName(), $("geom")));
Geometry result = (Geometry) first(boundaryTable).getField(0);
- assertEquals("LINEARRING (1 1, 0 0, -1 1, 1 1)", result.toString());
+ assertEquals("LINESTRING (1 1, 0 0, -1 1, 1 1)", result.toString());
}
@Test
@@ -221,8 +221,8 @@ public class FunctionTest extends TestBase{
public void testInteriorRingN() {
Table polygonTable = tableEnv.sqlQuery("SELECT ST_GeomFromText('POLYGON((7 9,8 7,11 6,15 8,16 6,17 7,17 10,18 12,17 14,15 15,11 15,10 13,9 12,7 9),(9 9,10 10,11 11,11 10,10 8,9 9),(12 14,15 14,13 11,12 14))') AS polygon");
Table resultTable = polygonTable.select(call(Functions.ST_InteriorRingN.class.getSimpleName(), $("polygon"), 1));
- LinearRing linearRing = (LinearRing) first(resultTable).getField(0);
- assertEquals("LINEARRING (12 14, 15 14, 13 11, 12 14)", linearRing.toString());
+ LineString lineString = (LineString) first(resultTable).getField(0);
+ assertEquals("LINESTRING (12 14, 15 14, 13 11, 12 14)", lineString.toString());
}
@Test
@@ -272,9 +272,9 @@ public class FunctionTest extends TestBase{
public void testExteriorRing() {
Table polygonTable = createPolygonTable(1);
Table linearRingTable = polygonTable.select(call(Functions.ST_ExteriorRing.class.getSimpleName(), $(polygonColNames[0])));
- LinearRing linearRing = (LinearRing) first(linearRingTable).getField(0);
- assertNotNull(linearRing);
- Assert.assertEquals("LINEARRING (-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 -0.5)", linearRing.toString());
+ LineString lineString = (LineString) first(linearRingTable).getField(0);
+ assertNotNull(lineString);
+ Assert.assertEquals("LINESTRING (-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 -0.5)", lineString.toString());
}
@Test
diff --git a/pom.xml b/pom.xml
index 6c50d34b..0b7c74b3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -331,6 +331,12 @@
<artifactId>s2-geometry</artifactId>
<version>${googles2.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-inline</artifactId>
+ <version>4.11.0</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</dependencyManagement>
<repositories>
diff --git a/sql/pom.xml b/sql/pom.xml
index 0d4de07d..287c2f6f 100644
--- a/sql/pom.xml
+++ b/sql/pom.xml
@@ -127,6 +127,10 @@
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.compat.version}</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-inline</artifactId>
+ </dependency>
</dependencies>
<build>
<sourceDirectory>src/main/scala</sourceDirectory>
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
index 8cf94680..427ff8ac 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
@@ -58,4 +58,3 @@ class GeometryUDT extends UserDefinedType[Geometry] {
}
case object GeometryUDT extends org.apache.spark.sql.sedona_sql.UDT.GeometryUDT with scala.Serializable
-
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 2b850cff..9de64be9 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -448,7 +448,9 @@ case class ST_StartPoint(inputExpressions: Seq[Expression])
override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
- case line: LineString => line.getPointN(0).toGenericArrayData
+ case line: LineString => {
+ line.getPointN(0)
+ }
case _ => null
}
}
@@ -473,11 +475,23 @@ case class ST_Boundary(inputExpressions: Seq[Expression])
case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression])
- extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
+ extends Expression with FoldableExpression with CodegenFallback {
+
+ override def nullable: Boolean = true
+
private val geometryFactory = new GeometryFactory()
- override protected def nullSafeEval(geometry: Geometry): Any = {
- getMinimumBoundingRadius(geometry)
+ override def eval(input: InternalRow): Any = {
+ val expr = inputExpressions(0)
+ val geometry = expr match {
+ case s: SerdeAware => s.evalWithoutSerialization(input)
+ case _ => expr.toGeometry(input)
+ }
+
+ geometry match {
+ case geometry: Geometry => getMinimumBoundingRadius(geometry)
+ case _ => null
+ }
}
private def getMinimumBoundingRadius(geom: Geometry): InternalRow = {
@@ -545,7 +559,7 @@ case class ST_EndPoint(inputExpressions: Seq[Expression])
override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
- case string: LineString => string.getEndPoint.toGenericArrayData
+ case string: LineString => string.getEndPoint
case _ => null
}
}
@@ -588,16 +602,24 @@ case class ST_Dump(inputExpressions: Seq[Expression])
extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
override protected def nullSafeEval(geometry: Geometry): Any = {
- val geometryCollection = geometry match {
+ geometry match {
case collection: GeometryCollection => {
val numberOfGeometries = collection.getNumGeometries
(0 until numberOfGeometries).map(
- index => collection.getGeometryN(index).toGenericArrayData
+ index => collection.getGeometryN(index)
).toArray
}
- case geom: Geometry => Array(geom.toGenericArrayData)
+ case geom: Geometry => Array(geom)
+ }
+ }
+
+ override protected def serializeResult(result: Any): Any = {
+ result match {
+ case array: Array[Geometry] => ArrayData.toArrayData(
+ array.map(_.toGenericArrayData)
+ )
+ case _ => null
}
- ArrayData.toArrayData(geometryCollection)
}
override def dataType: DataType = ArrayType(GeometryUDT)
@@ -613,7 +635,17 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression])
extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
override protected def nullSafeEval(geometry: Geometry): Any = {
- ArrayData.toArrayData(geometry.getPoints.map(geom => geom.toGenericArrayData))
+ geometry.getPoints.map(geom => geom).toArray
+ }
+
+ override protected def serializeResult(result: Any): Any = {
+ result match {
+ case array: Array[Geometry] => ArrayData.toArrayData(
+ array.map(geom => geom.toGenericArrayData)
+ )
+ case _ => null
+ }
+
}
override def dataType: DataType = ArrayType(GeometryUDT)
@@ -842,7 +874,7 @@ case class ST_SymDifference(inputExpressions: Seq[Expression])
extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {
override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
- leftGeometry.symDifference(rightGeometry).toGenericArrayData
+ leftGeometry.symDifference(rightGeometry)
}
override def dataType: DataType = GeometryUDT
@@ -863,7 +895,7 @@ case class ST_Union(inputExpressions: Seq[Expression])
extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {
override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
- leftGeometry.union(rightGeometry).toGenericArrayData
+ leftGeometry.union(rightGeometry)
}
override def dataType: DataType = GeometryUDT
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
index 4a29eb4a..05434771 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
@@ -38,7 +38,7 @@ trait FoldableExpression extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
}
-abstract class UnaryGeometryExpression extends Expression with ExpectsInputTypes {
+abstract class UnaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes {
def inputExpressions: Seq[Expression]
override def nullable: Boolean = true
@@ -46,19 +46,36 @@ abstract class UnaryGeometryExpression extends Expression with ExpectsInputTypes
override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT)
override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ val result = evalWithoutSerialization(input)
+ serializeResult(result)
+ }
+
+ override def evalWithoutSerialization(input: InternalRow): Any ={
+ val inputExpression = inputExpressions.head
+ val geometry = inputExpression match {
+ case expr: SerdeAware => expr.evalWithoutSerialization(input)
+ case expr: Any => expr.toGeometry(input)
+ }
+
(geometry) match {
case (geometry: Geometry) => nullSafeEval(geometry)
case _ => null
}
}
+ protected def serializeResult(result: Any): Any = {
+ result match {
+ case geometry: Geometry => geometry.toGenericArrayData
+ case _ => result
+ }
+ }
+
protected def nullSafeEval(geometry: Geometry): Any
}
-abstract class BinaryGeometryExpression extends Expression with ExpectsInputTypes {
+abstract class BinaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes {
def inputExpressions: Seq[Expression]
override def nullable: Boolean = true
@@ -66,14 +83,36 @@ abstract class BinaryGeometryExpression extends Expression with ExpectsInputType
override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, GeometryUDT)
override def eval(input: InternalRow): Any = {
- val leftGeometry = inputExpressions(0).toGeometry(input)
- val rightGeometry = inputExpressions(1).toGeometry(input)
+ val result = evalWithoutSerialization(input)
+ serializeResult(result)
+ }
+
+ override def evalWithoutSerialization(input: InternalRow): Any = {
+ val leftExpression = inputExpressions(0)
+ val leftGeometry = leftExpression match {
+ case expr: SerdeAware => expr.evalWithoutSerialization(input)
+ case _ => leftExpression.toGeometry(input)
+ }
+
+ val rightExpression = inputExpressions(1)
+ val rightGeometry = rightExpression match {
+ case expr: SerdeAware => expr.evalWithoutSerialization(input)
+ case _ => rightExpression.toGeometry(input)
+ }
+
(leftGeometry, rightGeometry) match {
case (leftGeometry: Geometry, rightGeometry: Geometry) => nullSafeEval(leftGeometry, rightGeometry)
case _ => null
}
}
+ protected def serializeResult(result: Any): Any = {
+ result match {
+ case geometry: Geometry => geometry.toGenericArrayData
+ case _ => result
+ }
+ }
+
protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any
}
@@ -168,7 +207,7 @@ object InferredTypes {
abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
(f: (A1) => R)
(implicit val a1Tag: TypeTag[A1], implicit val rTag: TypeTag[R])
- extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
+ extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
import InferredTypes._
def inputExpressions: Seq[Expression]
@@ -187,10 +226,12 @@ abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
lazy val serialize = buildSerializer[R]
- override def eval(input: InternalRow): Any = {
+ override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])
+
+ override def evalWithoutSerialization(input: InternalRow): Any = {
val value = extract(input)
if (value != null) {
- serialize(f(value))
+ f(value)
} else {
null
}
@@ -200,7 +241,7 @@ abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType, R: InferrableType]
(f: (A1, A2) => R)
(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val rTag: TypeTag[R])
- extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
+ extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
import InferredTypes._
def inputExpressions: Seq[Expression]
@@ -220,11 +261,13 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
lazy val serialize = buildSerializer[R]
- override def eval(input: InternalRow): Any = {
+ override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])
+
+ override def evalWithoutSerialization(input: InternalRow): Any = {
val left = extractLeft(input)
val right = extractRight(input)
if (left != null && right != null) {
- serialize(f(left, right))
+ f(left, right)
} else {
null
}
@@ -234,7 +277,7 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType, A3: InferrableType, R: InferrableType]
(f: (A1, A2, A3) => R)
(implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val a3Tag: TypeTag[A3], implicit val rTag: TypeTag[R])
- extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
+ extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
import InferredTypes._
def inputExpressions: Seq[Expression]
@@ -255,12 +298,14 @@ abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType,
lazy val serialize = buildSerializer[R]
- override def eval(input: InternalRow): Any = {
+ override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])
+
+ override def evalWithoutSerialization(input: InternalRow): Any = {
val first = extractFirst(input)
val second = extractSecond(input)
val third = extractThird(input)
if (first != null && second != null && third != null) {
- serialize(f(first, second, third))
+ f(first, second, third)
} else {
null
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/SerdeAware.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/SerdeAware.scala
new file mode 100644
index 00000000..a46d7e08
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/SerdeAware.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+trait SerdeAware {
+ def evalWithoutSerialization(input: InternalRow): Any
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index 6644560f..c259c366 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -18,6 +18,7 @@
*/
package org.apache.spark.sql.sedona_sql.expressions.collect
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
@@ -26,16 +27,24 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.expressions.implicits._
+import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
import org.apache.spark.sql.types.{ArrayType, _}
+import org.locationtech.jts.geom.Geometry
+
case class ST_Collect(inputExpressions: Seq[Expression])
extends Expression
+ with SerdeAware
with CodegenFallback {
assert(inputExpressions.length >= 1)
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
+ evalWithoutSerialization(input).asInstanceOf[Geometry].toGenericArrayData
+ }
+
+ override def evalWithoutSerialization(input: InternalRow): Any = {
val firstElement = inputExpressions.head
firstElement.dataType match {
@@ -49,14 +58,13 @@ case class ST_Collect(inputExpressions: Seq[Expression])
.filter(_ != null)
.map(_.toGeometry)
- Collect.createMultiGeometry(geomElements).toGenericArrayData
- case _ => Collect.createMultiGeometry(Seq()).toGenericArrayData
+ Collect.createMultiGeometry(geomElements)
+ case _ => Collect.createMultiGeometry(Seq())
}
case _ =>
val geomElements =
inputExpressions.map(_.toGeometry(input)).filter(_ != null)
- Collect.createMultiGeometry(geomElements).toGenericArrayData
-
+ Collect.createMultiGeometry(geomElements)
}
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
index f3ec6162..2bacc766 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
@@ -30,9 +30,13 @@ object implicits {
implicit class InputExpressionEnhancer(inputExpression: Expression) {
def toGeometry(input: InternalRow): Geometry = {
- inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
- case binary: Array[Byte] => GeometrySerializer.deserialize(binary)
- case _ => null
+ if (inputExpression.isInstanceOf[SerdeAware]) {
+ inputExpression.asInstanceOf[SerdeAware].evalWithoutSerialization(input).asInstanceOf[Geometry]
+ } else {
+ inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
+ case binary: Array[Byte] => GeometrySerializer.deserialize(binary)
+ case _ => null
+ }
}
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala b/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
new file mode 100644
index 00000000..ac9ce0b2
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sedona.sql
+
+import org.apache.sedona.common.geometrySerde.GeometrySerializer
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.sedona_sql.expressions.{ST_Buffer, ST_GeomFromText, ST_Point, ST_Union}
+import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{atMost, mockStatic}
+
+class SerdeAwareFunctionSpec extends TestBaseScala {
+
+ describe("SerdeAwareFunction") {
+ it("should save us some serialization and deserialization cost") {
+ // Mock GeometrySerializer
+ val factory = new GeometryFactory
+ val stubGeom = factory.createPoint(new Coordinate(1, 2))
+ val mocked = mockStatic(classOf[GeometrySerializer])
+ mocked.when(() => GeometrySerializer.deserialize(any(classOf[Array[Byte]]))).thenReturn(stubGeom)
+ mocked.when(() => GeometrySerializer.serialize(any(classOf[Geometry]))).thenReturn(Array[Byte](1, 2, 3))
+
+ val expr = ST_Union(Seq(
+ ST_Buffer(Seq(ST_GeomFromText(Seq(Literal("POINT (1 2)"), Literal(0))), Literal(1.0))),
+ ST_Point(Seq(Literal(1.0), Literal(2.0), Literal(null)))
+ ))
+
+ try {
+ // Evaluate an expression
+ expr.eval(null)
+
+ // Verify number of invocations
+ mocked.verify(
+ () => GeometrySerializer.deserialize(any(classOf[Array[Byte]])),
+ atMost(0))
+ mocked.verify(
+ () => GeometrySerializer.serialize(any(classOf[Geometry])),
+ atMost(1))
+ } finally {
+ // Undo the mock
+ mocked.close()
+ }
+ }
+ }
+}