You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/03/15 04:34:58 UTC

[sedona] branch master updated: [SEDONA-231] Redundant Serde Elimination (#792)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 7dfa2c3f [SEDONA-231] Redundant Serde Elimination (#792)
7dfa2c3f is described below

commit 7dfa2c3f566beade3cd524ab64327fcde836e6e8
Author: Douglas Dennis <do...@gmail.com>
AuthorDate: Tue Mar 14 21:34:51 2023 -0700

    [SEDONA-231] Redundant Serde Elimination (#792)
---
 .gitignore                                         |  2 +
 .../java/org/apache/sedona/common/Functions.java   | 18 +++++-
 .../java/org/apache/sedona/flink/FunctionTest.java | 14 ++---
 pom.xml                                            |  6 ++
 sql/pom.xml                                        |  4 ++
 .../spark/sql/sedona_sql/UDT/GeometryUDT.scala     |  1 -
 .../sql/sedona_sql/expressions/Functions.scala     | 56 +++++++++++++----
 .../expressions/NullSafeExpressions.scala          | 73 +++++++++++++++++-----
 .../sql/sedona_sql/expressions/SerdeAware.scala    | 25 ++++++++
 .../expressions/collect/ST_Collect.scala           | 16 +++--
 .../sql/sedona_sql/expressions/implicits.scala     | 10 ++-
 .../org/apache/sedona/sql/serdeAwareTest.scala     | 62 ++++++++++++++++++
 12 files changed, 243 insertions(+), 44 deletions(-)

diff --git a/.gitignore b/.gitignore
index 8f6bbd43..955529bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,3 +17,5 @@
 /.vscode/
 .Rproj.user
 __pycache__
+/.bsp
+/.scala-build
diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java
index 17809520..2c90d06b 100644
--- a/common/src/main/java/org/apache/sedona/common/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/Functions.java
@@ -71,7 +71,11 @@ public class Functions {
     }
 
     public static Geometry boundary(Geometry geometry) {
-        return geometry.getBoundary();
+        Geometry boundary = geometry.getBoundary();
+        if (boundary instanceof LinearRing) {
+            boundary = GEOMETRY_FACTORY.createLineString(boundary.getCoordinates());
+        }
+        return boundary;
     }
 
     public static Geometry buffer(Geometry geometry, double radius) {
@@ -236,7 +240,11 @@ public class Functions {
         if (geometry instanceof Polygon) {
             Polygon polygon = (Polygon) geometry;
             if (n < polygon.getNumInteriorRing()) {
-                return polygon.getInteriorRingN(n);
+                Geometry interiorRing = polygon.getInteriorRingN(n);
+                if (interiorRing instanceof LinearRing) {
+                    interiorRing = GEOMETRY_FACTORY.createLineString(interiorRing.getCoordinates());
+                }
+                return interiorRing;
             }
         }
         return null;
@@ -250,7 +258,11 @@ public class Functions {
     }
 
     public static Geometry exteriorRing(Geometry geometry) {
-        return GeomUtils.getExteriorRing(geometry);
+        Geometry ring = GeomUtils.getExteriorRing(geometry);
+        if (ring instanceof LinearRing) {
+            ring = GEOMETRY_FACTORY.createLineString(ring.getCoordinates());
+        }
+        return ring;
     }
 
     public static String asEWKT(Geometry geometry) {
diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
index 4dcb3aab..94546edd 100644
--- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
+++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java
@@ -21,7 +21,7 @@ import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.locationtech.jts.geom.Geometry;
-import org.locationtech.jts.geom.LinearRing;
+import org.locationtech.jts.geom.LineString;
 import org.locationtech.jts.geom.Point;
 import org.locationtech.jts.geom.Polygon;
 import org.opengis.referencing.FactoryException;
@@ -62,7 +62,7 @@ public class FunctionTest extends TestBase{
         Table polygonTable = tableEnv.sqlQuery("SELECT ST_GeomFromWKT('POLYGON ((1 1, 0 0, -1 1, 1 1))') AS geom");
         Table boundaryTable = polygonTable.select(call(Functions.ST_Boundary.class.getSimpleName(), $("geom")));
         Geometry result = (Geometry) first(boundaryTable).getField(0);
-        assertEquals("LINEARRING (1 1, 0 0, -1 1, 1 1)", result.toString());
+        assertEquals("LINESTRING (1 1, 0 0, -1 1, 1 1)", result.toString());
     }
 
     @Test
@@ -221,8 +221,8 @@ public class FunctionTest extends TestBase{
     public void testInteriorRingN() {
         Table polygonTable = tableEnv.sqlQuery("SELECT ST_GeomFromText('POLYGON((7 9,8 7,11 6,15 8,16 6,17 7,17 10,18 12,17 14,15 15,11 15,10 13,9 12,7 9),(9 9,10 10,11 11,11 10,10 8,9 9),(12 14,15 14,13 11,12 14))') AS polygon");
         Table resultTable = polygonTable.select(call(Functions.ST_InteriorRingN.class.getSimpleName(), $("polygon"), 1));
-        LinearRing linearRing = (LinearRing) first(resultTable).getField(0);
-        assertEquals("LINEARRING (12 14, 15 14, 13 11, 12 14)", linearRing.toString());
+        LineString lineString = (LineString) first(resultTable).getField(0);
+        assertEquals("LINESTRING (12 14, 15 14, 13 11, 12 14)", lineString.toString());
     }
 
     @Test
@@ -272,9 +272,9 @@ public class FunctionTest extends TestBase{
     public void testExteriorRing() {
         Table polygonTable = createPolygonTable(1);
         Table linearRingTable = polygonTable.select(call(Functions.ST_ExteriorRing.class.getSimpleName(), $(polygonColNames[0])));
-        LinearRing linearRing = (LinearRing) first(linearRingTable).getField(0);
-        assertNotNull(linearRing);
-        Assert.assertEquals("LINEARRING (-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 -0.5)", linearRing.toString());
+        LineString lineString = (LineString) first(linearRingTable).getField(0);
+        assertNotNull(lineString);
+        Assert.assertEquals("LINESTRING (-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 -0.5)", lineString.toString());
     }
 
     @Test
diff --git a/pom.xml b/pom.xml
index 6c50d34b..0b7c74b3 100644
--- a/pom.xml
+++ b/pom.xml
@@ -331,6 +331,12 @@
                 <artifactId>s2-geometry</artifactId>
                 <version>${googles2.version}</version>
             </dependency>
+            <dependency>
+                <groupId>org.mockito</groupId>
+                <artifactId>mockito-inline</artifactId>
+                <version>4.11.0</version>
+                <scope>test</scope>
+            </dependency>
         </dependencies>
     </dependencyManagement>
     <repositories>
diff --git a/sql/pom.xml b/sql/pom.xml
index 0d4de07d..287c2f6f 100644
--- a/sql/pom.xml
+++ b/sql/pom.xml
@@ -127,6 +127,10 @@
             <groupId>org.scalatest</groupId>
             <artifactId>scalatest_${scala.compat.version}</artifactId>
         </dependency>
+        <dependency>
+                <groupId>org.mockito</groupId>
+                <artifactId>mockito-inline</artifactId>
+        </dependency>
     </dependencies>
 	<build>
         <sourceDirectory>src/main/scala</sourceDirectory>
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
index 8cf94680..427ff8ac 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
@@ -58,4 +58,3 @@ class GeometryUDT extends UserDefinedType[Geometry] {
 }
 
 case object GeometryUDT extends org.apache.spark.sql.sedona_sql.UDT.GeometryUDT with scala.Serializable
-
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 2b850cff..9de64be9 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -448,7 +448,9 @@ case class ST_StartPoint(inputExpressions: Seq[Expression])
 
   override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
-      case line: LineString => line.getPointN(0).toGenericArrayData
+      case line: LineString => {
+        line.getPointN(0)
+      }
       case _ => null
     }
   }
@@ -473,11 +475,23 @@ case class ST_Boundary(inputExpressions: Seq[Expression])
 
 
 case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression])
-  extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
+  extends Expression with FoldableExpression with CodegenFallback {
+
+  override def nullable: Boolean = true
+
   private val geometryFactory = new GeometryFactory()
 
-  override protected def nullSafeEval(geometry: Geometry): Any = {
-    getMinimumBoundingRadius(geometry)
+  override def eval(input: InternalRow): Any = {
+    val expr = inputExpressions(0)
+    val geometry = expr match {
+      case s: SerdeAware => s.evalWithoutSerialization(input)
+      case _ => expr.toGeometry(input)
+    }
+
+    geometry match {
+      case geometry: Geometry => getMinimumBoundingRadius(geometry)
+      case _ => null
+    }
   }
 
   private def getMinimumBoundingRadius(geom: Geometry): InternalRow = {
@@ -545,7 +559,7 @@ case class ST_EndPoint(inputExpressions: Seq[Expression])
 
   override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
-      case string: LineString => string.getEndPoint.toGenericArrayData
+      case string: LineString => string.getEndPoint
       case _ => null
     }
   }
@@ -588,16 +602,24 @@ case class ST_Dump(inputExpressions: Seq[Expression])
   extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
 
   override protected def nullSafeEval(geometry: Geometry): Any = {
-    val geometryCollection = geometry match {
+    geometry match {
       case collection: GeometryCollection => {
         val numberOfGeometries = collection.getNumGeometries
         (0 until numberOfGeometries).map(
-          index => collection.getGeometryN(index).toGenericArrayData
+          index => collection.getGeometryN(index)
         ).toArray
       }
-      case geom: Geometry => Array(geom.toGenericArrayData)
+      case geom: Geometry => Array(geom)
+    }
+  }
+
+  override protected def serializeResult(result: Any): Any = {
+    result match {
+      case array: Array[Geometry] => ArrayData.toArrayData(
+        array.map(_.toGenericArrayData)
+      )
+      case _ => null
     }
-    ArrayData.toArrayData(geometryCollection)
   }
 
   override def dataType: DataType = ArrayType(GeometryUDT)
@@ -613,7 +635,17 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression])
   extends UnaryGeometryExpression with FoldableExpression with CodegenFallback {
 
   override protected def nullSafeEval(geometry: Geometry): Any = {
-    ArrayData.toArrayData(geometry.getPoints.map(geom => geom.toGenericArrayData))
+    geometry.getPoints.map(geom => geom).toArray
+  }
+
+  override protected def serializeResult(result: Any): Any = {
+    result match {
+      case array: Array[Geometry] => ArrayData.toArrayData(
+        array.map(geom => geom.toGenericArrayData)
+      )
+      case _ => null
+    }
+
   }
 
   override def dataType: DataType = ArrayType(GeometryUDT)
@@ -842,7 +874,7 @@ case class ST_SymDifference(inputExpressions: Seq[Expression])
   extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {
 
   override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
-    leftGeometry.symDifference(rightGeometry).toGenericArrayData
+    leftGeometry.symDifference(rightGeometry)
   }
 
   override def dataType: DataType = GeometryUDT
@@ -863,7 +895,7 @@ case class ST_Union(inputExpressions: Seq[Expression])
   extends BinaryGeometryExpression with FoldableExpression with CodegenFallback {
 
   override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
-    leftGeometry.union(rightGeometry).toGenericArrayData
+    leftGeometry.union(rightGeometry)
   }
 
   override def dataType: DataType = GeometryUDT
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
index 4a29eb4a..05434771 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
@@ -38,7 +38,7 @@ trait FoldableExpression extends Expression {
   override def foldable: Boolean = children.forall(_.foldable)
 }
 
-abstract class UnaryGeometryExpression extends Expression with ExpectsInputTypes {
+abstract class UnaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes {
   def inputExpressions: Seq[Expression]
 
   override def nullable: Boolean = true
@@ -46,19 +46,36 @@ abstract class UnaryGeometryExpression extends Expression with ExpectsInputTypes
   override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT)
 
   override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+    val result = evalWithoutSerialization(input)
+    serializeResult(result)
+  }
+
+  override def evalWithoutSerialization(input: InternalRow): Any ={
+    val inputExpression = inputExpressions.head
+    val geometry = inputExpression match {
+      case expr: SerdeAware => expr.evalWithoutSerialization(input)
+      case expr: Any => expr.toGeometry(input)
+    }
+
     (geometry) match {
       case (geometry: Geometry) => nullSafeEval(geometry)
       case _ => null
     }
   }
 
+  protected def serializeResult(result: Any): Any = {
+    result match {
+      case geometry: Geometry => geometry.toGenericArrayData
+      case _ => result
+    }
+  }
+
   protected def nullSafeEval(geometry: Geometry): Any
 
 
 }
 
-abstract class BinaryGeometryExpression extends Expression with ExpectsInputTypes {
+abstract class BinaryGeometryExpression extends Expression with SerdeAware with ExpectsInputTypes {
   def inputExpressions: Seq[Expression]
 
   override def nullable: Boolean = true
@@ -66,14 +83,36 @@ abstract class BinaryGeometryExpression extends Expression with ExpectsInputType
   override def inputTypes: Seq[AbstractDataType] = Seq(GeometryUDT, GeometryUDT)
 
   override def eval(input: InternalRow): Any = {
-    val leftGeometry = inputExpressions(0).toGeometry(input)
-    val rightGeometry = inputExpressions(1).toGeometry(input)
+    val result = evalWithoutSerialization(input)
+    serializeResult(result)
+  }
+
+  override def evalWithoutSerialization(input: InternalRow): Any = {
+    val leftExpression = inputExpressions(0)
+    val leftGeometry = leftExpression match {
+      case expr: SerdeAware => expr.evalWithoutSerialization(input)
+      case _ => leftExpression.toGeometry(input)
+    }
+
+    val rightExpression = inputExpressions(1)
+    val rightGeometry = rightExpression match {
+      case expr: SerdeAware => expr.evalWithoutSerialization(input)
+      case _ => rightExpression.toGeometry(input)
+    }
+
     (leftGeometry, rightGeometry) match {
       case (leftGeometry: Geometry, rightGeometry: Geometry) => nullSafeEval(leftGeometry, rightGeometry)
       case _ => null
     }
   }
 
+  protected def serializeResult(result: Any): Any = {
+    result match {
+      case geometry: Geometry => geometry.toGenericArrayData
+      case _ => result
+    }
+  }
+
   protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any
 }
 
@@ -168,7 +207,7 @@ object InferredTypes {
 abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
     (f: (A1) => R)
     (implicit val a1Tag: TypeTag[A1], implicit val rTag: TypeTag[R])
-    extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
+    extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
   import InferredTypes._
 
   def inputExpressions: Seq[Expression]
@@ -187,10 +226,12 @@ abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
 
   lazy val serialize = buildSerializer[R]
 
-  override def eval(input: InternalRow): Any = {
+  override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])
+
+  override def evalWithoutSerialization(input: InternalRow): Any = {
     val value = extract(input)
     if (value != null) {
-      serialize(f(value))
+      f(value)
     } else {
       null
     }
@@ -200,7 +241,7 @@ abstract class InferredUnaryExpression[A1: InferrableType, R: InferrableType]
 abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType, R: InferrableType]
     (f: (A1, A2) => R)
     (implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val rTag: TypeTag[R])
-    extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
+    extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
   import InferredTypes._
 
   def inputExpressions: Seq[Expression]
@@ -220,11 +261,13 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
 
   lazy val serialize = buildSerializer[R]
 
-  override def eval(input: InternalRow): Any = {
+  override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])
+
+  override def evalWithoutSerialization(input: InternalRow): Any = {
     val left = extractLeft(input)
     val right = extractRight(input)
     if (left != null && right != null) {
-      serialize(f(left, right))
+        f(left, right)
     } else {
       null
     }
@@ -234,7 +277,7 @@ abstract class InferredBinaryExpression[A1: InferrableType, A2: InferrableType,
 abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType, A3: InferrableType, R: InferrableType]
 (f: (A1, A2, A3) => R)
 (implicit val a1Tag: TypeTag[A1], implicit val a2Tag: TypeTag[A2], implicit val a3Tag: TypeTag[A3], implicit val rTag: TypeTag[R])
-  extends Expression with ImplicitCastInputTypes with CodegenFallback with Serializable {
+  extends Expression with ImplicitCastInputTypes with SerdeAware with CodegenFallback with Serializable {
   import InferredTypes._
 
   def inputExpressions: Seq[Expression]
@@ -255,12 +298,14 @@ abstract class InferredTernaryExpression[A1: InferrableType, A2: InferrableType,
 
   lazy val serialize = buildSerializer[R]
 
-  override def eval(input: InternalRow): Any = {
+  override def eval(input: InternalRow): Any = serialize(evalWithoutSerialization(input).asInstanceOf[R])
+
+  override def evalWithoutSerialization(input: InternalRow): Any = {
     val first = extractFirst(input)
     val second = extractSecond(input)
     val third = extractThird(input)
     if (first != null && second != null && third != null) {
-      serialize(f(first, second, third))
+      f(first, second, third)
     } else {
       null
     }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/SerdeAware.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/SerdeAware.scala
new file mode 100644
index 00000000..a46d7e08
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/SerdeAware.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+trait SerdeAware {
+  def evalWithoutSerialization(input: InternalRow): Any
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
index 6644560f..c259c366 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/collect/ST_Collect.scala
@@ -18,6 +18,7 @@
  */
 package org.apache.spark.sql.sedona_sql.expressions.collect
 
+
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 
@@ -26,16 +27,24 @@ import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
+import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
 import org.apache.spark.sql.types.{ArrayType, _}
 
+import org.locationtech.jts.geom.Geometry
+
 case class ST_Collect(inputExpressions: Seq[Expression])
     extends Expression
+    with SerdeAware
     with CodegenFallback {
   assert(inputExpressions.length >= 1)
 
   override def nullable: Boolean = true
 
   override def eval(input: InternalRow): Any = {
+    evalWithoutSerialization(input).asInstanceOf[Geometry].toGenericArrayData
+  }
+
+  override def evalWithoutSerialization(input: InternalRow): Any = {
     val firstElement = inputExpressions.head
 
     firstElement.dataType match {
@@ -49,14 +58,13 @@ case class ST_Collect(inputExpressions: Seq[Expression])
               .filter(_ != null)
               .map(_.toGeometry)
 
-            Collect.createMultiGeometry(geomElements).toGenericArrayData
-          case _ => Collect.createMultiGeometry(Seq()).toGenericArrayData
+            Collect.createMultiGeometry(geomElements)
+          case _ => Collect.createMultiGeometry(Seq())
         }
       case _ =>
         val geomElements =
           inputExpressions.map(_.toGeometry(input)).filter(_ != null)
-        Collect.createMultiGeometry(geomElements).toGenericArrayData
-
+        Collect.createMultiGeometry(geomElements)
     }
   }
 
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
index f3ec6162..2bacc766 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
@@ -30,9 +30,13 @@ object implicits {
 
   implicit class InputExpressionEnhancer(inputExpression: Expression) {
     def toGeometry(input: InternalRow): Geometry = {
-      inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
-        case binary: Array[Byte] => GeometrySerializer.deserialize(binary)
-        case _ => null
+      if (inputExpression.isInstanceOf[SerdeAware]) {
+        inputExpression.asInstanceOf[SerdeAware].evalWithoutSerialization(input).asInstanceOf[Geometry]
+      } else {
+        inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
+          case binary: Array[Byte] => GeometrySerializer.deserialize(binary)
+          case _ => null
+        }
       }
     }
 
diff --git a/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala b/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
new file mode 100644
index 00000000..ac9ce0b2
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sedona.sql
+
+import org.apache.sedona.common.geometrySerde.GeometrySerializer
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.sedona_sql.expressions.{ST_Buffer, ST_GeomFromText, ST_Point, ST_Union}
+import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito.{atMost, mockStatic}
+
+class SerdeAwareFunctionSpec extends TestBaseScala {
+
+  describe("SerdeAwareFunction") {
+    it("should save us some serialization and deserialization cost") {
+      // Mock GeometrySerializer
+      val factory = new GeometryFactory
+      val stubGeom = factory.createPoint(new Coordinate(1, 2))
+      val mocked = mockStatic(classOf[GeometrySerializer])
+      mocked.when(() => GeometrySerializer.deserialize(any(classOf[Array[Byte]]))).thenReturn(stubGeom)
+      mocked.when(() => GeometrySerializer.serialize(any(classOf[Geometry]))).thenReturn(Array[Byte](1, 2, 3))
+
+      val expr = ST_Union(Seq(
+        ST_Buffer(Seq(ST_GeomFromText(Seq(Literal("POINT (1 2)"), Literal(0))), Literal(1.0))),
+        ST_Point(Seq(Literal(1.0), Literal(2.0), Literal(null)))
+      ))
+
+      try {
+        // Evaluate an expression
+        expr.eval(null)
+
+        // Verify number of invocations
+        mocked.verify(
+          () => GeometrySerializer.deserialize(any(classOf[Array[Byte]])),
+          atMost(0))
+        mocked.verify(
+          () => GeometrySerializer.serialize(any(classOf[Geometry])),
+          atMost(1))
+      } finally {
+        // Undo the mock
+        mocked.close()
+      }
+    }
+  }
+}