You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/04/21 04:35:09 UTC
[incubator-sedona] branch master updated: [SEDONA-108] Write Support for GeoTiff Raster Images (#612)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 37f495c6 [SEDONA-108] Write Support for GeoTiff Raster Images (#612)
37f495c6 is described below
commit 37f495c6465abd3e2cccfc2c8d62b546b026bf83
Author: Kanchan Chowdhury <43...@users.noreply.github.com>
AuthorDate: Wed Apr 20 21:35:04 2022 -0700
[SEDONA-108] Write Support for GeoTiff Raster Images (#612)
---
docs/api/sql/Raster-loader.md | 109 ++++++++-
.../sql/sedona_sql/io/GeotiffFileFormat.scala | 161 ++++++++++++-
.../spark/sql/sedona_sql/io/GeotiffSchema.scala | 41 +++-
.../{ImageOptions.scala => ImageReadOptions.scala} | 11 +-
...{ImageOptions.scala => ImageWriteOptions.scala} | 18 +-
.../scala/org/apache/sedona/sql/rasterIOTest.scala | 265 ++++++++++++++++++++-
6 files changed, 578 insertions(+), 27 deletions(-)
diff --git a/docs/api/sql/Raster-loader.md b/docs/api/sql/Raster-loader.md
index 9d0b9b73..08bf0505 100644
--- a/docs/api/sql/Raster-loader.md
+++ b/docs/api/sql/Raster-loader.md
@@ -29,10 +29,38 @@ Output:
| | |-- element: double (containsNull = true)
```
+There are three more optional parameters for reading GeoTiff:
+
+```html
+ |-- readfromCRS: Coordinate reference system of the geometry coordinates representing the location of the Geotiff. An example value of readfromCRS is EPSG:4326.
+ |-- readToCRS: If you want to tranform the Geotiff location geometry coordinates to a different coordinate reference system, you can define the target coordinate reference system with this option.
+ |-- disableErrorInCRS: (Default value false) => Indicates whether to ignore errors in CRS transformation.
+```
+
+An example with all GeoTiff read options:
+
+```Scala
+var geotiffDF = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readFromCRS", "EPSG:4499").option("readToCRS", "EPSG:4326").option("disableErrorInCRS", true).load("YOUR_PATH")
+geotiffDF.printSchema()
+```
+
+Output:
+
+```html
+ |-- image: struct (nullable = true)
+ | |-- origin: string (nullable = true)
+ | |-- Geometry: geometry (nullable = true)
+ | |-- height: integer (nullable = true)
+ | |-- width: integer (nullable = true)
+ | |-- nBands: integer (nullable = true)
+ | |-- data: array (nullable = true)
+ | | |-- element: double (containsNull = true)
+```
+
You can also select sub-attributes individually to construct a new DataFrame
```Scala
-geotiffDF = geotiffDF.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+geotiffDF = geotiffDF.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
geotiffDF.createOrReplaceTempView("GeotiffDataframe")
geotiffDF.show()
```
@@ -142,3 +170,82 @@ Output:
+--------------------+
```
+### Geotiff Dataframe Writer
+
+Introduction: You can write a GeoTiff dataframe as GeoTiff images using the spark `write` feature with the format `geotiff`.
+
+Since: `v1.2.1`
+
+Spark SQL example:
+
+The schema of the GeoTiff dataframe to be written can be one of the following two schemas:
+
+```html
+ |-- image: struct (nullable = true)
+ | |-- origin: string (nullable = true)
+ | |-- Geometry: geometry (nullable = true)
+ | |-- height: integer (nullable = true)
+ | |-- width: integer (nullable = true)
+ | |-- nBands: integer (nullable = true)
+ | |-- data: array (nullable = true)
+ | | |-- element: double (containsNull = true)
+```
+
+or
+
+```html
+ |-- origin: string (nullable = true)
+ |-- Geometry: geometry (nullable = true)
+ |-- height: integer (nullable = true)
+ |-- width: integer (nullable = true)
+ |-- nBands: integer (nullable = true)
+ |-- data: array (nullable = true)
+ | |-- element: double (containsNull = true)
+```
+
+Field names can be renamed, but schema should exactly match with one of the above two schemas. The output path could be a path to a directory where GeoTiff images will be saved. If the directory already exists, `write` should be called in `overwrite` mode.
+
+```Scala
+var dfToWrite = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load("PATH_TO_INPUT_GEOTIFF_IMAGES")
+dfToWrite.write.format("geotiff").save("DESTINATION_PATH")
+```
+
+You can override an existing path with the following approach:
+
+```Scala
+dfToWrite.write.mode("overwrite").format("geotiff").save("DESTINATION_PATH")
+```
+
+You can also extract the columns nested within `image` column and write the dataframe as GeoTiff image.
+
+```Scala
+dfToWrite = dfToWrite.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+dfToWrite.write.mode("overwrite").format("geotiff").save("DESTINATION_PATH")
+```
+
+If you want the saved GeoTiff images not to be distributed into multiple partitions, you can call coalesce to merge all files in a single partition.
+
+```Scala
+dfToWrite.coalesce(1).write.mode("overwrite").format("geotiff").save("DESTINATION_PATH")
+```
+
+In case, you rename the columns of GeoTiff dataframe, you can set the corresponding column names with the `option` parameter. All available optional parameters are listed below:
+
+```html
+ |-- writeToCRS: (Default value "EPSG:4326") => Coordinate reference system of the geometry coordinates representing the location of the Geotiff.
+ |-- fieldImage: (Default value "image") => Indicates the image column of GeoTiff DataFrame.
+ |-- fieldOrigin: (Default value "origin") => Indicates the origin column of GeoTiff DataFrame.
+ |-- fieldNBands: (Default value "nBands") => Indicates the nBands column of GeoTiff DataFrame.
+ |-- fieldWidth: (Default value "width") => Indicates the width column of GeoTiff DataFrame.
+ |-- fieldHeight: (Default value "height") => Indicates the height column of GeoTiff DataFrame.
+ |-- fieldGeometry: (Default value "geometry") => Indicates the geometry column of GeoTiff DataFrame.
+ |-- fieldData: (Default value "data") => Indicates the data column of GeoTiff DataFrame.
+```
+
+An example:
+
+```Scala
+dfToWrite = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load("PATH_TO_INPUT_GEOTIFF_IMAGES")
+dfToWrite = dfToWrite.selectExpr("image.origin as source","ST_GeomFromWkt(image.geometry) as geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+dfToWrite.write.mode("overwrite").format("geotiff").option("writeToCRS", "EPSG:4326").option("fieldOrigin", "source").option("fieldGeometry", "geom").option("fieldNBands", "bands").save("DESTINATION_PATH")
+```
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
index cf9a3140..3e38cbe3 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffFileFormat.scala
@@ -20,19 +20,36 @@
package org.apache.spark.sql.sedona_sql.io
-
import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.mapreduce.Job
-import org.apache.spark.sql.SparkSession
+import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
+import org.apache.sedona.sql.utils.GeometrySerializer
+import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.SerializableConfiguration
+import org.geotools.coverage.CoverageFactoryFinder
+import org.geotools.coverage.grid.io.{AbstractGridFormat, GridFormatFinder}
+import org.geotools.gce.geotiff.{GeoTiffFormat, GeoTiffWriteParams, GeoTiffWriter}
+import org.geotools.geometry.jts.ReferencedEnvelope
+import org.geotools.referencing.CRS
+import org.geotools.util.factory.Hints
+import org.locationtech.jts.geom.{Coordinate, Polygon}
+import org.locationtech.jts.io.WKTReader
+import org.opengis.parameter.GeneralParameterValue
+
+import java.awt.image.DataBuffer
+import java.io.IOException
+import java.nio.file.Paths
+import javax.imageio.ImageWriteParam
+import javax.media.jai.RasterFactory
private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegister {
@@ -46,7 +63,18 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
- throw new UnsupportedOperationException("Write is not supported for image data source")
+ val imageWriteOptions = new ImageWriteOptions(options)
+ if (!isValidGeoTiffSchema(imageWriteOptions, dataSchema)) {
+ throw new IllegalArgumentException("Invalid GeoTiff Schema")
+ }
+
+ new OutputWriterFactory {
+ override def getFileExtension(context: TaskAttemptContext): String = ""
+
+ override def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = {
+ new GeotiffFileWriter(path, imageWriteOptions, dataSchema, context)
+ }
+ }
}
override def shortName(): String = "geotiff"
@@ -66,7 +94,7 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
- val imageSourceOptions = new ImageOptions(options)
+ val imageSourceOptions = new ImageReadOptions(options)
(file: PartitionedFile) => {
val emptyUnsafeRow = new UnsafeRow(0)
@@ -83,7 +111,7 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
Closeables.close(stream, true)
}
- val resultOpt = GeotiffSchema.decode(origin, bytes)
+ val resultOpt = GeotiffSchema.decode(origin, bytes, imageSourceOptions)
val filteredResult = if (imageSourceOptions.dropInvalid) {
resultOpt.toIterator
} else {
@@ -101,4 +129,123 @@ private[spark] class GeotiffFileFormat extends FileFormat with DataSourceRegiste
}
}
}
+
+ private def isValidGeoTiffSchema(imageWriteOptions: ImageWriteOptions, dataSchema: StructType): Boolean = {
+ val fields = dataSchema.fieldNames
+ if (fields.contains(imageWriteOptions.colImage) ){
+ val schemaFields = dataSchema.fields(dataSchema.fieldIndex(imageWriteOptions.colImage)).dataType.asInstanceOf[StructType]
+ if (schemaFields.fieldNames.length != 6) return false
+ }
+ else {
+ if (fields.length != 6) return false
+ }
+ true
+ }
+
+}
+
+// class for writing geoTiff images
+private class GeotiffFileWriter(savePath: String,
+ imageWriteOptions: ImageWriteOptions,
+ dataSchema: StructType,
+ context: TaskAttemptContext) extends OutputWriter {
+
+ // set writing parameters
+ private val DEFAULT_WRITE_PARAMS: GeoTiffWriteParams = new GeoTiffWriteParams()
+ DEFAULT_WRITE_PARAMS.setCompressionMode(ImageWriteParam.MODE_EXPLICIT)
+ DEFAULT_WRITE_PARAMS.setCompressionType("LZW")
+ DEFAULT_WRITE_PARAMS.setCompressionQuality(0.75F)
+ DEFAULT_WRITE_PARAMS.setTilingMode(ImageWriteParam.MODE_EXPLICIT)
+ DEFAULT_WRITE_PARAMS.setTiling(512, 512)
+
+ private val hfs = new Path(savePath).getFileSystem(context.getConfiguration)
+
+ override def write(row: InternalRow): Unit = {
+ // retrieving the metadata of a geotiff image
+ var rowFields: InternalRow = row
+ var schemaFields: StructType = dataSchema
+ val fields = dataSchema.fieldNames
+
+ if (fields.contains(imageWriteOptions.colImage)) {
+ schemaFields = dataSchema.fields(dataSchema.fieldIndex(imageWriteOptions.colImage)).dataType.asInstanceOf[StructType]
+ rowFields = row.getStruct(dataSchema.fieldIndex(imageWriteOptions.colImage), 6)
+ }
+
+ val tiffOrigin = rowFields.getString(schemaFields.fieldIndex(imageWriteOptions.colOrigin))
+ val tiffBands = rowFields.getInt(schemaFields.fieldIndex(imageWriteOptions.colBands))
+ val tiffWidth = rowFields.getInt(schemaFields.fieldIndex(imageWriteOptions.colWidth))
+ val tiffHeight = rowFields.getInt(schemaFields.fieldIndex(imageWriteOptions.colHeight))
+ val tiffGeometry = Row.fromSeq(rowFields.toSeq(schemaFields)).get(schemaFields.fieldIndex(imageWriteOptions.colGeometry))
+ val tiffData = rowFields.getArray(schemaFields.fieldIndex(imageWriteOptions.colData)).toDoubleArray()
+
+ // if an image is invalid, fields are -1 and data array is empty. Skip writing that image
+ if (tiffBands == -1) return
+
+ // create a writable raster object
+ val raster = RasterFactory.createBandedRaster(DataBuffer.TYPE_DOUBLE, tiffWidth, tiffHeight, tiffBands, null)
+
+ // extract the pixels of the geotiff image and write to the writable raster
+ val pixelVal = Array.ofDim[Double](tiffBands)
+ for (i <- 0 until tiffHeight) {
+ for (j <- 0 until tiffWidth) {
+ for (k <- 0 until tiffBands) {
+ pixelVal(k) = tiffData(tiffHeight*tiffWidth*k + i * tiffWidth + j)
+ }
+ raster.setPixel(j, i, pixelVal)
+ }
+ }
+
+ // CRS is decoded to user-provided option "writeToCRS", default value is "EPSG:4326"
+ val crs = CRS.decode(imageWriteOptions.writeToCRS, true)
+
+ // Extract the geometry coordinates and set the envelop of the geotiff source
+ var coordinateList: Array[Coordinate] = null
+ if (tiffGeometry.isInstanceOf[UTF8String]) {
+ val wktReader = new WKTReader()
+ val envGeom = wktReader.read(tiffGeometry.toString).asInstanceOf[Polygon]
+ coordinateList = envGeom.getCoordinates()
+ } else {
+ val envGeom = GeometrySerializer.deserialize(tiffGeometry.asInstanceOf[ArrayData])
+ coordinateList = envGeom.getCoordinates()
+ }
+ val referencedEnvelope = new ReferencedEnvelope(coordinateList(0).x, coordinateList(2).x, coordinateList(0).y, coordinateList(2).y, crs)
+
+ // create the write path
+ val writePath = Paths.get(savePath, new Path(tiffOrigin).getName).toString
+ val out = hfs.create(new Path(writePath))
+
+ val format = GridFormatFinder.findFormat(out)
+ var hints: Hints = null
+ if (format.isInstanceOf[GeoTiffFormat]) {
+ hints = new Hints(Hints.FORCE_LONGITUDE_FIRST_AXIS_ORDER, true)
+ }
+
+ // create the writer object
+ val factory = CoverageFactoryFinder.getGridCoverageFactory(hints)
+ val gc = factory.create("GRID", raster, referencedEnvelope)
+ val writer = new GeoTiffWriter(out, hints)
+
+ val gtiffParams = new GeoTiffFormat().getWriteParameters
+ gtiffParams.parameter(AbstractGridFormat.GEOTOOLS_WRITE_PARAMS.getName.toString).setValue(DEFAULT_WRITE_PARAMS)
+ val wps: Array[GeneralParameterValue] = gtiffParams.values.toArray(new Array[GeneralParameterValue](1))
+
+ // write the geotiff image to file
+ try {
+ writer.write(gc, wps)
+ writer.dispose()
+ out.close()
+ } catch {
+ case e@(_: IllegalArgumentException | _: IOException) =>
+ // TODO Auto-generated catch block
+ e.printStackTrace()
+ }
+ }
+
+ override def close(): Unit = {
+ hfs.close()
+ }
+
+ def path(): String = {
+ savePath
+ }
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
index 08234338..5a3a3595 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/GeotiffSchema.scala
@@ -26,9 +26,11 @@ import org.geotools.coverage.grid.{GridCoordinates2D, GridCoverage2D}
import org.geotools.gce.geotiff.GeoTiffReader
import org.geotools.geometry.jts.JTS
import org.geotools.referencing.CRS
-import org.locationtech.jts.geom.{Coordinate, GeometryFactory}
+import org.locationtech.jts.geom.{Coordinate, GeometryFactory, Polygon}
import org.opengis.coverage.grid.{GridCoordinates, GridEnvelope}
import org.opengis.parameter.{GeneralParameterValue, ParameterValue}
+import org.opengis.referencing.crs.CoordinateReferenceSystem
+import org.opengis.referencing.operation.MathTransform
import java.io.ByteArrayInputStream
@@ -40,7 +42,7 @@ object GeotiffSchema {
*/
val columnSchema = StructType(
StructField("origin", StringType, true) ::
- StructField("wkt", StringType, true) ::
+ StructField("geometry", StringType, true) ::
StructField("height", IntegerType, false) ::
StructField("width", IntegerType, false) ::
StructField("nBands", IntegerType, false) ::
@@ -117,7 +119,7 @@ object GeotiffSchema {
*
*/
- private[io] def decode(origin: String, bytes: Array[Byte]): Option[Row] = {
+ private[io] def decode(origin: String, bytes: Array[Byte], imageSourceOptions: ImageReadOptions): Option[Row] = {
val policy: ParameterValue[OverviewPolicy] = AbstractGridFormat.OVERVIEW_POLICY.createValue
policy.setValue(OverviewPolicy.IGNORE)
@@ -141,9 +143,31 @@ object GeotiffSchema {
}
// Fetch geometry from given image
- val source = coverage.getCoordinateReferenceSystem
- val target = CRS.decode("EPSG:4326", true)
- val targetCRS = CRS.findMathTransform(source, target)
+ var source: CoordinateReferenceSystem = try {
+ coverage.getCoordinateReferenceSystem
+ }
+ catch {
+ case _: Exception => null
+ }
+ if (source == null && imageSourceOptions.readFromCRS != "") {
+ source = CRS.decode(imageSourceOptions.readFromCRS, true)
+ }
+
+ val target: CoordinateReferenceSystem = if (imageSourceOptions.readToCRS != "") {
+ CRS.decode(imageSourceOptions.readToCRS, true)
+ } else {
+ null
+ }
+
+ var targetCRS: MathTransform = null
+ if (target != null) {
+ if (source == null) {
+ throw new IllegalArgumentException("Invalid arguments. Source coordinate reference system was not found.")
+ } else {
+ targetCRS = CRS.findMathTransform(source, target, imageSourceOptions.disableErrorInCRS)
+ }
+ }
+
val gridRange2D = coverage.getGridGeometry.getGridRange
val cords = Array(Array(gridRange2D.getLow(0), gridRange2D.getLow(1)), Array(gridRange2D.getLow(0), gridRange2D.getHigh(1)), Array(gridRange2D.getHigh(0), gridRange2D.getHigh(1)), Array(gridRange2D.getHigh(0), gridRange2D.getLow(1)))
val polyCoordinates = new Array[Coordinate](5)
@@ -160,7 +184,10 @@ object GeotiffSchema {
polyCoordinates(index) = polyCoordinates(0)
val factory = new GeometryFactory
- val polygon = JTS.transform(factory.createPolygon(polyCoordinates), targetCRS)
+ var polygon = factory.createPolygon(polyCoordinates)
+ if (targetCRS != null) {
+ polygon = JTS.transform(polygon, targetCRS).asInstanceOf[Polygon]
+ }
// Fetch band values from given image
val nBands: Int = coverage.getNumSampleDimensions
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
similarity index 63%
copy from sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
copy to sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
index 94d21830..f73fc7cf 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageReadOptions.scala
@@ -20,11 +20,18 @@ package org.apache.spark.sql.sedona_sql.io
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-private[io] class ImageOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
+private[io] class ImageReadOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
/**
- * Whether to drop invalid images. If true, invalid images will be removed, otherwise
+ * Optional parameters for reading GeoTiff
+ * dropInvalid indicatesWhether to drop invalid images. If true, invalid images will be removed, otherwise
* invalid images will be returned with empty data and all other field filled with `-1`.
+ * disableErrorInCRS indicates whether to disable to errors in CRS transformation
+ * readFromCRS and readToCRS indicate source and target coordinate reference system, respectively.
*/
val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
+ val disableErrorInCRS = parameters.getOrElse("disableErrorInCRS", "false").toBoolean
+ val readFromCRS = parameters.getOrElse("readFromCRS", "")
+ val readToCRS = parameters.getOrElse("readToCRS", "")
+
}
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
similarity index 59%
rename from sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
rename to sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
index 94d21830..8653c93a 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageOptions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/io/ImageWriteOptions.scala
@@ -20,11 +20,17 @@ package org.apache.spark.sql.sedona_sql.io
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-private[io] class ImageOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
+private[io] class ImageWriteOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
- /**
- * Whether to drop invalid images. If true, invalid images will be removed, otherwise
- * invalid images will be returned with empty data and all other field filled with `-1`.
- */
- val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
+
+ // Optional parameters for writing GeoTiff
+ val writeToCRS = parameters.getOrElse("writeToCRS", "EPSG:4326")
+ val colImage = parameters.getOrElse("fieldImage", "image")
+ val colOrigin = parameters.getOrElse("fieldOrigin", "origin")
+ val colBands = parameters.getOrElse("fieldNBands", "nBands")
+ val colWidth = parameters.getOrElse("fieldWidth", "width")
+ val colHeight = parameters.getOrElse("fieldHeight", "height")
+ val colGeometry = parameters.getOrElse("fieldGeometry", "geometry")
+ val colData = parameters.getOrElse("fieldData", "data")
+
}
\ No newline at end of file
diff --git a/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala b/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
index 406f7086..418daea4 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/rasterIOTest.scala
@@ -22,6 +22,7 @@ package org.apache.sedona.sql
import org.locationtech.jts.geom.Geometry
import org.scalatest.{BeforeAndAfter, GivenWhenThen}
+import java.io.File
import scala.collection.mutable
class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen {
@@ -29,9 +30,72 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
var rasterdatalocation: String = resourceFolder + "raster/"
describe("Raster IO test") {
- it("Should Pass geotiff loading") {
+ it("Should Pass geotiff loading without readFromCRS and readToCRS") {
var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
- df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-13095782 4021226.5, -13095782 3983905, -13058822 3983905, -13058822 4021226.5, -13095782 4021226.5))")
+ assert(df.first().getInt(2) == 517)
+ assert(df.first().getInt(3) == 512)
+ assert(df.first().getInt(5) == 1)
+ val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+ val line1 = blackBand.slice(0, 512)
+ val line2 = blackBand.slice(512, 1024)
+ assert(line1(0) == 0.0) // The first value at line 1 is black
+ assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+ }
+
+ it("Should Pass geotiff loading with readToCRS") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-117.64141128097314 33.94356351407699, -117.64141128097314 33.664978146501284, -117.30939395196258 33.664978146501284," +
+ " -117.30939395196258 33.94356351407699, -117.64141128097314 33.94356351407699))")
+ assert(df.first().getInt(2) == 517)
+ assert(df.first().getInt(3) == 512)
+ assert(df.first().getInt(5) == 1)
+ val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+ val line1 = blackBand.slice(0, 512)
+ val line2 = blackBand.slice(512, 1024)
+ assert(line1(0) == 0.0) // The first value at line 1 is black
+ assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+ }
+
+ it("Should Pass geotiff loading with readFromCRS") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readFromCRS", "EPSG:4499").load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-13095782 4021226.5, -13095782 3983905, -13058822 3983905, -13058822 4021226.5, -13095782 4021226.5))")
+ assert(df.first().getInt(2) == 517)
+ assert(df.first().getInt(3) == 512)
+ assert(df.first().getInt(5) == 1)
+ val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+ val line1 = blackBand.slice(0, 512)
+ val line2 = blackBand.slice(512, 1024)
+ assert(line1(0) == 0.0) // The first value at line 1 is black
+ assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+ }
+
+ it("Should Pass geotiff loading with readFromCRS and readToCRS") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readFromCRS", "EPSG:4499").option("readToCRS", "EPSG:4326").load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-117.64141128097314 33.94356351407699, -117.64141128097314 33.664978146501284, -117.30939395196258 33.664978146501284," +
+ " -117.30939395196258 33.94356351407699, -117.64141128097314 33.94356351407699))")
+ assert(df.first().getInt(2) == 517)
+ assert(df.first().getInt(3) == 512)
+ assert(df.first().getInt(5) == 1)
+ val blackBand = df.first().getAs[mutable.WrappedArray[Double]](4)
+ val line1 = blackBand.slice(0, 512)
+ val line2 = blackBand.slice(512, 1024)
+ assert(line1(0) == 0.0) // The first value at line 1 is black
+ assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+ }
+
+ it("Should Pass geotiff loading with all read options") {
+ var df = sparkSession.read.format("geotiff")
+ .option("dropInvalid", true)
+ .option("readFromCRS", "EPSG:4499")
+ .option("readToCRS", "EPSG:4326")
+ .option("disableErrorInCRS", true)
+ .load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
assert(df.first().getAs[Geometry](1).toText == "POLYGON ((-117.64141128097314 33.94356351407699, -117.64141128097314 33.664978146501284, -117.30939395196258 33.664978146501284," +
" -117.30939395196258 33.94356351407699, -117.64141128097314 33.94356351407699))")
assert(df.first().getInt(2) == 517)
@@ -53,7 +117,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
it("should pass RS_Base64") {
var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(resourceFolder + "raster/")
- df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
df = df.selectExpr("RS_GetBand(data, 1, bands) as targetBand", "width","height")
df.createOrReplaceTempView("geotiff")
df = sparkSession.sql("Select RS_base64(height, width, targetBand, RS_Array(height*width, 0), RS_Array(height*width, 0)) as encodedstring from geotiff")
@@ -62,7 +126,7 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
it("should pass RS_HTML") {
var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(resourceFolder + "raster/")
- df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.wkt) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ df = df.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
df = df.selectExpr("RS_GetBand(data, 1, bands) as targetBand", "width","height")
df.createOrReplaceTempView("geotiff")
df = sparkSession.sql("Select RS_base64(height, width, targetBand, RS_Array(height*width, 0.0), RS_Array(height*width, 0.0)) as encodedstring from geotiff")
@@ -91,6 +155,199 @@ class rasterIOTest extends TestBaseScala with BeforeAndAfter with GivenWhenThen
df = df.selectExpr("RS_GetBand(data, 4, bands) as targetBand")
assert(df.first().getAs[mutable.WrappedArray[Double]](0)(2) == 0.0)
}
+
+ it("Should Pass geotiff file writing with coalesce") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).option("readToCRS", "EPSG:4326").load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+ val savePath = resourceFolder + "raster-written/"
+ df.coalesce(1).write.mode("overwrite").format("geotiff").save(savePath)
+
+ var loadPath = savePath
+ val tempFile = new File(loadPath)
+ val fileList = tempFile.listFiles()
+ for (i <- 0 until fileList.length) {
+ if (fileList(i).isDirectory) loadPath = fileList(i).getAbsolutePath
+ }
+
+ var dfWritten = sparkSession.read.format("geotiff").option("dropInvalid", true).load(loadPath)
+ dfWritten = dfWritten.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ val rowFirst = dfWritten.first()
+ assert(rowFirst.getInt(2) == 517)
+ assert(rowFirst.getInt(3) == 512)
+ assert(rowFirst.getInt(5) == 1)
+
+ val blackBand = rowFirst.getAs[mutable.WrappedArray[Double]](4)
+ val line1 = blackBand.slice(0, 512)
+ val line2 = blackBand.slice(512, 1024)
+ assert(line1(0) == 0.0) // The first value at line 1 is black
+ assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+ }
+
+ it("Should Pass geotiff file writing with writeToCRS") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+ val savePath = resourceFolder + "raster-written/"
+ df.coalesce(1).write.mode("overwrite").format("geotiff").option("writeToCRS", "EPSG:4499").save(savePath)
+
+ var loadPath = savePath
+ val tempFile = new File(loadPath)
+ val fileList = tempFile.listFiles()
+ for (i <- 0 until fileList.length) {
+ if (fileList(i).isDirectory) loadPath = fileList(i).getAbsolutePath
+ }
+
+ var dfWritten = sparkSession.read.format("geotiff").option("dropInvalid", true).load(loadPath)
+ dfWritten = dfWritten.selectExpr("image.origin as origin","ST_GeomFromWkt(image.geometry) as Geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ val rowFirst = dfWritten.first()
+ assert(rowFirst.getInt(2) == 517)
+ assert(rowFirst.getInt(3) == 512)
+ assert(rowFirst.getInt(5) == 1)
+
+ val blackBand = rowFirst.getAs[mutable.WrappedArray[Double]](4)
+ val line1 = blackBand.slice(0, 512)
+ val line2 = blackBand.slice(512, 1024)
+ assert(line1(0) == 0.0) // The first value at line 1 is black
+ assert(line2(159) == 0.0 && line2(160) == 123.0) // In the second line, value at 159 is black and at 160 is not black
+ }
+
+ it("Should Pass geotiff file writing without coalesce") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data", "image.nBands as nBands")
+ val savePath = resourceFolder + "raster-written/"
+ df.write.mode("overwrite").format("geotiff").save(savePath)
+
+ var imageCount = 0
+ def getFile(loadPath: String): Unit ={
+ val tempFile = new File(loadPath)
+ val fileList = tempFile.listFiles()
+ if (fileList == null) return
+ for (i <- 0 until fileList.length) {
+ if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+ else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+ }
+ }
+
+ getFile(savePath)
+ assert(imageCount == 3)
+ }
+
+ it("Should Pass geotiff file writing with nested schema") {
+ val df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+ val savePath = resourceFolder + "raster-written/"
+ df.write.mode("overwrite").format("geotiff").save(savePath)
+
+ var imageCount = 0
+ def getFile(loadPath: String): Unit ={
+ val tempFile = new File(loadPath)
+ val fileList = tempFile.listFiles()
+ if (fileList == null) return
+ for (i <- 0 until fileList.length) {
+ if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+ else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+ }
+ }
+
+ getFile(savePath)
+ assert(imageCount == 3)
+ }
+
+ it("Should Pass geotiff file writing with renamed fields") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+ df = df.selectExpr("image.origin as source","image.geometry as geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ val savePath = resourceFolder + "raster-written/"
+ df.write
+ .mode("overwrite")
+ .format("geotiff")
+ .option("fieldOrigin", "source")
+ .option("fieldGeometry", "geom")
+ .option("fieldNBands", "bands")
+ .save(savePath)
+
+ var imageCount = 0
+ def getFile(loadPath: String): Unit ={
+ val tempFile = new File(loadPath)
+ val fileList = tempFile.listFiles()
+ if (fileList == null) return
+ for (i <- 0 until fileList.length) {
+ if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+ else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+ }
+ }
+
+ getFile(savePath)
+ assert(imageCount == 3)
+ }
+
+ it("Should Pass geotiff file writing with nested schema and renamed fields") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+ df = df.selectExpr("image as tiff_image")
+ val savePath = resourceFolder + "raster-written/"
+ df.write
+ .mode("overwrite")
+ .format("geotiff")
+ .option("fieldImage", "tiff_image")
+ .save(savePath)
+
+ var imageCount = 0
+ def getFile(loadPath: String): Unit ={
+ val tempFile = new File(loadPath)
+ val fileList = tempFile.listFiles()
+ if (fileList == null) return
+ for (i <- 0 until fileList.length) {
+ if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+ else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+ }
+ }
+
+ getFile(savePath)
+ assert(imageCount == 3)
+ }
+
+ it("Should Pass geotiff file writing with converted geometry") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+ df = df.selectExpr("image.origin as source","ST_GeomFromWkt(image.geometry) as geom", "image.height as height", "image.width as width", "image.data as data", "image.nBands as bands")
+ val savePath = resourceFolder + "raster-written/"
+ df.write
+ .mode("overwrite")
+ .format("geotiff")
+ .option("fieldOrigin", "source")
+ .option("fieldGeometry", "geom")
+ .option("fieldNBands", "bands")
+ .save(savePath)
+
+ var imageCount = 0
+ def getFile(loadPath: String): Unit ={
+ val tempFile = new File(loadPath)
+ val fileList = tempFile.listFiles()
+ if (fileList == null) return
+ for (i <- 0 until fileList.length) {
+ if (fileList(i).isDirectory) getFile(fileList(i).getAbsolutePath)
+ else if (fileList(i).getAbsolutePath.endsWith(".tiff") || fileList(i).getAbsolutePath.endsWith(".tif")) imageCount += 1
+ }
+ }
+
+ getFile(savePath)
+ assert(imageCount == 3)
+ }
+
+ it("Should Pass geotiff file writing with handling invalid schema") {
+ var df = sparkSession.read.format("geotiff").option("dropInvalid", true).load(rasterdatalocation)
+ df = df.selectExpr("image.origin as origin","image.geometry as geometry", "image.height as height", "image.width as width", "image.data as data")
+ val savePath = resourceFolder + "raster-written/"
+
+ try {
+ df.write
+ .mode("overwrite")
+ .format("geotiff")
+ .option("fieldImage", "tiff_image")
+ .save(savePath)
+ }
+ catch {
+ case e: IllegalArgumentException => {
+ assert(e.getMessage == "Invalid GeoTiff Schema")
+ }
+ }
+ }
}
}