You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/04/15 01:11:47 UTC

[sedona] branch master updated: [SEDONA-275] Add raster function RS_SetSRID (#817)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new ed45311f [SEDONA-275] Add raster function RS_SetSRID (#817)
ed45311f is described below

commit ed45311f1401b1dc18bb900b4132a3203fed0326
Author: Martin Andersson <u....@gmail.com>
AuthorDate: Sat Apr 15 03:11:40 2023 +0200

    [SEDONA-275] Add raster function RS_SetSRID (#817)
---
 .../org/apache/sedona/common/raster/Functions.java | 16 +++++++++++
 .../apache/sedona/common/raster/FunctionsTest.java | 15 ++++++++++
 docs/api/sql/Raster-operators.md                   | 14 +++++++++
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  1 +
 .../sedona_sql/expressions/raster/Functions.scala  | 33 ++++++++++++++++++++--
 .../org/apache/sedona/sql/rasteralgebraTest.scala  | 11 ++++++++
 6 files changed, 88 insertions(+), 2 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/raster/Functions.java b/common/src/main/java/org/apache/sedona/common/raster/Functions.java
index a7761bb4..809f9b33 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/Functions.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/Functions.java
@@ -13,11 +13,15 @@
  */
 package org.apache.sedona.common.raster;
 
+import org.geotools.coverage.CoverageFactoryFinder;
 import org.geotools.coverage.grid.GridCoordinates2D;
 import org.geotools.coverage.grid.GridCoverage2D;
+import org.geotools.coverage.grid.GridCoverageFactory;
 import org.geotools.coverage.grid.GridGeometry2D;
+import org.geotools.gce.geotiff.GeoTiffWriter;
 import org.geotools.geometry.DirectPosition2D;
 import org.geotools.geometry.Envelope2D;
+import org.geotools.geometry.jts.ReferencedEnvelope;
 import org.geotools.referencing.CRS;
 import org.geotools.referencing.crs.DefaultEngineeringCRS;
 import org.locationtech.jts.geom.*;
@@ -48,6 +52,18 @@ public class Functions {
         return raster.getNumSampleDimensions();
     }
 
+    public static GridCoverage2D setSrid(GridCoverage2D raster, int srid) throws FactoryException {
+        CoordinateReferenceSystem crs;
+        if (srid == 0) {
+            crs = DefaultEngineeringCRS.CARTESIAN_2D;
+        } else {
+            crs = CRS.decode("EPSG:" + srid);
+        }
+        ReferencedEnvelope referencedEnvelope = new ReferencedEnvelope(raster.getEnvelope2D(), crs);
+        GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null);
+        return gridCoverageFactory.create(raster.getName().toString(), raster.getRenderedImage(), referencedEnvelope);
+    }
+
     public static int srid(GridCoverage2D raster) throws FactoryException {
         CoordinateReferenceSystem crs = raster.getCoordinateReferenceSystem();
         if (crs instanceof DefaultEngineeringCRS) {
diff --git a/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java
index f9b3a40b..90e0d68f 100644
--- a/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/raster/FunctionsTest.java
@@ -13,6 +13,7 @@
  */
 package org.apache.sedona.common.raster;
 
+import org.geotools.coverage.grid.GridCoverage2D;
 import org.junit.Test;
 import org.locationtech.jts.geom.Coordinate;
 import org.locationtech.jts.geom.Geometry;
@@ -45,6 +46,20 @@ public class FunctionsTest extends RasterTestBase {
         assertEquals(4, Functions.numBands(multiBandRaster));
     }
 
+    @Test
+    public void testSetSrid() throws FactoryException {
+        assertEquals(0, Functions.srid(oneBandRaster));
+        assertEquals(4326, Functions.srid(multiBandRaster));
+
+        GridCoverage2D oneBandRasterWithUpdatedSrid = Functions.setSrid(oneBandRaster, 4326);
+        assertEquals(4326, Functions.srid(oneBandRasterWithUpdatedSrid));
+        assertEquals(4326, Functions.envelope(oneBandRasterWithUpdatedSrid).getSRID());
+        assertTrue(Functions.envelope(oneBandRasterWithUpdatedSrid).equalsTopo(Functions.envelope(oneBandRaster)));
+
+        GridCoverage2D multiBandRasterWithUpdatedSrid = Functions.setSrid(multiBandRaster, 0);
+        assertEquals(0 , Functions.srid(multiBandRasterWithUpdatedSrid));
+    }
+
     @Test
     public void testSrid() throws FactoryException {
         assertEquals(0, Functions.srid(oneBandRaster));
diff --git a/docs/api/sql/Raster-operators.md b/docs/api/sql/Raster-operators.md
index 0388d35a..0b49daf0 100644
--- a/docs/api/sql/Raster-operators.md
+++ b/docs/api/sql/Raster-operators.md
@@ -35,6 +35,20 @@ Output:
 4
 ```
 
+## RS_SetSRID
+
+Introduction: Sets the spatial reference system identifier (SRID) of the raster geometry.
+
+Format: `RS_SetSRID (raster: Raster, srid: Integer)`
+
+Since: `v1.4.1`
+
+Spark SQL example:
+```sql
+SELECT RS_SetSRID(raster, 4326)
+FROM raster_table
+```
+
 ### RS_SRID
 
 Introduction: Returns the spatial reference system identifier (SRID) of the raster geometry.
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 109d095d..03dbbab5 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -173,6 +173,7 @@ object Catalog {
     function[RS_FromGeoTiff](),
     function[RS_Envelope](),
     function[RS_NumBands](),
+    function[RS_SetSRID](),
     function[RS_SRID](),
     function[RS_Value](1),
     function[RS_Values](1)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala
index 62c43a3f..d904bf9d 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/Functions.scala
@@ -20,16 +20,17 @@
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
 import org.apache.sedona.common.geometrySerde.GeometrySerializer
-import org.apache.sedona.common.raster.Functions
+import org.apache.sedona.common.raster.{Functions, Serde}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
-import org.apache.spark.sql.sedona_sql.expressions.UserDataGeneratator
+import org.apache.spark.sql.sedona_sql.expressions.{SerdeAware, UserDataGeneratator}
 import org.apache.spark.sql.sedona_sql.expressions.implicits._
 import org.apache.spark.sql.sedona_sql.expressions.raster.implicits._
 import org.apache.spark.sql.types._
+import org.geotools.coverage.grid.GridCoverage2D
 
 
 
@@ -855,6 +856,34 @@ case class RS_NumBands(inputExpressions: Seq[Expression]) extends Expression wit
   override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT)
 }
 
+case class RS_SetSRID(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes with SerdeAware {
+  override def nullable: Boolean = true
+
+  override def eval(input: InternalRow): Any = {
+    Option(evalWithoutSerialization(input)).map(Serde.serialize).orNull
+  }
+
+  override def evalWithoutSerialization(input: InternalRow): GridCoverage2D = {
+    val raster = inputExpressions(0).toRaster(input)
+    val srid = inputExpressions(1).eval(input).asInstanceOf[Int]
+    if (raster == null) {
+      null
+    } else {
+      Functions.setSrid(raster, srid)
+    }
+  }
+
+  override def dataType: DataType = RasterUDT
+
+  override def children: Seq[Expression] = inputExpressions
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+    copy(inputExpressions = newChildren)
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(RasterUDT, IntegerType)
+}
+
 case class RS_SRID(inputExpressions: Seq[Expression]) extends Expression with CodegenFallback with ExpectsInputTypes {
   override def nullable: Boolean = true
 
diff --git a/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
index 09ec4d44..7dcc678d 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
@@ -290,6 +290,17 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
       assert(result == 1)
     }
 
+    it("Passed RS_SetSRID should handle null values") {
+      val result = sparkSession.sql("select RS_SetSRID(null, 0)").first().get(0)
+      assert(result == null)
+    }
+
+    it("Passed RS_SetSRID with raster") {
+      val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
+      val result = df.selectExpr("RS_SRID(RS_SetSRID(RS_FromGeoTiff(content), 4326))").first().getInt(0)
+      assert(result == 4326)
+    }
+
     it("Passed RS_SRID should handle null values") {
       val result = sparkSession.sql("select RS_SRID(null)").first().get(0)
       assert(result == null)