You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/03/25 03:29:55 UTC

[sedona] branch master updated: [SEDONA-270] Remove redundant serialization for rasters (#810)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 0579e9a5 [SEDONA-270] Remove redundant serialization for rasters (#810)
0579e9a5 is described below

commit 0579e9a5b3039066eade44f404ef286a067416b0
Author: Martin Andersson <u....@gmail.com>
AuthorDate: Sat Mar 25 04:29:49 2023 +0100

    [SEDONA-270] Remove redundant serialization for rasters (#810)
---
 .../expressions/raster/Constructors.scala          | 42 +++++++++++++--------
 .../sedona_sql/expressions/raster/implicits.scala  | 11 ++++--
 .../org/apache/sedona/sql/serdeAwareTest.scala     | 43 +++++++++++++++++++++-
 3 files changed, 77 insertions(+), 19 deletions(-)

diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Constructors.scala
index 37fb93ab..c71f9963 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Constructors.scala
@@ -18,26 +18,24 @@
  */
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
-import org.apache.sedona.common.raster.Constructors
+import org.apache.sedona.common.raster.{Constructors, Serde}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.sedona_sql.UDT.RasterUDT
+import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
 import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType}
 import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
+import org.geotools.coverage.grid.GridCoverage2D
 
 
-case class RS_FromArcInfoAsciiGrid(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {
+case class RS_FromArcInfoAsciiGrid(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback
+  with ExpectsInputTypes with SerdeAware {
 
   override def nullable: Boolean = true
 
   override def eval(input: InternalRow): Any = {
-    val bytes = inputExpressions(0).eval(input).asInstanceOf[Array[Byte]]
-    if (bytes == null) {
-      null
-    } else {
-      Constructors.fromArcInfoAsciiGrid(bytes).serialize
-    }
+    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
   }
 
   override def dataType: DataType = RasterUDT
@@ -49,20 +47,25 @@ case class RS_FromArcInfoAsciiGrid(inputExpressions: Seq[Expression]) extends Ex
   }
 
   override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
-}
 
-case class RS_FromGeoTiff(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
+  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
     val bytes = inputExpressions(0).eval(input).asInstanceOf[Array[Byte]]
     if (bytes == null) {
       null
     } else {
-      Constructors.fromGeoTiff(bytes).serialize
+      Constructors.fromArcInfoAsciiGrid(bytes)
     }
   }
+}
+
+case class RS_FromGeoTiff(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback
+  with ExpectsInputTypes with SerdeAware {
+
+  override def nullable: Boolean = true
+
+  override def eval(input: InternalRow): Any = {
+    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
+  }
 
   override def dataType: DataType = RasterUDT
 
@@ -73,4 +76,13 @@ case class RS_FromGeoTiff(inputExpressions: Seq[Expression]) extends Expression
   }
 
   override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+
+  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
+    val bytes = inputExpressions(0).eval(input).asInstanceOf[Array[Byte]]
+    if (bytes == null) {
+      null
+    } else {
+      Constructors.fromGeoTiff(bytes)
+    }
+  }
 }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala
index c4f4931a..f1e1c6bf 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala
@@ -21,15 +21,20 @@ package org.apache.spark.sql.sedona_sql.expressions.raster
 import org.apache.sedona.common.raster.Serde
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
 import org.geotools.coverage.grid.GridCoverage2D
 
 object implicits {
 
   implicit class RasterInputExpressionEnhancer(inputExpression: Expression) {
     def toRaster(input: InternalRow): GridCoverage2D = {
-      inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
-        case binary: Array[Byte] => Serde.deserialize(binary)
-        case _ => null
+      if (inputExpression.isInstanceOf[SerdeAware]) {
+        inputExpression.asInstanceOf[SerdeAware].evalWithoutSerialization(input).asInstanceOf[GridCoverage2D]
+      } else {
+        inputExpression.eval(input).asInstanceOf[Array[Byte]] match {
+          case binary: Array[Byte] => Serde.deserialize(binary)
+          case _ => null
+        }
       }
     }
   }
diff --git a/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala b/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
index ac9ce0b2..27f0ca54 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
@@ -20,8 +20,12 @@
 package org.apache.sedona.sql
 
 import org.apache.sedona.common.geometrySerde.GeometrySerializer
+import org.apache.sedona.common.raster.Constructors.fromArcInfoAsciiGrid
+import org.apache.sedona.common.raster.Serde
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.sedona_sql.expressions.{ST_Buffer, ST_GeomFromText, ST_Point, ST_Union}
+import org.apache.spark.sql.sedona_sql.expressions.raster.{RS_FromArcInfoAsciiGrid, RS_NumBands}
+import org.geotools.coverage.grid.GridCoverage2D
 import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
 import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito.{atMost, mockStatic}
@@ -29,7 +33,7 @@ import org.mockito.Mockito.{atMost, mockStatic}
 class SerdeAwareFunctionSpec extends TestBaseScala {
 
   describe("SerdeAwareFunction") {
-    it("should save us some serialization and deserialization cost") {
+    it("should save us some serialization and deserialization cost for geometries") {
       // Mock GeometrySerializer
       val factory = new GeometryFactory
       val stubGeom = factory.createPoint(new Coordinate(1, 2))
@@ -58,5 +62,42 @@ class SerdeAwareFunctionSpec extends TestBaseScala {
         mocked.close()
       }
     }
+
+    it("should save us some serialization and deserialization cost for rasters") {
+      // Mock RasterSerializer
+      val ascGrid =
+        """
+          |NCOLS 2
+          |NROWS 2
+          |XLLCORNER 378922
+          |YLLCORNER 4072345
+          |CELLSIZE 30
+          |NODATA_VALUE 0
+          |0 1 2 3
+          |""".stripMargin
+      val mocked = mockStatic(classOf[Serde])
+      mocked.when(() => Serde.deserialize(any(classOf[Array[Byte]]))).thenReturn(fromArcInfoAsciiGrid(ascGrid.getBytes))
+      mocked.when(() => Serde.serialize(any(classOf[GridCoverage2D]))).thenReturn(Array[Byte](1, 2, 3))
+
+      val expr = RS_NumBands(Seq(
+        RS_FromArcInfoAsciiGrid(Seq(Literal(ascGrid.getBytes)))
+      ))
+
+      try {
+        // Evaluate an expression
+        expr.eval(null)
+
+        // Verify number of invocations
+        mocked.verify(
+          () => Serde.deserialize(any(classOf[Array[Byte]])),
+          atMost(0))
+        mocked.verify(
+          () => Serde.serialize(any(classOf[GridCoverage2D])),
+          atMost(1))
+      } finally {
+        // Undo the mock
+        mocked.close()
+      }
+    }
   }
 }