You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/04/21 04:35:09 UTC

[incubator-sedona] branch master updated: [SEDONA-108] Write Support for GeoTiff Raster Images (#612)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 37f495c6 [SEDONA-108] Write Support for GeoTiff Raster Images (#612)
37f495c6 is described below

commit 37f495c6465abd3e2cccfc2c8d62b546b026bf83
Author: Kanchan Chowdhury <43...@users.noreply.github.com>
AuthorDate: Wed Apr 20 21:35:04 2022 -0700

    [SEDONA-108] Write Support for GeoTiff Raster Images (#612)
---
 docs/api/sql/Raster-loader.md                      | 109 ++++++++-
 .../sql/sedona_sql/io/GeotiffFileFormat.scala      | 161 ++++++++++++-
 .../spark/sql/sedona_sql/io/GeotiffSchema.scala    |  41 +++-
 .../{ImageOptions.scala => ImageReadOptions.scala} |  11 +-
 ...{ImageOptions.scala => ImageWriteOptions.scala} |  18 +-
 .../scala/org/apache/sedona/sql/rasterIOTest.scala | 265 ++++++++++++++++++++-
 6 files changed, 578 insertions(+), 27 deletions(-)

diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md
index 9d0b9b73..08bf0505 100644
--- a/docs/api/sql/Raster-loader.md
+++ b/docs/api/sql/Raster-loader.md
@@ -29,10 +29,38 @@ Output:
  |    |    |-- element: double (containsNull = true)
 ```
 
+There are three more optional parameters for reading GeoTiff:
+
+```html
+ |-- readfromCRS: Coordinate reference system of the geometry coordinates representing the location of the Geotiff. An example value of readfromCRS is EPSG:4326.
+ |-- readToCRS: If you want to tranform the Geotiff location geometry coordinates to a different coordinate reference system, you can define the target coordinate reference system with this option.
+ |-- disableErrorInCRS: (Default value false) => Indicates whether to ignore errors in CRS transformation.
+```
+
+An example with all GeoTiff read options:
+
+```Scala
+var geotiffDF = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readFromCRS", "EPSG:4499").option("readToCRS", "EPSG:4326").option("disableErrorInCRS", true).load("YOUR_PATH")
+geotiffDF.printSchema()
+```
+
+Output:
+
+```html
+ |-- image: struct (nullable = true)
+ |    |-- origin: string (nullable = true)
+ |    |-- Geometry: geometry (nullable = true)
+ |    |-- height: integer (nullable = true)
+ |    |-- width: integer (nullable = true)
+ |    |-- nBands: integer (nullable = true)
+ |    |-- data: array (nullable = true)
+ |    |    |-- element: double (containsNull = true)
+```
+
 You can also select sub-attributes individually to construct a new DataFrame
 
 ```Scala
-geotiffDF = geotiffDF.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+geotiffDF = geotiffDF.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
 geotiffDF.createOrReplaceTempView("GeotiffDataframe")
 geotiffDF.show()
 ```
@@ -142,3 +170,82 @@ Output:
 +--------------------+
 ```
 
+### Geotiff Dataframe Writer
+
+Introduction: You can write a GeoTiff dataframe as GeoTiff images using the spark `write` feature with the format `geotiff`.
+
+Since: `v1.2.1`
+
+Spark SQL example:
+
+The schema of the GeoTiff dataframe to be written can be one of the following two schemas:
+
+```html
+ |-- image: struct (nullable = true)
+ |    |-- origin: string (nullable = true)
+ |    |-- Geometry: geometry (nullable = true)
+ |    |-- height: integer (nullable = true)
+ |    |-- width: integer (nullable = true)
+ |    |-- nBands: integer (nullable = true)
+ |    |-- data: array (nullable = true)
+ |    |    |-- element: double (containsNull = true)
+```
+
+or
+
+```html
+ |-- origin: string (nullable = true)
+ |-- Geometry: geometry (nullable = true)
+ |-- height: integer (nullable = true)
+ |-- width: integer (nullable = true)
+ |-- nBands: integer (nullable = true)
+ |-- data: array (nullable = true)
+ |    |-- element: double (containsNull = true)
+```
+
+Field names can be renamed, but schema should exactly match with one of the above two schemas. The output path could be a path to a directory where GeoTiff images will be saved. If the directory already exists, `write` should be called in `overwrite` mode.
+
+```Scala
+var dfToWrite = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load("PATH_TO_INPUT_GEOTIFF_IMAGES")
+dfToWrite.write.format("geotiff").save("DESTINATION_PATH")
+```
+
+You can override an existing path with the following approach:
+
+```Scala
+dfToWrite.write.mode("overwrite").format("geotiff").save("DESTINATION_PATH")
+```
+
+You can also extract the columns nested within `image` column and write the dataframe as GeoTiff image.
+
+```Scala
+dfToWrite = dfToWrite.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+dfToWrite.write.mode("overwrite").format("geotiff").save("DESTINATION_PATH")
+```
+
+If you want the saved GeoTiff images not to be distributed into multiple partitions, you can call coalesce to merge all files in a single partition.
+
+```Scala
+dfToWrite.coalesce(1).write.mode("overwrite").format("geotiff").save("DESTINATION_PATH")
+```
+
+In case, you rename the columns of GeoTiff dataframe, you can set the corresponding column names with the `option` parameter. All available optional parameters are listed below:
+
+```html
+ |-- writeToCRS: (Default value "EPSG:4326") => Coordinate reference system of the geometry coordinates representing the location of the Geotiff.
+ |-- fieldImage: (Default value "image") => Indicates the image column of GeoTiff DataFrame.
+ |-- fieldOrigin: (Default value "origin") => Indicates the origin column of GeoTiff DataFrame.
+ |-- fieldNBands: (Default value "nBands") => Indicates the nBands column of GeoTiff DataFrame.
+ |-- fieldWidth: (Default value "width") => Indicates the width column of GeoTiff DataFrame.
+ |-- fieldHeight: (Default value "height") => Indicates the height column of GeoTiff DataFrame.
+ |-- fieldGeometry: (Default value "geometry") => Indicates the geometry column of GeoTiff DataFrame.
+ |-- fieldData: (Default value "data") => Indicates the data column of GeoTiff DataFrame.
+```
+
+An example:
+
+```Scala
+dfToWrite = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load("PATH_TO_INPUT_GEOTIFF_IMAGES")
+dfToWrite = dfToWrite.selectExpr("image.origin as source","ST_GeomFromWkt(image.geometry) as geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+dfToWrite.write.mode("overwrite").format("geotiff").option("writeToCRS", "EPSG:4326").option("fieldOrigin", "source").option("fieldGeometry", "geom").option("fieldNBands", "bands").save("DESTINATION_PATH")
+```
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
index cf9a3140..3e38cbe3 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
@@ -20,19 +20,36 @@
 
 package org.apache.spark.sql.sedona_sql.io
 
-
 import com.google.common.io.{ByteStreams, Closeables}
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.mapreduce.Job
-import org.apache.spark.sql.SparkSession
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+import org.apache.sedona.sql.utils.GeometrySerializer
+import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, OutputWriterFactory, PartitionedFile}
 import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.SerializableConfiguration
+import org.geotools.coverage.CoverageFactoryFinder
+import org.geotools.coverage.grid.io.{AbstractGridFormat, GridFormatFinder}
+import org.geotools.gce.geotiff.{GeoTiffFormat, GeoTiffWriteParams, GeoTiffWriter}
+import org.geotools.geometry.jts.ReferencedEnvelope
+import org.geotools.referencing.CRS
+import org.geotools.util.factory.Hints
+import org.locationtech.jts.geom.{Coordinate, Polygon}
+import org.locationtech.jts.io.WKTReader
+import org.opengis.parameter.GeneralParameterValue
+
+import java.awt.image.DataBuffer
+import java.io.IOException
+import java.nio.file.Paths
+import javax.imageio.ImageWriteParam
+import javax.media.jai.RasterFactory
 
 private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegister {
 
@@ -46,7 +63,18 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
                              job: Job,
                              options: Map[String, String],
                              dataSchema: StructType): OutputWriterFactory = {
-    throw new UnsupportedOperationException("Write is not supported for image data source")
+    val imageWriteOptions = new ImageWriteOptions(options)
+    if (!isValidGeoTiffSchema(imageWriteOptions, dataSchema)) {
+      throw new IllegalArgumentException("Invalid GeoTiff Schema")
+    }
+
+    new OutputWriterFactory {
+      override def getFileExtension(context: TaskAttemptContext): String = ""
+
+      override def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = {
+        new GeotiffFileWriter(path, imageWriteOptions, dataSchema, context)
+      }
+    }
   }
 
   override def shortName(): String = "geotiff"
@@ -66,7 +94,7 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
     val broadcastedHadoopConf =
       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
 
-    val imageSourceOptions = new ImageOptions(options)
+    val imageSourceOptions = new ImageReadOptions(options)
 
     (file: PartitionedFile) => {
       val emptyUnsafeRow = new UnsafeRow(0)
@@ -83,7 +111,7 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
           Closeables.close(stream, true)
         }
 
-        val resultOpt = GeotiffSchema.decode(origin, bytes)
+        val resultOpt = GeotiffSchema.decode(origin, bytes, imageSourceOptions)
         val filteredResult = if (imageSourceOptions.dropInvalid) {
           resultOpt.toIterator
         } else {
@@ -101,4 +129,123 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
       }
     }
   }
+
+  private def isValidGeoTiffSchema(imageWriteOptions: ImageWriteOptions, dataSchema: StructType): Boolean = {
+    val fields = dataSchema.fieldNames
+    if (fields.contains(imageWriteOptions.colImage) ){
+      val schemaFields = dataSchema.fields(dataSchema.fieldIndex(imageWriteOptions.colImage)).dataType.asInstanceOf[StructType]
+      if (schemaFields.fieldNames.length != 6) return false
+    }
+    else {
+      if (fields.length != 6) return false
+    }
+    true
+  }
+
+}
+
+// class for writing geoTiff images
+private class GeotiffFileWriter(savePath: String,
+                                imageWriteOptions: ImageWriteOptions,
+                                dataSchema: StructType,
+                                context: TaskAttemptContext) extends OutputWriter {
+
+  // set writing parameters
+  private val DEFAULT_WRITE_PARAMS: GeoTiffWriteParams = new GeoTiffWriteParams()
+  DEFAULT_WRITE_PARAMS.setCompressionMode(ImageWriteParam.MODE_EXPLICIT)
+  DEFAULT_WRITE_PARAMS.setCompressionType("LZW")
+  DEFAULT_WRITE_PARAMS.setCompressionQuality(0.75F)
+  DEFAULT_WRITE_PARAMS.setTilingMode(ImageWriteParam.MODE_EXPLICIT)
+  DEFAULT_WRITE_PARAMS.setTiling(512, 512)
+
+  private val hfs = new Path(savePath).getFileSystem(context.getConfiguration)
+
+  override def write(row: InternalRow): Unit = {
+    // retrieving the metadata of a geotiff image
+    var rowFields: InternalRow = row
+    var schemaFields: StructType = dataSchema
+    val fields = dataSchema.fieldNames
+
+    if (fields.contains(imageWriteOptions.colImage)) {
+      schemaFields = dataSchema.fields(dataSchema.fieldIndex(imageWriteOptions.colImage)).dataType.asInstanceOf[StructType]
+      rowFields = row.getStruct(dataSchema.fieldIndex(imageWriteOptions.colImage), 6)
+    }
+
+    val tiffOrigin = rowFields.getString(schemaFields.fieldIndex(imageWriteOptions.colOrigin))
+    val tiffBands = rowFields.getInt(schemaFields.fieldIndex(imageWriteOptions.colBands))
+    val tiffWidth = rowFields.getInt(schemaFields.fieldIndex(imageWriteOptions.colWidth))
+    val tiffHeight = rowFields.getInt(schemaFields.fieldIndex(imageWriteOptions.colHeight))
+    val tiffGeometry = Row.fromSeq(rowFields.toSeq(schemaFields)).get(schemaFields.fieldIndex(imageWriteOptions.colGeometry))
+    val tiffData = rowFields.getArray(schemaFields.fieldIndex(imageWriteOptions.colData)).toDoubleArray()
+
+    // if an image is invalid, fields are -1 and data array is empty. Skip writing that image
+    if (tiffBands == -1) return
+
+    // create a writable raster object
+    val raster = RasterFactory.createBandedRaster(DataBuffer.TYPE_DOUBLE, tiffWidth, tiffHeight, tiffBands, null)
+
+    // extract the pixels of the geotiff image and write to the writable raster
+    val pixelVal = Array.ofDim[Double](tiffBands)
+    for (i <- 0 until tiffHeight) {
+      for (j <- 0 until tiffWidth) {
+        for (k <- 0 until tiffBands) {
+          pixelVal(k) = tiffData(tiffHeight*tiffWidth*k + i * tiffWidth + j)
+        }
+        raster.setPixel(j, i, pixelVal)
+      }
+    }
+
+    // CRS is decoded to user-provided option "writeToCRS", default value is "EPSG:4326"
+    val crs = CRS.decode(imageWriteOptions.writeToCRS, true)
+
+    // Extract the geometry coordinates and set the envelop of the geotiff source
+    var coordinateList: Array[Coordinate] = null
+    if (tiffGeometry.isInstanceOf[UTF8String]) {
+      val wktReader = new WKTReader()
+      val envGeom = wktReader.read(tiffGeometry.toString).asInstanceOf[Polygon]
+      coordinateList = envGeom.getCoordinates()
+    } else {
+      val envGeom = GeometrySerializer.deserialize(tiffGeometry.asInstanceOf[ArrayData])
+      coordinateList = envGeom.getCoordinates()
+    }
+    val referencedEnvelope = new ReferencedEnvelope(coordinateList(0).x, coordinateList(2).x, coordinateList(0).y, coordinateList(2).y, crs)
+
+    // create the write path
+    val writePath = Paths.get(savePath, new Path(tiffOrigin).getName).toString
+    val out = hfs.create(new Path(writePath))
+
+    val format = GridFormatFinder.findFormat(out)
+    var hints: Hints = null
+    if (format.isInstanceOf[GeoTiffFormat]) {
+      hints = new Hints(Hints.FORCE_LONGITUDE_FIRST_AXIS_ORDER, true)
+    }
+
+    // create the writer object
+    val factory = CoverageFactoryFinder.getGridCoverageFactory(hints)
+    val gc = factory.create("GRID", raster, referencedEnvelope)
+    val writer = new GeoTiffWriter(out, hints)
+
+    val gtiffParams = new GeoTiffFormat().getWriteParameters
+    gtiffParams.parameter(AbstractGridFormat.GEOTOOLS_WRITE_PARAMS.getName.toString).setValue(DEFAULT_WRITE_PARAMS)
+    val wps: Array[GeneralParameterValue] = gtiffParams.values.toArray(new Array[GeneralParameterValue](1))
+
+    // write the geotiff image to file
+    try {
+      writer.write(gc, wps)
+      writer.dispose()
+      out.close()
+    } catch {
+      case e@(_: IllegalArgumentException | _: IOException) =>
+        // TODO Auto-generated catch block
+        e.printStackTrace()
+    }
+  }
+
+  override def close(): Unit = {
+    hfs.close()
+  }
+
+  def path(): String = {
+    savePath
+  }
 }
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
index 08234338..5a3a3595 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
@@ -26,9 +26,11 @@ import org.geotools.coverage.grid.{GridCoordinates2D, GridCoverage2D}
 import org.geotools.gce.geotiff.GeoTiffReader
 import org.geotools.geometry.jts.JTS
 import org.geotools.referencing.CRS
-import org.locationtech.jts.geom.{Coordinate, GeometryFactory}
+import org.locationtech.jts.geom.{Coordinate, GeometryFactory, Polygon}
 import org.opengis.coverage.grid.{GridCoordinates, GridEnvelope}
 import org.opengis.parameter.{GeneralParameterValue, ParameterValue}
+import org.opengis.referencing.crs.CoordinateReferenceSystem
+import org.opengis.referencing.operation.MathTransform
 
 import java.io.ByteArrayInputStream
 
@@ -40,7 +42,7 @@ object GeotiffSchema {
    */
   val columnSchema = StructType(
     StructField("origin", StringType, true) ::
-      StructField("wkt", StringType, true) ::
+      StructField("geometry", StringType, true) ::
       StructField("height", IntegerType, false) ::
       StructField("width", IntegerType, false) ::
       StructField("nBands", IntegerType, false) ::
@@ -117,7 +119,7 @@ object GeotiffSchema {
    *
    */
 
-  private[io] def decode(origin: String, bytes: Array[Byte]): Option[Row] = {
+  private[io] def decode(origin: String, bytes: Array[Byte], imageSourceOptions: ImageReadOptions): Option[Row] = {
 
     val policy: ParameterValue[OverviewPolicy] = AbstractGridFormat.OVERVIEW_POLICY.createValue
     policy.setValue(OverviewPolicy.IGNORE)
@@ -141,9 +143,31 @@ object GeotiffSchema {
     }
 
     // Fetch geometry from given image
-    val source = coverage.getCoordinateReferenceSystem
-    val target = CRS.decode("EPSG:4326", true)
-    val targetCRS = CRS.findMathTransform(source, target)
+    var source: CoordinateReferenceSystem = try {
+      coverage.getCoordinateReferenceSystem
+    }
+    catch {
+      case _: Exception => null
+    }
+    if (source == null && imageSourceOptions.readFromCRS != "") {
+      source = CRS.decode(imageSourceOptions.readFromCRS, true)
+    }
+
+    val target: CoordinateReferenceSystem = if (imageSourceOptions.readToCRS != "") {
+      CRS.decode(imageSourceOptions.readToCRS, true)
+    } else {
+      null
+    }
+
+    var targetCRS: MathTransform = null
+    if (target != null) {
+      if (source == null) {
+        throw new IllegalArgumentException("Invalid arguments. Source coordinate reference system was not found.")
+      } else {
+        targetCRS = CRS.findMathTransform(source, target, imageSourceOptions.disableErrorInCRS)
+      }
+    }
+
     val gridRange2D = coverage.getGridGeometry.getGridRange
     val cords = Array(Array(gridRange2D.getLow(0), gridRange2D.getLow(1)), Array(gridRange2D.getLow(0), gridRange2D.getHigh(1)), Array(gridRange2D.getHigh(0), gridRange2D.getHigh(1)), Array(gridRange2D.getHigh(0), gridRange2D.getLow(1)))
     val polyCoordinates = new Array[Coordinate](5)
@@ -160,7 +184,10 @@ object GeotiffSchema {
 
     polyCoordinates(index) = polyCoordinates(0)
     val factory = new GeometryFactory
-    val polygon = JTS.transform(factory.createPolygon(polyCoordinates), targetCRS)
+    var polygon = factory.createPolygon(polyCoordinates)
+    if (targetCRS != null) {
+      polygon = JTS.transform(polygon, targetCRS).asInstanceOf[Polygon]
+    }
 
     // Fetch band values from given image
     val nBands: Int = coverage.getNumSampleDimensions
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
similarity index 63%
copy from sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
copy to sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
index 94d21830..f73fc7cf 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
@@ -20,11 +20,18 @@ package org.apache.spark.sql.sedona_sql.io
 
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 
-private[io] class ImageOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
+private[io] class ImageReadOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
   def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
   /**
-   * Whether to drop invalid images. If true, invalid images will be removed, otherwise
+   * Optional parameters for reading GeoTiff
+   * dropInvalid indicatesWhether to drop invalid images. If true, invalid images will be removed, otherwise
    * invalid images will be returned with empty data and all other field filled with `-1`.
+   * disableErrorInCRS indicates whether to disable to errors in CRS transformation
+   * readFromCRS and readToCRS indicate source and target coordinate reference system, respectively.
    */
   val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
+  val disableErrorInCRS = parameters.getOrElse("disableErrorInCRS", "false").toBoolean
+  val readFromCRS = parameters.getOrElse("readFromCRS", "")
+  val readToCRS = parameters.getOrElse("readToCRS", "")
+
 }
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
similarity index 59%
rename from sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
rename to sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
index 94d21830..8653c93a 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
@@ -20,11 +20,17 @@ package org.apache.spark.sql.sedona_sql.io
 
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 
-private[io] class ImageOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
+private[io] class ImageWriteOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
   def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
-  /**
-   * Whether to drop invalid images. If true, invalid images will be removed, otherwise
-   * invalid images will be returned with empty data and all other field filled with `-1`.
-   */
-  val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
+
+  // Optional parameters for writing GeoTiff
+  val writeToCRS = parameters.getOrElse("writeToCRS", "EPSG:4326")
+  val colImage = parameters.getOrElse("fieldImage", "image")
+  val colOrigin = parameters.getOrElse("fieldOrigin", "origin")
+  val colBands = parameters.getOrElse("fieldNBands", "nBands")
+  val colWidth = parameters.getOrElse("fieldWidth", "width")
+  val colHeight = parameters.getOrElse("fieldHeight", "height")
+  val colGeometry = parameters.getOrElse("fieldGeometry", "geometry")
+  val colData = parameters.getOrElse("fieldData", "data")
+
 }
\ No newline at end of file
diff --git a/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala b/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
index 406f7086..418daea4 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
@@ -22,6 +22,7 @@ package org.apache.sedona.sql
 import org.locationtech.jts.geom.Geometry
 import org.scalatest.{BeforeAndAfter, GivenWhenThen}
 
+import java.io.File
 import scala.collection.mutable
 
 class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen {
@@ -29,9 +30,72 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
   var rasterdatalocation: String = resourceFolder + "raster/"
 
   describe("Raster IO test") {
-    it("Should Pass geotiff loading") {
+    it("Should Pass geotiff loading without readFromCRS and readToCRS") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
-      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-13095782 4021226.5, -13095782 3983905, -13058822 3983905, -13058822 4021226.5, -13095782 4021226.5))")
+      assert(df.first().getInt(2) == 517)
+      assert(df.first().getInt(3) == 512)
+      assert(df.first().getInt(5) == 1)
+      val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+      val line1 = blackBand.slice(0, 512)
+      val line2 = blackBand.slice(512, 1024)
+      assert(line1(0) == 0.0) // The first value at line 1 is black
+      assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+    }
+
+    it("Should Pass geotiff loading with readToCRS") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-117.64141128097314 33.94356351407699, -117.64141128097314 33.664978146501284, -117.30939395196258 33.664978146501284," +
+        " -117.30939395196258 33.94356351407699, -117.64141128097314 33.94356351407699))")
+      assert(df.first().getInt(2) == 517)
+      assert(df.first().getInt(3) == 512)
+      assert(df.first().getInt(5) == 1)
+      val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+      val line1 = blackBand.slice(0, 512)
+      val line2 = blackBand.slice(512, 1024)
+      assert(line1(0) == 0.0) // The first value at line 1 is black
+      assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+    }
+
+    it("Should Pass geotiff loading with readFromCRS") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readFromCRS", "EPSG:4499").load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-13095782 4021226.5, -13095782 3983905, -13058822 3983905, -13058822 4021226.5, -13095782 4021226.5))")
+      assert(df.first().getInt(2) == 517)
+      assert(df.first().getInt(3) == 512)
+      assert(df.first().getInt(5) == 1)
+      val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+      val line1 = blackBand.slice(0, 512)
+      val line2 = blackBand.slice(512, 1024)
+      assert(line1(0) == 0.0) // The first value at line 1 is black
+      assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+    }
+
+    it("Should Pass geotiff loading with readFromCRS and readToCRS") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readFromCRS", "EPSG:4499").option("readToCRS", "EPSG:4326").load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-117.64141128097314 33.94356351407699, -117.64141128097314 33.664978146501284, -117.30939395196258 33.664978146501284," +
+        " -117.30939395196258 33.94356351407699, -117.64141128097314 33.94356351407699))")
+      assert(df.first().getInt(2) == 517)
+      assert(df.first().getInt(3) == 512)
+      assert(df.first().getInt(5) == 1)
+      val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+      val line1 = blackBand.slice(0, 512)
+      val line2 = blackBand.slice(512, 1024)
+      assert(line1(0) == 0.0) // The first value at line 1 is black
+      assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+    }
+
+    it("Should Pass geotiff loading with all read options") {
+      var df = sparkSession.read.format("geotiff")
+        .option("dropInvalid", true)
+        .option("readFromCRS", "EPSG:4499")
+        .option("readToCRS", "EPSG:4326")
+        .option("disableErrorInCRS", true)
+        .load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
       assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-117.64141128097314 33.94356351407699, -117.64141128097314 33.664978146501284, -117.30939395196258 33.664978146501284," +
         " -117.30939395196258 33.94356351407699, -117.64141128097314 33.94356351407699))")
       assert(df.first().getInt(2) == 517)
@@ -53,7 +117,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
 
     it("should pass RS_Base64") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(resourceFolder + "raster/")
-      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
       df = df.selectExpr("RS_GetBand(data, 1, bands) as targetBand", "width","height")
       df.createOrReplaceTempView("geotiff")
       df = sparkSession.sql("Select RS_base64(height, width, targetBand, RS_Array(height*width, 0), RS_Array(height*width, 0)) as encodedstring from geotiff")
@@ -62,7 +126,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
 
     it("should pass RS_HTML") {
       var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(resourceFolder + "raster/")
-      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
       df = df.selectExpr("RS_GetBand(data, 1, bands) as targetBand", "width","height")
       df.createOrReplaceTempView("geotiff")
       df = sparkSession.sql("Select RS_base64(height, width, targetBand, RS_Array(height*width, 0.0), RS_Array(height*width, 0.0)) as encodedstring from geotiff")
@@ -91,6 +155,199 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
       df = df.selectExpr("RS_GetBand(data, 4, bands) as targetBand")
       assert(df.first().getAs[mutable.WrappedArray[Double]](0)(2) == 0.0)
     }
+
+    it("Should Pass geotiff file writing with coalesce") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+      val savePath = resourceFolder + "raster-written/"
+      df.coalesce(1).write.mode("overwrite").format("geotiff").save(savePath)
+
+      var loadPath = savePath
+      val tempFile = new File(loadPath)
+      val fileList = tempFile.listFiles()
+      for (i <- 0 until fileList.length) {
+        if (fileList(i).isDirectory) loadPath = fileList(i).getAbsolutePath
+      }
+
+      var dfWritten = sparkSession.read.format("geotiff").option("dropInvalid", true).load(loadPath)
+      dfWritten = dfWritten.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      val rowFirst = dfWritten.first()
+      assert(rowFirst.getInt(2) == 517)
+      assert(rowFirst.getInt(3) == 512)
+      assert(rowFirst.getInt(5) == 1)
+
+      val blackBand = rowFirst.getAs[mutable.WrappedArray[Double]](4)
+      val line1 = blackBand.slice(0, 512)
+      val line2 = blackBand.slice(512, 1024)
+      assert(line1(0) == 0.0) // The first value at line 1 is black
+      assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+    }
+
+    it("Should Pass geotiff file writing with writeToCRS") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+      val savePath = resourceFolder + "raster-written/"
+      df.coalesce(1).write.mode("overwrite").format("geotiff").option("writeToCRS", "EPSG:4499").save(savePath)
+
+      var loadPath = savePath
+      val tempFile = new File(loadPath)
+      val fileList = tempFile.listFiles()
+      for (i <- 0 until fileList.length) {
+        if (fileList(i).isDirectory) loadPath = fileList(i).getAbsolutePath
+      }
+
+      var dfWritten = sparkSession.read.format("geotiff").option("dropInvalid", true).load(loadPath)
+      dfWritten = dfWritten.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      val rowFirst = dfWritten.first()
+      assert(rowFirst.getInt(2) == 517)
+      assert(rowFirst.getInt(3) == 512)
+      assert(rowFirst.getInt(5) == 1)
+
+      val blackBand = rowFirst.getAs[mutable.WrappedArray[Double]](4)
+      val line1 = blackBand.slice(0, 512)
+      val line2 = blackBand.slice(512, 1024)
+      assert(line1(0) == 0.0) // The first value at line 1 is black
+      assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+    }
+
+    it("Should Pass geotiff file writing without coalesce") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+      val savePath = resourceFolder + "raster-written/"
+      df.write.mode("overwrite").format("geotiff").save(savePath)
+
+      var imageCount = 0
+      def getFile(loadPath: String): Unit ={
+        val tempFile = new File(loadPath)
+        val fileList = tempFile.listFiles()
+        if (fileList == null) return
+        for (i <- 0 until fileList.length) {
+          if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+          else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+        }
+      }
+
+      getFile(savePath)
+      assert(imageCount == 3)
+    }
+
+    it("Should Pass geotiff file writing with nested schema") {
+      val df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+      val savePath = resourceFolder + "raster-written/"
+      df.write.mode("overwrite").format("geotiff").save(savePath)
+
+      var imageCount = 0
+      def getFile(loadPath: String): Unit ={
+        val tempFile = new File(loadPath)
+        val fileList = tempFile.listFiles()
+        if (fileList == null) return
+        for (i <- 0 until fileList.length) {
+          if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+          else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+        }
+      }
+
+      getFile(savePath)
+      assert(imageCount == 3)
+    }
+
+    it("Should Pass geotiff file writing with renamed fields") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+      df = df.selectExpr("image.origin as source","image.geometry as geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      val savePath = resourceFolder + "raster-written/"
+      df.write
+        .mode("overwrite")
+        .format("geotiff")
+        .option("fieldOrigin", "source")
+        .option("fieldGeometry", "geom")
+        .option("fieldNBands", "bands")
+        .save(savePath)
+
+      var imageCount = 0
+      def getFile(loadPath: String): Unit ={
+        val tempFile = new File(loadPath)
+        val fileList = tempFile.listFiles()
+        if (fileList == null) return
+        for (i <- 0 until fileList.length) {
+          if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+          else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+        }
+      }
+
+      getFile(savePath)
+      assert(imageCount == 3)
+    }
+
+    it("Should Pass geotiff file writing with nested schema and renamed fields") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+      df = df.selectExpr("image as tiff_image")
+      val savePath = resourceFolder + "raster-written/"
+      df.write
+        .mode("overwrite")
+        .format("geotiff")
+        .option("fieldImage", "tiff_image")
+        .save(savePath)
+
+      var imageCount = 0
+      def getFile(loadPath: String): Unit ={
+        val tempFile = new File(loadPath)
+        val fileList = tempFile.listFiles()
+        if (fileList == null) return
+        for (i <- 0 until fileList.length) {
+          if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+          else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+        }
+      }
+
+      getFile(savePath)
+      assert(imageCount == 3)
+    }
+
+    it("Should Pass geotiff file writing with converted geometry") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+      df = df.selectExpr("image.origin as source","ST_GeomFromWkt(image.geometry) as geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+      val savePath = resourceFolder + "raster-written/"
+      df.write
+        .mode("overwrite")
+        .format("geotiff")
+        .option("fieldOrigin", "source")
+        .option("fieldGeometry", "geom")
+        .option("fieldNBands", "bands")
+        .save(savePath)
+
+      var imageCount = 0
+      def getFile(loadPath: String): Unit ={
+        val tempFile = new File(loadPath)
+        val fileList = tempFile.listFiles()
+        if (fileList == null) return
+        for (i <- 0 until fileList.length) {
+          if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+          else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+        }
+      }
+
+      getFile(savePath)
+      assert(imageCount == 3)
+    }
+
+    it("Should Pass geotiff file writing with handling invalid schema") {
+      var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+      df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data")
+      val savePath = resourceFolder + "raster-written/"
+
+      try {
+        df.write
+          .mode("overwrite")
+          .format("geotiff")
+          .option("fieldImage", "tiff_image")
+          .save(savePath)
+      }
+      catch {
+        case e: IllegalArgumentException => {
+          assert(e.getMessage == "Invalid GeoTiff Schema")
+        }
+      }
+    }
     
   }
 }