You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2024/03/22 03:05:20 UTC
(sedona) branch master updated: [SEDONA-406] Raster deserializer for PySpark (#1281)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 63a1de0b2 [SEDONA-406] Raster deserializer for PySpark (#1281)
63a1de0b2 is described below
commit 63a1de0b209f2f8dccca4e1c7f7ff25f6044f081
Author: Kristin Cowalcijk <bo...@wherobots.com>
AuthorDate: Fri Mar 22 11:05:15 2024 +0800
[SEDONA-406] Raster deserializer for PySpark (#1281)
* [SEDONA-406] Raster deserializer for PySpark (#116)
* Update documentation
* Add documentation for writing Python UDF to work with raster data
---
.github/workflows/python.yml | 17 +-
.../common/raster/DeepCopiedRenderedImage.java | 218 ++++++++++-----
.../raster/RasterConstructorsForTesting.java | 199 ++++++++++++++
.../org/apache/sedona/common/raster/Serde.java | 103 -------
.../common/raster/serde/AWTRasterSerializer.java | 61 +++++
.../raster/serde/AffineTransform2DSerializer.java | 47 ++++
.../common/raster/{ => serde}/CRSSerializer.java | 47 +++-
.../common/raster/serde/DataBufferSerializer.java | 106 ++++++++
.../raster/serde/GridEnvelopeSerializer.java | 39 +++
.../serde/GridSampleDimensionSerializer.java | 54 ++++
.../sedona/common/raster/serde/KryoUtil.java | 297 +++++++++++++++++++++
.../common/raster/serde/SampleModelSerializer.java | 174 ++++++++++++
.../apache/sedona/common/raster/serde/Serde.java | 179 +++++++++++++
.../common/raster/RasterBandEditorsTest.java | 5 +-
.../raster/RasterConstructorsForTestingTest.java | 111 ++++++++
.../sedona/common/raster/RasterTestBase.java | 8 +-
.../raster/{ => serde}/CRSSerializerTest.java | 2 +-
.../raster/serde/DataBufferSerializerTest.java | 153 +++++++++++
.../raster/serde/KryoSerializerTestBase.java | 34 +++
.../sedona/common/raster/serde/KryoUtilTest.java | 234 ++++++++++++++++
.../raster/serde/SampleModelSerializerTest.java | 112 ++++++++
.../common/raster/{ => serde}/SerdeTest.java | 23 +-
docs/setup/compile.md | 19 +-
docs/tutorial/raster.md | 86 ++++++
python/Pipfile | 1 +
python/sedona/raster/__init__.py | 16 ++
python/sedona/raster/awt_raster.py | 41 +++
python/sedona/raster/data_buffer.py | 39 +++
python/sedona/raster/meta.py | 112 ++++++++
python/sedona/raster/raster_serde.py | 180 +++++++++++++
python/sedona/raster/sample_model.py | 193 +++++++++++++
python/sedona/raster/sedona_raster.py | 261 ++++++++++++++++++
python/sedona/sql/types.py | 7 +-
python/setup.py | 2 +-
python/tests/raster/__init__.py | 16 ++
python/tests/raster/test_meta.py | 59 ++++
python/tests/raster/test_pandas_udf.py | 76 ++++++
python/tests/raster/test_serde.py | 121 +++++++++
spark-shaded/pom.xml | 98 +++++++
.../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 +
.../apache/sedona/sql/utils/RasterSerializer.scala | 2 +-
.../spark/sql/sedona_sql/UDT/RasterUDT.scala | 2 +-
.../expressions/raster/RasterConstructors.scala | 10 +-
.../sedona_sql/expressions/raster/implicits.scala | 2 +-
.../org/apache/sedona/sql/rasteralgebraTest.scala | 9 +-
.../org/apache/sedona/sql/serdeAwareTest.scala | 2 +-
46 files changed, 3374 insertions(+), 204 deletions(-)
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index b9cbb6337..a4efd768b 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -22,6 +22,9 @@ on:
env:
MAVEN_OPTS: -Dmaven.wagon.httpconnectionManager.ttlSeconds=60
+ JAI_CORE_VERSION: "1.1.3"
+ JAI_CODEC_VERSION: "1.1.3"
+ JAI_IMAGEIO_VERSION: "1.1"
jobs:
build:
@@ -111,11 +114,15 @@ jobs:
- env:
SPARK_VERSION: ${{ matrix.spark }}
HADOOP_VERSION: ${{ matrix.hadoop }}
- run: wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz
- - env:
- SPARK_VERSION: ${{ matrix.spark }}
- HADOOP_VERSION: ${{ matrix.hadoop }}
- run: tar -xzf spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz
+ run: |
+ wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz
+ wget https://repo.osgeo.org/repository/release/javax/media/jai_core/${JAI_CORE_VERSION}/jai_core-${JAI_CORE_VERSION}.jar
+ wget https://repo.osgeo.org/repository/release/javax/media/jai_codec/${JAI_CODEC_VERSION}/jai_codec-${JAI_CODEC_VERSION}.jar
+ wget https://repo.osgeo.org/repository/release/javax/media/jai_imageio/${JAI_IMAGEIO_VERSION}/jai_imageio-${JAI_IMAGEIO_VERSION}.jar
+ tar -xzf spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}.tgz
+ mv -v jai_core-${JAI_CORE_VERSION}.jar spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/jars/
+ mv -v jai_codec-${JAI_CODEC_VERSION}.jar spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/jars/
+ mv -v jai_imageio-${JAI_IMAGEIO_VERSION}.jar spark-${SPARK_VERSION}-bin-hadoop${HADOOP_VERSION}/jars/
- run: sudo apt-get -y install python3-pip python-dev-is-python3
- run: sudo pip3 install -U setuptools
- run: sudo pip3 install -U wheel
diff --git a/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java b/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java
index 971b168ab..0df154076 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java
@@ -13,8 +13,16 @@
*/
package org.apache.sedona.common.raster;
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import com.esotericsoftware.kryo.serializers.JavaSerializer;
+import com.sun.media.jai.rmi.ColorModelState;
import com.sun.media.jai.util.ImageUtil;
import it.geosolutions.jaiext.range.NoDataContainer;
+import org.apache.sedona.common.raster.serde.AWTRasterSerializer;
+import org.apache.sedona.common.raster.serde.KryoUtil;
import org.apache.sedona.common.utils.RasterUtils;
import javax.media.jai.JAI;
@@ -48,7 +56,7 @@ import java.util.Vector;
* object is being disposed, it tries to connect to the remote server. However, there is no remote server in deep-copy
* mode, so the dispose() method throws a java.net.SocketException.
*/
-public final class DeepCopiedRenderedImage implements RenderedImage, Serializable {
+public final class DeepCopiedRenderedImage implements RenderedImage, Serializable, KryoSerializable {
private transient RenderedImage source;
private int minX;
private int minY;
@@ -69,7 +77,7 @@ public final class DeepCopiedRenderedImage implements RenderedImage, Serializabl
private Rectangle imageBounds;
private transient Raster imageRaster;
- DeepCopiedRenderedImage() {
+ public DeepCopiedRenderedImage() {
this.sampleModel = null;
this.colorModel = null;
this.sources = null;
@@ -87,57 +95,54 @@ public final class DeepCopiedRenderedImage implements RenderedImage, Serializabl
this.properties = null;
if (source == null) {
throw new IllegalArgumentException("source cannot be null");
- } else {
- SampleModel sm = source.getSampleModel();
- if (sm != null && SerializerFactory.getSerializer(sm.getClass()) == null) {
- throw new IllegalArgumentException("sample model object is not serializable");
- } else {
- ColorModel cm = source.getColorModel();
- if (cm != null && SerializerFactory.getSerializer(cm.getClass()) == null) {
- throw new IllegalArgumentException("color model object is not serializable");
- } else {
- if (checkDataBuffer) {
- Raster ras = source.getTile(source.getMinTileX(), source.getMinTileY());
- if (ras != null) {
- DataBuffer db = ras.getDataBuffer();
- if (db != null && SerializerFactory.getSerializer(db.getClass()) == null) {
- throw new IllegalArgumentException("data buffer object is not serializable");
- }
- }
- }
-
- this.source = source;
- if (source instanceof RemoteImage) {
- throw new IllegalArgumentException("RemoteImage is not supported");
- }
- this.minX = source.getMinX();
- this.minY = source.getMinY();
- this.width = source.getWidth();
- this.height = source.getHeight();
- this.minTileX = source.getMinTileX();
- this.minTileY = source.getMinTileY();
- this.numXTiles = source.getNumXTiles();
- this.numYTiles = source.getNumYTiles();
- this.tileWidth = source.getTileWidth();
- this.tileHeight = source.getTileHeight();
- this.tileGridXOffset = source.getTileGridXOffset();
- this.tileGridYOffset = source.getTileGridYOffset();
- this.sampleModel = source.getSampleModel();
- this.colorModel = source.getColorModel();
- this.sources = new Vector<>();
- this.sources.add(source);
- this.properties = new Hashtable<>();
- String[] propertyNames = source.getPropertyNames();
- if (propertyNames != null) {
- for (String propertyName : propertyNames) {
- this.properties.put(propertyName, source.getProperty(propertyName));
- }
- }
-
- this.imageBounds = new Rectangle(this.minX, this.minY, this.width, this.height);
+ }
+ SampleModel sm = source.getSampleModel();
+ if (sm != null && SerializerFactory.getSerializer(sm.getClass()) == null) {
+ throw new IllegalArgumentException("sample model object is not serializable");
+ }
+ ColorModel cm = source.getColorModel();
+ if (cm != null && SerializerFactory.getSerializer(cm.getClass()) == null) {
+ throw new IllegalArgumentException("color model object is not serializable");
+ }
+ if (checkDataBuffer) {
+ Raster ras = source.getTile(source.getMinTileX(), source.getMinTileY());
+ if (ras != null) {
+ DataBuffer db = ras.getDataBuffer();
+ if (db != null && SerializerFactory.getSerializer(db.getClass()) == null) {
+ throw new IllegalArgumentException("data buffer object is not serializable");
}
}
}
+
+ this.source = source;
+ if (source instanceof RemoteImage) {
+ throw new IllegalArgumentException("RemoteImage is not supported");
+ }
+ this.minX = source.getMinX();
+ this.minY = source.getMinY();
+ this.width = source.getWidth();
+ this.height = source.getHeight();
+ this.minTileX = source.getMinTileX();
+ this.minTileY = source.getMinTileY();
+ this.numXTiles = source.getNumXTiles();
+ this.numYTiles = source.getNumYTiles();
+ this.tileWidth = source.getTileWidth();
+ this.tileHeight = source.getTileHeight();
+ this.tileGridXOffset = source.getTileGridXOffset();
+ this.tileGridYOffset = source.getTileGridYOffset();
+ this.sampleModel = source.getSampleModel();
+ this.colorModel = source.getColorModel();
+ this.sources = new Vector<>();
+ this.sources.add(source);
+ this.properties = new Hashtable<>();
+ String[] propertyNames = source.getPropertyNames();
+ if (propertyNames != null) {
+ for (String propertyName : propertyNames) {
+ this.properties.put(propertyName, source.getProperty(propertyName));
+ }
+ }
+
+ this.imageBounds = new Rectangle(this.minX, this.minY, this.width, this.height);
}
@Override
@@ -325,10 +330,54 @@ public final class DeepCopiedRenderedImage implements RenderedImage, Serializabl
return this.width;
}
- @SuppressWarnings("unchecked")
private void writeObject(ObjectOutputStream out) throws IOException {
out.defaultWriteObject();
+ Hashtable<String, Object> propertyTable = getSerializableProperties();
+ out.writeObject(SerializerFactory.getState(this.colorModel, null));
+ out.writeObject(propertyTable);
+ if (this.source != null) {
+ Raster serializedRaster = RasterUtils.getRaster(this.source);
+ out.writeObject(SerializerFactory.getState(serializedRaster, null));
+ } else {
+ out.writeObject(SerializerFactory.getState(imageRaster, null));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
+ this.source = null;
+ in.defaultReadObject();
+
+ SerializableState cmState = (SerializableState)in.readObject();
+ this.colorModel = (ColorModel)cmState.getObject();
+ this.properties = (Hashtable<String, Object>)in.readObject();
+ for (String key : this.properties.keySet()) {
+ Object value = this.properties.get(key);
+ // Restore the value of GC_NODATA property as a NoDataContainer object.
+ if (value instanceof SingleValueNoDataContainer) {
+ SingleValueNoDataContainer noDataContainer = (SingleValueNoDataContainer) value;
+ this.properties.put(key, new NoDataContainer(noDataContainer.singleValue));
+ }
+ }
+ SerializableState rasState = (SerializableState)in.readObject();
+ this.imageRaster = (Raster)rasState.getObject();
+
+ // The deserialized rendered image contains only one tile (imageRaster). We need to update
+ // the sample model and tile properties to reflect this.
+ this.sampleModel = this.imageRaster.getSampleModel();
+ this.tileWidth = this.width;
+ this.tileHeight = this.height;
+ this.numXTiles = 1;
+ this.numYTiles = 1;
+ this.minTileX = 0;
+ this.minTileY = 0;
+ this.tileGridXOffset = minX;
+ this.tileGridYOffset = minY;
+ }
+
+ @SuppressWarnings("unchecked")
+ private Hashtable<String, Object> getSerializableProperties() {
// Prepare serialize properties. non-serializable properties won't be serialized.
Hashtable<String, Object> propertyTable = this.properties;
boolean propertiesCloned = false;
@@ -350,25 +399,54 @@ public final class DeepCopiedRenderedImage implements RenderedImage, Serializabl
}
}
}
+ return propertyTable;
+ }
- out.writeObject(SerializerFactory.getState(this.colorModel, null));
- out.writeObject(propertyTable);
+ public static void registerKryo(Kryo kryo) {
+ kryo.register(ColorModelState.class, new JavaSerializer());
+ }
+
+ private static final AWTRasterSerializer awtRasterSerializer = new AWTRasterSerializer();
+
+ @Override
+ public void write(Kryo kryo, Output output) {
+ // write basic properties
+ output.writeInt(minX);
+ output.writeInt(minY);
+ output.writeInt(width);
+ output.writeInt(height);
+
+ // write properties
+ Hashtable<String, Object> propertyTable = getSerializableProperties();
+ KryoUtil.writeObjectWithLength(kryo, output, propertyTable);
+
+ // write color model
+ SerializableState colorModelState = SerializerFactory.getState(this.colorModel, null);
+ KryoUtil.writeObjectWithLength(kryo, output, colorModelState);
+
+ // write raster
+ Raster serializedRaster;
if (this.source != null) {
- Raster serializedRaster = RasterUtils.getRaster(this.source);
- out.writeObject(SerializerFactory.getState(serializedRaster, null));
+ serializedRaster = RasterUtils.getRaster(this.source);
} else {
- out.writeObject(SerializerFactory.getState(imageRaster, null));
+ serializedRaster = imageRaster;
}
+ awtRasterSerializer.write(kryo, output, serializedRaster);
}
@SuppressWarnings("unchecked")
- private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
- this.source = null;
- in.defaultReadObject();
-
- SerializableState cmState = (SerializableState)in.readObject();
- this.colorModel = (ColorModel)cmState.getObject();
- this.properties = (Hashtable<String, Object>)in.readObject();
+ @Override
+ public void read(Kryo kryo, Input input) {
+ // read basic properties
+ minX = input.readInt();
+ minY = input.readInt();
+ width = input.readInt();
+ height = input.readInt();
+ imageBounds = new Rectangle(minX, minY, width, height);
+
+ // read properties
+ input.readInt(); // skip the length of the property table
+ properties = kryo.readObject(input, Hashtable.class);
for (String key : this.properties.keySet()) {
Object value = this.properties.get(key);
// Restore the value of GC_NODATA property as a NoDataContainer object.
@@ -377,8 +455,14 @@ public final class DeepCopiedRenderedImage implements RenderedImage, Serializabl
this.properties.put(key, new NoDataContainer(noDataContainer.singleValue));
}
}
- SerializableState rasState = (SerializableState)in.readObject();
- this.imageRaster = (Raster)rasState.getObject();
+
+ // read color model
+ input.readInt(); // skip the length of the color model state
+ ColorModelState cmState = kryo.readObject(input, ColorModelState.class);
+ this.colorModel = (ColorModel) cmState.getObject();
+
+ // read raster
+ this.imageRaster = awtRasterSerializer.read(kryo, input, Raster.class);
// The deserialized rendered image contains only one tile (imageRaster). We need to update
// the sample model and tile properties to reflect this.
@@ -387,6 +471,10 @@ public final class DeepCopiedRenderedImage implements RenderedImage, Serializabl
this.tileHeight = this.height;
this.numXTiles = 1;
this.numYTiles = 1;
+ this.minTileX = 0;
+ this.minTileY = 0;
+ this.tileGridXOffset = minX;
+ this.tileGridYOffset = minY;
}
/**
diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterConstructorsForTesting.java b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructorsForTesting.java
new file mode 100644
index 000000000..47667b9b8
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/RasterConstructorsForTesting.java
@@ -0,0 +1,199 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster;
+
+import com.sun.media.imageioimpl.common.BogusColorSpace;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.sedona.common.FunctionsGeoTools;
+import org.apache.sedona.common.utils.RasterUtils;
+import org.geotools.coverage.CoverageFactoryFinder;
+import org.geotools.coverage.grid.GridCoverage2D;
+import org.geotools.coverage.grid.GridCoverageFactory;
+import org.geotools.coverage.grid.GridEnvelope2D;
+import org.geotools.coverage.grid.GridGeometry2D;
+import org.geotools.referencing.crs.DefaultEngineeringCRS;
+import org.geotools.referencing.operation.transform.AffineTransform2D;
+import org.opengis.referencing.crs.CoordinateReferenceSystem;
+import org.opengis.referencing.datum.PixelInCell;
+import org.opengis.referencing.operation.MathTransform;
+
+import javax.media.jai.RasterFactory;
+import java.awt.Transparency;
+import java.awt.color.ColorSpace;
+import java.awt.image.BufferedImage;
+import java.awt.image.ColorModel;
+import java.awt.image.ComponentColorModel;
+import java.awt.image.ComponentSampleModel;
+import java.awt.image.DataBuffer;
+import java.awt.image.DirectColorModel;
+import java.awt.image.IndexColorModel;
+import java.awt.image.MultiPixelPackedSampleModel;
+import java.awt.image.PixelInterleavedSampleModel;
+import java.awt.image.RenderedImage;
+import java.awt.image.SampleModel;
+import java.awt.image.SinglePixelPackedSampleModel;
+import java.awt.image.WritableRaster;
+import java.util.Arrays;
+
+/**
+ * Raster constructor for testing the Python implementation of raster deserializer.
+ */
+public class RasterConstructorsForTesting {
+ private RasterConstructorsForTesting() {}
+
+ public static GridCoverage2D makeRasterForTesting(
+ int numBand, String bandDataType, String sampleModelType,
+ int widthInPixel, int heightInPixel, double upperLeftX, double upperLeftY,
+ double scaleX, double scaleY, double skewX, double skewY,
+ int srid) {
+ CoordinateReferenceSystem crs;
+ if (srid == 0) {
+ crs = DefaultEngineeringCRS.GENERIC_2D;
+ } else {
+ // Create the CRS from the srid
+ // Longitude first, Latitude second
+ crs = FunctionsGeoTools.sridToCRS(srid);
+ }
+
+ // Create a new raster with certain pixel values
+ WritableRaster raster = createRasterWithSampleModel(sampleModelType, bandDataType, widthInPixel, heightInPixel, numBand);
+ for (int k = 0; k < numBand; k++) {
+ for (int y = 0; y < heightInPixel; y++) {
+ for (int x = 0; x < widthInPixel; x++) {
+ double value = k + y * widthInPixel + x;
+ raster.setSample(x, y, k, value);
+ }
+ }
+ }
+
+ MathTransform transform = new AffineTransform2D(scaleX, skewY, skewX, scaleY, upperLeftX, upperLeftY);
+ GridGeometry2D gridGeometry = new GridGeometry2D(
+ new GridEnvelope2D(0, 0, widthInPixel, heightInPixel),
+ PixelInCell.CELL_CORNER,
+ transform, crs, null);
+
+ int rasterDataType = raster.getDataBuffer().getDataType();
+ ColorModel colorModel;
+ if (!sampleModelType.contains("Packed")) {
+ final ColorSpace cs = new BogusColorSpace(numBand);
+ final int[] nBits = new int[numBand];
+ Arrays.fill(nBits, DataBuffer.getDataTypeSize(rasterDataType));
+ colorModel =
+ new ComponentColorModel(cs, nBits, false, true, Transparency.OPAQUE, rasterDataType);
+ } else if (sampleModelType.equals("SinglePixelPackedSampleModel")) {
+ colorModel = new DirectColorModel(32,
+ 0x0F,
+ (0x0F) << 4,
+ (0x0F) << 8,
+ (0x0F) << 12);
+ } else if (sampleModelType.equals("MultiPixelPackedSampleModel")) {
+ byte[] arr = new byte[16];
+ for (int k = 0; k < 16; k++) {
+ arr[k] = (byte) (k * 16);
+ }
+ colorModel = new IndexColorModel(4, arr.length, arr, arr, arr);
+ } else {
+ throw new IllegalArgumentException("Unknown sample model type: " + sampleModelType);
+ }
+
+ final RenderedImage image = new BufferedImage(colorModel, raster, false, null);
+ GridCoverageFactory gridCoverageFactory = CoverageFactoryFinder.getGridCoverageFactory(null);
+ return gridCoverageFactory.create("genericCoverage", image, gridGeometry, null, null, null);
+ }
+
+ private static WritableRaster createRasterWithSampleModel(String sampleModelType, String bandDataType, int widthInPixel, int heightInPixel, int numBand) {
+ int dataType = RasterUtils.getDataTypeCode(bandDataType);
+
+ // Create raster according to sample model type
+ WritableRaster raster;
+ switch (sampleModelType) {
+ case "BandedSampleModel":
+ raster = RasterFactory.createBandedRaster(dataType, widthInPixel, heightInPixel, numBand, null);
+ break;
+ case "PixelInterleavedSampleModel": {
+ int scanlineStride = widthInPixel * numBand;
+ int[] bandOffsets = new int[numBand];
+ for (int i = 0; i < numBand; i++) {
+ bandOffsets[i] = i;
+ }
+ SampleModel sm = new PixelInterleavedSampleModel(dataType, widthInPixel, heightInPixel, numBand, scanlineStride, bandOffsets);
+ raster = RasterFactory.createWritableRaster(sm, null);
+ break;
+ }
+ case "PixelInterleavedSampleModelComplex": {
+ int pixelStride = numBand + 2;
+ int scanlineStride = widthInPixel * pixelStride + 5;
+ int[] bandOffsets = new int[numBand];
+ for (int i = 0; i < numBand; i++) {
+ bandOffsets[i] = i;
+ }
+ ArrayUtils.shuffle(bandOffsets);
+ SampleModel sm = new PixelInterleavedSampleModel(dataType, widthInPixel, heightInPixel, pixelStride, scanlineStride, bandOffsets);
+ raster = RasterFactory.createWritableRaster(sm, null);
+ break;
+ }
+ case "ComponentSampleModel": {
+ int pixelStride = numBand + 1;
+ int scanlineStride = widthInPixel * pixelStride + 5;
+ int[] bankIndices = new int[numBand];
+ for (int i = 0; i < numBand; i++) {
+ bankIndices[i] = i;
+ }
+ ArrayUtils.shuffle(bankIndices);
+ int[] bandOffsets = new int[numBand];
+ for (int i = 0; i < numBand; i++) {
+ bandOffsets[i] = (int)(Math.random() * widthInPixel);
+ }
+ SampleModel sm = new ComponentSampleModel(dataType, widthInPixel, heightInPixel, pixelStride, scanlineStride, bankIndices, bandOffsets);
+ raster = RasterFactory.createWritableRaster(sm, null);
+ break;
+ }
+ case "SinglePixelPackedSampleModel": {
+ if (dataType != DataBuffer.TYPE_INT) {
+ throw new IllegalArgumentException("only supports creating SinglePixelPackedSampleModel with int data type");
+ }
+ if (numBand != 4) {
+ throw new IllegalArgumentException("only supports creating SinglePixelPackedSampleModel with 4 bands");
+ }
+ int bitsPerBand = 4;
+ int scanlineStride = widthInPixel + 5;
+ int[] bitMasks = new int[numBand];
+ int baseMask = (1 << bitsPerBand) - 1;
+ for (int i = 0; i < numBand; i++) {
+ bitMasks[i] = baseMask << (i * bitsPerBand);
+ }
+ SampleModel sm = new SinglePixelPackedSampleModel(dataType, widthInPixel, heightInPixel, scanlineStride, bitMasks);
+ raster = RasterFactory.createWritableRaster(sm, null);
+ break;
+ }
+ case "MultiPixelPackedSampleModel": {
+ if (dataType != DataBuffer.TYPE_BYTE) {
+ throw new IllegalArgumentException("only supports creating MultiPixelPackedSampleModel with byte data type");
+ }
+ if (numBand != 1) {
+ throw new IllegalArgumentException("only supports creating MultiPixelPackedSampleModel with 1 band");
+ }
+ int numberOfBits = 4;
+ int scanlineStride = widthInPixel * numberOfBits / 8 + 2;
+ SampleModel sm = new MultiPixelPackedSampleModel(dataType, widthInPixel, heightInPixel, numberOfBits, scanlineStride, 80);
+ raster = RasterFactory.createWritableRaster(sm, null);
+ break;
+ }
+ default:
+ throw new IllegalArgumentException("Unknown sample model type: " + sampleModelType);
+ }
+
+ return raster;
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/Serde.java b/common/src/main/java/org/apache/sedona/common/raster/Serde.java
deleted file mode 100644
index 7f34d840d..000000000
--- a/common/src/main/java/org/apache/sedona/common/raster/Serde.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/**
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- * <p>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.sedona.common.raster;
-
-import org.geotools.coverage.GridSampleDimension;
-import org.geotools.coverage.grid.GridCoverage2D;
-import org.geotools.coverage.grid.GridCoverageFactory;
-import org.geotools.coverage.grid.GridEnvelope2D;
-import org.geotools.coverage.grid.GridGeometry2D;
-import org.opengis.referencing.operation.MathTransform;
-
-import javax.media.jai.RenderedImageAdapter;
-import java.awt.image.RenderedImage;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.Serializable;
-import java.lang.reflect.Field;
-
-public class Serde {
-
- static final Field field;
-
- static {
- try {
- field = GridCoverage2D.class.getDeclaredField("serializedImage");
- field.setAccessible(true);
- } catch (NoSuchFieldException e) {
- throw new RuntimeException(e);
- }
- }
-
- private static class SerializableState implements Serializable {
- public CharSequence name;
-
- // The following three components are used to construct a GridGeometry2D object.
- // We serialize CRS separately because the default serializer is pretty slow, we use a
- // cached serializer to speed up the serialization and reuse CRS on deserialization.
- public GridEnvelope2D gridEnvelope2D;
- public MathTransform gridToCRS;
- public byte[] serializedCRS;
-
- public GridSampleDimension[] bands;
- public DeepCopiedRenderedImage image;
-
- public GridCoverage2D restore() {
- GridGeometry2D gridGeometry = new GridGeometry2D(gridEnvelope2D, gridToCRS, CRSSerializer.deserialize(serializedCRS));
- return new GridCoverageFactory().create(name, image, gridGeometry, bands, null, null);
- }
- }
-
- public static byte[] serialize(GridCoverage2D raster) throws IOException {
- // GridCoverage2D created by GridCoverage2DReaders contain references that are not serializable.
- // Wrap the RenderedImage in DeepCopiedRenderedImage to make it serializable.
- DeepCopiedRenderedImage deepCopiedRenderedImage = null;
- RenderedImage renderedImage = raster.getRenderedImage();
- while (renderedImage instanceof RenderedImageAdapter) {
- renderedImage = ((RenderedImageAdapter) renderedImage).getWrappedImage();
- }
- if (renderedImage instanceof DeepCopiedRenderedImage) {
- deepCopiedRenderedImage = (DeepCopiedRenderedImage) renderedImage;
- } else {
- deepCopiedRenderedImage = new DeepCopiedRenderedImage(renderedImage);
- }
-
- SerializableState state = new SerializableState();
- GridGeometry2D gridGeometry = raster.getGridGeometry();
- state.name = raster.getName();
- state.gridEnvelope2D = gridGeometry.getGridRange2D();
- state.gridToCRS = gridGeometry.getGridToCRS2D();
- state.serializedCRS = CRSSerializer.serialize(gridGeometry.getCoordinateReferenceSystem());
- state.bands = raster.getSampleDimensions();
- state.image = deepCopiedRenderedImage;
- try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
- try (ObjectOutputStream oos = new ObjectOutputStream(bos)) {
- oos.writeObject(state);
- return bos.toByteArray();
- }
- }
- }
-
- public static GridCoverage2D deserialize(byte[] bytes) throws IOException, ClassNotFoundException {
- try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes)) {
- try (ObjectInputStream ois = new ObjectInputStream(bis)) {
- SerializableState state = (SerializableState) ois.readObject();
- return state.restore();
- }
- }
- }
-}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/AWTRasterSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/AWTRasterSerializer.java
new file mode 100644
index 000000000..5de5744c6
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/AWTRasterSerializer.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.awt.Point;
+import java.awt.image.DataBuffer;
+import java.awt.image.Raster;
+import java.awt.image.SampleModel;
+import java.awt.image.WritableRaster;
+
+public class AWTRasterSerializer extends Serializer<Raster> {
+ private static final SampleModelSerializer sampleModelSerializer = new SampleModelSerializer();
+ private static final DataBufferSerializer dataBufferSerializer = new DataBufferSerializer();
+
+ @Override
+ public void write(Kryo kryo, Output output, Raster raster) {
+ Raster r;
+ if (raster.getParent() != null) {
+ r = raster.createCompatibleWritableRaster(raster.getBounds());
+ ((WritableRaster) r).setRect(raster);
+ } else {
+ r = raster;
+ }
+
+ output.writeInt(r.getMinX());
+ output.writeInt(r.getMinY());
+ sampleModelSerializer.write(kryo, output, r.getSampleModel());
+ dataBufferSerializer.write(kryo, output, r.getDataBuffer());
+ }
+
+ @Override
+ public Raster read(Kryo kryo, Input input, Class<Raster> type) {
+ int minX = input.readInt();
+ int minY = input.readInt();
+ Point location = new Point(minX, minY);
+ SampleModel sampleModel = sampleModelSerializer.read(kryo, input, SampleModel.class);
+ DataBuffer dataBuffer = dataBufferSerializer.read(kryo, input, DataBuffer.class);
+ return Raster.createRaster(sampleModel, dataBuffer, location);
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/AffineTransform2DSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/AffineTransform2DSerializer.java
new file mode 100644
index 000000000..ed0a2d576
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/AffineTransform2DSerializer.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.geotools.referencing.operation.transform.AffineTransform2D;
+
+/**
+ * AffineTransform2D cannot be correctly deserialized by the default serializer of Kryo, so we need to provide a
+ * custom serializer.
+ */
+public class AffineTransform2DSerializer extends Serializer<AffineTransform2D> {
+ @Override
+ public void write(Kryo kryo, Output output, AffineTransform2D affineTransform2D) {
+ output.writeDouble(affineTransform2D.getScaleX());
+ output.writeDouble(affineTransform2D.getShearY());
+ output.writeDouble(affineTransform2D.getShearX());
+ output.writeDouble(affineTransform2D.getScaleY());
+ output.writeDouble(affineTransform2D.getTranslateX());
+ output.writeDouble(affineTransform2D.getTranslateY());
+ }
+
+ @Override
+ public AffineTransform2D read(Kryo kryo, Input input, Class<AffineTransform2D> aClass) {
+ double scaleX = input.readDouble();
+ double skewY = input.readDouble();
+ double skewX = input.readDouble();
+ double scaleY = input.readDouble();
+ double upperLeftX = input.readDouble();
+ double upperLeftY = input.readDouble();
+ return new AffineTransform2D(scaleX, skewY, skewX, scaleY, upperLeftX, upperLeftY);
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/CRSSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/CRSSerializer.java
similarity index 70%
rename from common/src/main/java/org/apache/sedona/common/raster/CRSSerializer.java
rename to common/src/main/java/org/apache/sedona/common/raster/serde/CRSSerializer.java
index 216131010..9c427e34f 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/CRSSerializer.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/CRSSerializer.java
@@ -16,19 +16,24 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.sedona.common.raster;
+package org.apache.sedona.common.raster.serde;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
+import org.apache.commons.io.IOUtils;
import org.geotools.referencing.CRS;
+import org.geotools.referencing.wkt.Formattable;
+import org.opengis.referencing.FactoryException;
import org.opengis.referencing.crs.CoordinateReferenceSystem;
+import si.uom.NonSI;
+import si.uom.SI;
+import javax.measure.IncommensurableException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
import java.util.zip.DeflaterOutputStream;
import java.util.zip.InflaterInputStream;
@@ -42,6 +47,20 @@ import java.util.zip.InflaterInputStream;
public class CRSSerializer {
private CRSSerializer() {}
+ static {
+ try {
+ // HACK: This is for warming up the piCache in tech.units.indriya.function.Calculus.
+ // Otherwise, concurrent calls to CoordinateReferenceSystem.toWKT() will cause a
+ // ConcurrentModificationException. This is a bug of tech.units.indriya, which was fixed
+ // in 2.1.4 by https://github.com/unitsofmeasurement/indriya/commit/fc370465
+ // However, 2.1.4 is not compatible with the GeoTools version we use. That's the reason
+ // why we have this workaround here.
+ NonSI.DEGREE_ANGLE.getConverterToAny(SI.RADIAN).convert(1);
+ } catch (IncommensurableException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
private static class CRSKey {
private final CoordinateReferenceSystem crs;
private final int hashCode;
@@ -85,10 +104,15 @@ public class CRSSerializer {
private static byte[] doSerializeCRS(CRSKey crsKey) throws IOException {
CoordinateReferenceSystem crs = crsKey.crs;
try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
- DeflaterOutputStream dos = new DeflaterOutputStream(bos);
- ObjectOutputStream oos = new ObjectOutputStream(dos)) {
- oos.writeObject(crs);
- oos.flush();
+ DeflaterOutputStream dos = new DeflaterOutputStream(bos)) {
+ String wktString;
+ if (crs instanceof Formattable) {
+ // Can specify "strict" as false to get rid of serialization errors in trade of correctness
+ wktString = ((Formattable) crs).toWKT(2, false);
+ } else {
+ wktString = crs.toWKT();
+ }
+ dos.write(wktString.getBytes(StandardCharsets.UTF_8));
dos.finish();
byte[] res = bos.toByteArray();
crsDeserializationCache.put(ByteBuffer.wrap(res), crs);
@@ -96,12 +120,13 @@ public class CRSSerializer {
}
}
- private static CoordinateReferenceSystem doDeserializeCRS(ByteBuffer byteBuffer) throws IOException, ClassNotFoundException {
+ private static CoordinateReferenceSystem doDeserializeCRS(ByteBuffer byteBuffer) throws IOException, FactoryException {
byte[] bytes = byteBuffer.array();
try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
- InflaterInputStream dis = new InflaterInputStream(bis);
- ObjectInputStream ois = new ObjectInputStream(dis)) {
- CoordinateReferenceSystem crs = (CoordinateReferenceSystem) ois.readObject();
+ InflaterInputStream dis = new InflaterInputStream(bis)) {
+ byte[] wktBytes = IOUtils.toByteArray(dis);
+ String wktString = new String(wktBytes, StandardCharsets.UTF_8);
+ CoordinateReferenceSystem crs = CRS.parseWKT(wktString);
crsSerializationCache.put(new CRSKey(crs), bytes);
return crs;
}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/DataBufferSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/DataBufferSerializer.java
new file mode 100644
index 000000000..4d886ab43
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/DataBufferSerializer.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import com.sun.media.jai.util.DataBufferUtils;
+
+import java.awt.image.DataBuffer;
+import java.awt.image.DataBufferByte;
+import java.awt.image.DataBufferInt;
+import java.awt.image.DataBufferShort;
+import java.awt.image.DataBufferUShort;
+
+public class DataBufferSerializer extends Serializer<DataBuffer> {
+ @Override
+ public void write(Kryo kryo, Output output, DataBuffer dataBuffer) {
+ int dataType = dataBuffer.getDataType();
+ output.writeInt(dataType);
+ KryoUtil.writeIntArray(output, dataBuffer.getOffsets());
+ output.writeInt(dataBuffer.getSize());
+ switch (dataType) {
+ case DataBuffer.TYPE_BYTE:
+ byte[][] byteDataArray = ((DataBufferByte) dataBuffer).getBankData();
+ KryoUtil.writeByteArrays(output, byteDataArray);
+ break;
+ case DataBuffer.TYPE_USHORT:
+ short[][] uShortDataArray = ((DataBufferUShort) dataBuffer).getBankData();
+ KryoUtil.writeShortArrays(output, uShortDataArray);
+ break;
+ case DataBuffer.TYPE_SHORT:
+ short[][] shortDataArray = ((DataBufferShort) dataBuffer).getBankData();
+ KryoUtil.writeShortArrays(output, shortDataArray);
+ break;
+ case DataBuffer.TYPE_INT:
+ int[][] intDataArray = ((DataBufferInt) dataBuffer).getBankData();
+ KryoUtil.writeIntArrays(output, intDataArray);
+ break;
+ case DataBuffer.TYPE_FLOAT:
+ float[][] floatDataArray = DataBufferUtils.getBankDataFloat(dataBuffer);
+ KryoUtil.writeFloatArrays(output, floatDataArray);
+ break;
+ case DataBuffer.TYPE_DOUBLE:
+ double[][] doubleDataArray = DataBufferUtils.getBankDataDouble(dataBuffer);
+ KryoUtil.writeDoubleArrays(output, doubleDataArray);
+ break;
+ default:
+ throw new RuntimeException("Unknown data type: " + dataType);
+ }
+ }
+
+ @Override
+ public DataBuffer read(Kryo kryo, Input input, Class<DataBuffer> type) {
+ int dataType = input.readInt();
+ int[] offsets = KryoUtil.readIntArray(input);
+ int size = input.readInt();
+ DataBuffer dataBuffer;
+ switch (dataType) {
+ case DataBuffer.TYPE_BYTE:
+ byte[][] byteDataArray = KryoUtil.readByteArrays(input);
+ dataBuffer = new DataBufferByte(byteDataArray, size, offsets);
+ break;
+ case DataBuffer.TYPE_USHORT:
+ short[][] uShortDataArray = KryoUtil.readShortArrays(input);
+ dataBuffer = new DataBufferUShort(uShortDataArray, size, offsets);
+ break;
+ case DataBuffer.TYPE_SHORT:
+ short[][] shortDataArray = KryoUtil.readShortArrays(input);
+ dataBuffer = new DataBufferShort(shortDataArray, size, offsets);
+ break;
+ case DataBuffer.TYPE_INT:
+ int[][] intDataArray = KryoUtil.readIntArrays(input);
+ dataBuffer = new DataBufferInt(intDataArray, size, offsets);
+ break;
+ case DataBuffer.TYPE_FLOAT:
+ float[][] floatDataArray = KryoUtil.readFloatArrays(input);
+ dataBuffer = DataBufferUtils.createDataBufferFloat(floatDataArray, size, offsets);
+ break;
+ case DataBuffer.TYPE_DOUBLE:
+ double[][] doubleDataArray = KryoUtil.readDoubleArrays(input);
+ dataBuffer = DataBufferUtils.createDataBufferDouble(doubleDataArray, size, offsets);
+ break;
+ default:
+ throw new RuntimeException("Unknown data type: " + dataType);
+ }
+ return dataBuffer;
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/GridEnvelopeSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/GridEnvelopeSerializer.java
new file mode 100644
index 000000000..024d8c8bc
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/GridEnvelopeSerializer.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.geotools.coverage.grid.GridEnvelope2D;
+
+public class GridEnvelopeSerializer extends Serializer<GridEnvelope2D> {
+ @Override
+ public void write(Kryo kryo, Output output, GridEnvelope2D object) {
+ output.writeInt(object.width);
+ output.writeInt(object.height);
+ output.writeInt(object.x);
+ output.writeInt(object.y);
+ }
+
+ @Override
+ public GridEnvelope2D read(Kryo kryo, Input input, Class<GridEnvelope2D> type) {
+ int width = input.readInt();
+ int height = input.readInt();
+ int x = input.readInt();
+ int y = input.readInt();
+ return new GridEnvelope2D(x, y, width, height);
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/GridSampleDimensionSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/GridSampleDimensionSerializer.java
new file mode 100644
index 000000000..f4ca504b5
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/GridSampleDimensionSerializer.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.apache.sedona.common.utils.RasterUtils;
+import org.geotools.coverage.Category;
+import org.geotools.coverage.GridSampleDimension;
+
+import java.util.List;
+
+/**
+ * GridSampleDimension and RenderedSampleDimension are not serializable. We need to provide a custom serializer
+ */
+public class GridSampleDimensionSerializer extends Serializer<GridSampleDimension> {
+ @Override
+ public void write(Kryo kryo, Output output, GridSampleDimension sampleDimension) {
+ String description = sampleDimension.getDescription().toString();
+ List<Category> categories = sampleDimension.getCategories();
+ double offset = sampleDimension.getOffset();
+ double scale = sampleDimension.getScale();
+ double noDataValue = RasterUtils.getNoDataValue(sampleDimension);
+ KryoUtil.writeUTF8String(output, description);
+ output.writeDouble(offset);
+ output.writeDouble(scale);
+ output.writeDouble(noDataValue); // for interoperability with Python RasterType.
+ KryoUtil.writeObjectWithLength(kryo, output, categories.toArray());
+ }
+
+ @Override
+ public GridSampleDimension read(Kryo kryo, Input input, Class aClass) {
+ String description = KryoUtil.readUTF8String(input);
+ double offset = input.readDouble();
+ double scale = input.readDouble();
+ input.readDouble(); // noDataValue is included in categories, so we just skip it
+ input.readInt(); // skip the length of the next object
+ Category[] categories = kryo.readObject(input, Category[].class);
+ return new GridSampleDimension(description, categories, offset, scale);
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/KryoUtil.java b/common/src/main/java/org/apache/sedona/common/raster/serde/KryoUtil.java
new file mode 100644
index 000000000..8292e7b55
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/KryoUtil.java
@@ -0,0 +1,297 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Utility methods for serializing objects with Kryo. The serialization formats are well-defined and independent
+ * of the Kryo version. This allows us to exchange serialized data with other tech stack, such as Python.
+ */
+public class KryoUtil {
+ private KryoUtil() {}
+
+ /**
+ * Write the length of the next serialized object, followed by the serialized object
+ * @param kryo the kryo instance
+ * @param output the output stream
+ * @param object the object to serialize
+ */
+ public static void writeObjectWithLength(Kryo kryo, Output output, Object object) {
+ int lengthOffset = output.position();
+ output.writeInt(0); // placeholder, will be overwritten later
+
+ // Write the object
+ int start = output.position();
+ kryo.writeObject(output, object);
+ int end = output.position();
+
+ // Rewrite the length
+ int length = end - start;
+ output.setPosition(lengthOffset);
+ output.writeInt(length);
+ output.setPosition(end);
+ }
+
+ /**
+ * Write string as UTF-8 byte sequence
+ * @param output the output stream
+ * @param value the string to write
+ */
+ public static void writeUTF8String(Output output, String value) {
+ byte[] utf8 = value.getBytes(StandardCharsets.UTF_8);
+ output.writeInt(utf8.length);
+ output.writeBytes(utf8);
+ }
+
+ /**
+ * Read UTF-8 byte sequence as string
+ * @param input the input stream
+ * @return the string
+ */
+ public static String readUTF8String(Input input) {
+ int length = input.readInt();
+ byte[] utf8 = new byte[length];
+ input.readBytes(utf8);
+ return new String(utf8, StandardCharsets.UTF_8);
+ }
+
+ /**
+ * Write an array of integers
+ * @param output the output stream
+ * @param array the array to write
+ */
+ public static void writeIntArray(Output output, int[] array) {
+ output.writeInt(array.length);
+ output.writeInts(array);
+ }
+
+ /**
+ * Read an array of integers
+ * @param input the input stream
+ * @return the array
+ */
+ public static int[] readIntArray(Input input) {
+ int length = input.readInt();
+ return input.readInts(length);
+ }
+
+ /**
+ * Write a 2-d array of ints
+ * @param output the output stream
+ * @param arrays the array to write
+ */
+ public static void writeIntArrays(Output output, int[][] arrays) {
+ output.writeInt(arrays.length);
+ for (int[] array : arrays) {
+ writeIntArray(output, array);
+ }
+ }
+
+ /**
+ * Read a 2-d array of ints
+ * @param input the input stream
+ * @return the array
+ */
+ public static int[][] readIntArrays(Input input) {
+ int length = input.readInt();
+ int[][] arrays = new int[length][];
+ for (int i = 0; i < length; i++) {
+ arrays[i] = readIntArray(input);
+ }
+ return arrays;
+ }
+
+ /**
+ * Write a 2-d array of bytes
+ * @param output the output stream
+ * @param arrays the array to write
+ */
+ public static void writeByteArrays(Output output, byte[][] arrays) {
+ output.writeInt(arrays.length);
+ for (byte[] array : arrays) {
+ output.writeInt(array.length);
+ output.writeBytes(array);
+ }
+ }
+
+ /**
+ * Read a 2-d array of bytes
+ * @param input the input stream
+ * @return the array
+ */
+ public static byte[][] readByteArrays(Input input) {
+ int length = input.readInt();
+ byte[][] arrays = new byte[length][];
+ for (int i = 0; i < length; i++) {
+ int arrayLength = input.readInt();
+ arrays[i] = input.readBytes(arrayLength);
+ }
+ return arrays;
+ }
+
+ /**
+ * Write a 2-d array of doubles
+ * @param output the output stream
+ * @param arrays the array to write
+ */
+ public static void writeDoubleArrays(Output output, double[][] arrays) {
+ output.writeInt(arrays.length);
+ for (double[] array : arrays) {
+ output.writeInt(array.length);
+ output.writeDoubles(array);
+ }
+ }
+
+ /**
+ * Read a 2-d array of doubles
+ * @param input the input stream
+ * @return the array
+ */
+ public static double[][] readDoubleArrays(Input input) {
+ int length = input.readInt();
+ double[][] arrays = new double[length][];
+ for (int i = 0; i < length; i++) {
+ int arrayLength = input.readInt();
+ arrays[i] = input.readDoubles(arrayLength);
+ }
+ return arrays;
+ }
+
+ /**
+ * Write a 2-d array of longs
+ * @param output the output stream
+ * @param arrays the array to write
+ */
+ public static void writeLongArrays(Output output, long[][] arrays) {
+ output.writeInt(arrays.length);
+ for (long[] array : arrays) {
+ output.writeInt(array.length);
+ output.writeLongs(array);
+ }
+ }
+
+ /**
+ * Read a 2-d array of longs
+ * @param input the input stream
+ * @return the array
+ */
+ public static long[][] readLongArrays(Input input) {
+ int length = input.readInt();
+ long[][] arrays = new long[length][];
+ for (int i = 0; i < length; i++) {
+ int arrayLength = input.readInt();
+ arrays[i] = input.readLongs(arrayLength);
+ }
+ return arrays;
+ }
+
+ /**
+ * Write a 2-d array of floats
+ * @param output the output stream
+ * @param arrays the array to write
+ */
+ public static void writeFloatArrays(Output output, float[][] arrays) {
+ output.writeInt(arrays.length);
+ for (float[] array : arrays) {
+ output.writeInt(array.length);
+ output.writeFloats(array);
+ }
+ }
+
+ /**
+ * Read a 2-d array of floats
+ * @param input the input stream
+ * @return the array
+ */
+ public static float[][] readFloatArrays(Input input) {
+ int length = input.readInt();
+ float[][] arrays = new float[length][];
+ for (int i = 0; i < length; i++) {
+ int arrayLength = input.readInt();
+ arrays[i] = input.readFloats(arrayLength);
+ }
+ return arrays;
+ }
+
+ /**
+ * Write a 2-d array of shorts
+ * @param output the output stream
+ * @param arrays the array to write
+ */
+ public static void writeShortArrays(Output output, short[][] arrays) {
+ output.writeInt(arrays.length);
+ for (short[] array : arrays) {
+ output.writeInt(array.length);
+ output.writeShorts(array);
+ }
+ }
+
+ /**
+ * Read a 2-d array of shorts
+ * @param input the input stream
+ * @return the array
+ */
+ public static short[][] readShortArrays(Input input) {
+ int length = input.readInt();
+ short[][] arrays = new short[length][];
+ for (int i = 0; i < length; i++) {
+ int arrayLength = input.readInt();
+ arrays[i] = input.readShorts(arrayLength);
+ }
+ return arrays;
+ }
+
+ /**
+ * Write a {@code Map<String, String>} object to the output stream
+ * @param output the output stream
+ * @param map the map to write
+ */
+ public static void writeUTF8StringMap(Output output, Map<String, String> map) {
+ if (map == null) {
+ output.writeInt(-1);
+ return;
+ }
+ output.writeInt(map.size());
+ for (Map.Entry<String, String> entry : map.entrySet()) {
+ writeUTF8String(output, entry.getKey());
+ writeUTF8String(output, entry.getValue());
+ }
+ }
+
+ /**
+ * Read a {@code Map<String, String>} object from the input stream
+ * @param input the input stream
+ * @return the map
+ */
+ public static Map<String, String> readUTF8StringMap(Input input) {
+ int size = input.readInt();
+ if (size == -1) {
+ return null;
+ }
+ Map<String, String> params = new HashMap<>(size);
+ for (int i = 0; i < size; i++) {
+ String key = readUTF8String(input);
+ String value = readUTF8String(input);
+ params.put(key, value);
+ }
+ return params;
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/SampleModelSerializer.java b/common/src/main/java/org/apache/sedona/common/raster/serde/SampleModelSerializer.java
new file mode 100644
index 000000000..1dce0b9be
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/SampleModelSerializer.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import javax.media.jai.ComponentSampleModelJAI;
+import javax.media.jai.RasterFactory;
+import java.awt.image.BandedSampleModel;
+import java.awt.image.ComponentSampleModel;
+import java.awt.image.MultiPixelPackedSampleModel;
+import java.awt.image.PixelInterleavedSampleModel;
+import java.awt.image.SampleModel;
+import java.awt.image.SinglePixelPackedSampleModel;
+
+/**
+ * Serializer for SampleModelState using Kryo. This is translated from the original JAI implementation
+ * of writeObject and readObject.
+ */
+public class SampleModelSerializer extends Serializer<SampleModel> {
+
+ // These constants are taken from SampleModelState
+ private static final int TYPE_BANDED = 1;
+ private static final int TYPE_PIXEL_INTERLEAVED = 2;
+ private static final int TYPE_SINGLE_PIXEL_PACKED = 3;
+ private static final int TYPE_MULTI_PIXEL_PACKED = 4;
+ private static final int TYPE_COMPONENT_JAI = 5;
+ private static final int TYPE_COMPONENT = 6;
+
+ private static int sampleModelTypeOf(SampleModel sampleModel) {
+ if (sampleModel instanceof ComponentSampleModel) {
+ if (sampleModel instanceof PixelInterleavedSampleModel) {
+ return TYPE_PIXEL_INTERLEAVED;
+ } else if (sampleModel instanceof BandedSampleModel) {
+ return TYPE_BANDED;
+ } else if (sampleModel instanceof ComponentSampleModelJAI) {
+ return TYPE_COMPONENT_JAI;
+ } else {
+ return TYPE_COMPONENT;
+ }
+ } else if (sampleModel instanceof SinglePixelPackedSampleModel) {
+ return TYPE_SINGLE_PIXEL_PACKED;
+ } else if (sampleModel instanceof MultiPixelPackedSampleModel) {
+ return TYPE_MULTI_PIXEL_PACKED;
+ } else {
+ throw new UnsupportedOperationException("Unsupported SampleModel type: " + sampleModel.getClass().getName());
+ }
+ }
+
+ @Override
+ public void write(Kryo kryo, Output output, SampleModel sampleModel) {
+ int sampleModelType = sampleModelTypeOf(sampleModel);
+ output.writeInt(sampleModelType);
+ output.writeInt(sampleModel.getTransferType());
+ output.writeInt(sampleModel.getWidth());
+ output.writeInt(sampleModel.getHeight());
+
+ switch (sampleModelType) {
+ case TYPE_BANDED: {
+ BandedSampleModel sm = (BandedSampleModel)sampleModel;
+ KryoUtil.writeIntArray(output, sm.getBankIndices());
+ KryoUtil.writeIntArray(output, sm.getBandOffsets());
+ break;
+ }
+
+ case TYPE_PIXEL_INTERLEAVED: {
+ PixelInterleavedSampleModel sm = (PixelInterleavedSampleModel)sampleModel;
+ output.writeInt(sm.getPixelStride());
+ output.writeInt(sm.getScanlineStride());
+ KryoUtil.writeIntArray(output, sm.getBandOffsets());
+ break;
+ }
+
+ case TYPE_COMPONENT:
+ case TYPE_COMPONENT_JAI: {
+ ComponentSampleModel sm = (ComponentSampleModel)sampleModel;
+ output.writeInt(sm.getPixelStride());
+ output.writeInt(sm.getScanlineStride());
+ KryoUtil.writeIntArray(output, sm.getBankIndices());
+ KryoUtil.writeIntArray(output, sm.getBandOffsets());
+ break;
+ }
+
+ case TYPE_SINGLE_PIXEL_PACKED: {
+ SinglePixelPackedSampleModel sm = (SinglePixelPackedSampleModel)sampleModel;
+ output.writeInt(sm.getScanlineStride());
+ KryoUtil.writeIntArray(output, sm.getBitMasks());
+ break;
+ }
+
+ case TYPE_MULTI_PIXEL_PACKED: {
+ MultiPixelPackedSampleModel sm = (MultiPixelPackedSampleModel)sampleModel;
+ output.writeInt(sm.getPixelBitStride());
+ output.writeInt(sm.getScanlineStride());
+ output.writeInt(sm.getDataBitOffset());
+ break;
+ }
+
+ default:
+ throw new UnsupportedOperationException("Unknown SampleModel type: " + sampleModel.getClass().getName());
+ }
+ }
+
+ @Override
+ public SampleModel read(Kryo kryo, Input input, Class<SampleModel> type) {
+ int sampleModelType = input.readInt();
+ int transferType = input.readInt();
+ int width = input.readInt();
+ int height = input.readInt();
+
+ switch (sampleModelType) {
+ case TYPE_BANDED: {
+ int[] bankIndices = KryoUtil.readIntArray(input);
+ int[] bandOffsets = KryoUtil.readIntArray(input);
+ return RasterFactory.createBandedSampleModel(transferType, width, height, bankIndices.length, bankIndices, bandOffsets);
+ }
+
+ case TYPE_PIXEL_INTERLEAVED: {
+ int pixelStride = input.readInt();
+ int scanLineStride = input.readInt();
+ int[] bandOffsets = KryoUtil.readIntArray(input);
+ return RasterFactory.createPixelInterleavedSampleModel(transferType, width, height, pixelStride, scanLineStride, bandOffsets);
+ }
+
+ case TYPE_COMPONENT_JAI:
+ case TYPE_COMPONENT: {
+ int pixelStride = input.readInt();
+ int scanLineStride = input.readInt();
+ int[] bankIndices = KryoUtil.readIntArray(input);
+ int[] bandOffsets = KryoUtil.readIntArray(input);
+ if (sampleModelType == TYPE_COMPONENT_JAI) {
+ return new ComponentSampleModelJAI(transferType, width, height, pixelStride, scanLineStride, bankIndices, bandOffsets);
+ } else {
+ return new ComponentSampleModel(transferType, width, height, pixelStride, scanLineStride, bankIndices, bandOffsets);
+ }
+ }
+
+ case TYPE_SINGLE_PIXEL_PACKED: {
+ int scanLineStride = input.readInt();
+ int[] bitMasks = KryoUtil.readIntArray(input);
+ return new SinglePixelPackedSampleModel(transferType, width, height, scanLineStride, bitMasks);
+ }
+
+ case TYPE_MULTI_PIXEL_PACKED: {
+ int pixelStride = input.readInt();
+ int scanLineStride = input.readInt();
+ int dataBitOffset = input.readInt();
+ return new MultiPixelPackedSampleModel(transferType, width, height, pixelStride, scanLineStride, dataBitOffset);
+ }
+
+ default:
+ throw new UnsupportedOperationException("Unsupported SampleModel type: " + sampleModelType);
+ }
+ }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
new file mode 100644
index 000000000..616ded015
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
@@ -0,0 +1,179 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.KryoSerializable;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import com.esotericsoftware.kryo.io.UnsafeInput;
+import com.esotericsoftware.kryo.io.UnsafeOutput;
+import org.apache.sedona.common.raster.DeepCopiedRenderedImage;
+import org.geotools.coverage.GridSampleDimension;
+import org.geotools.coverage.grid.GridCoverage2D;
+import org.geotools.coverage.grid.GridCoverageFactory;
+import org.geotools.coverage.grid.GridEnvelope2D;
+import org.geotools.coverage.grid.GridGeometry2D;
+import org.geotools.referencing.operation.transform.AffineTransform2D;
+import org.objenesis.strategy.StdInstantiatorStrategy;
+import org.opengis.referencing.operation.MathTransform;
+
+import javax.media.jai.RenderedImageAdapter;
+import java.awt.image.RenderedImage;
+import java.io.IOException;
+import java.io.Serializable;
+import java.net.URI;
+
+public class Serde {
+ private Serde() {}
+
+ /**
+ * URIs are not serializable. We need to provide a custom serializer
+ */
+ private static class URISerializer extends Serializer<java.net.URI> {
+ public URISerializer() {
+ setImmutable(true);
+ }
+
+ @Override
+ public void write(final Kryo kryo, final Output output, final URI uri) {
+ KryoUtil.writeUTF8String(output, uri.toString());
+ }
+
+ @Override
+ public URI read(final Kryo kryo, final Input input, final Class<URI> uriClass) {
+ return URI.create(KryoUtil.readUTF8String(input));
+ }
+ }
+
+ private static final ThreadLocal<Kryo> kryos = ThreadLocal.withInitial(() -> {
+ Kryo kryo = new Kryo();
+ kryo.setInstantiatorStrategy(new Kryo.DefaultInstantiatorStrategy(new StdInstantiatorStrategy()));
+ kryo.register(AffineTransform2D.class, new AffineTransform2DSerializer());
+ kryo.register(GridSampleDimension.class, new GridSampleDimensionSerializer());
+ kryo.register(URI.class, new URISerializer());
+ DeepCopiedRenderedImage.registerKryo(kryo);
+ try {
+ kryo.register(Class.forName("org.geotools.coverage.grid.RenderedSampleDimension"),
+ new GridSampleDimensionSerializer());
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException("Cannot register kryo serializer for class RenderedSampleDimension", e);
+ }
+ kryo.setClassLoader(Thread.currentThread().getContextClassLoader());
+ return kryo;
+ });
+
+ private static class SerializableState implements Serializable, KryoSerializable {
+ public CharSequence name;
+
+ // The following three components are used to construct a GridGeometry2D object.
+ // We serialize CRS separately because the default serializer is pretty slow, we use a
+ // cached serializer to speed up the serialization and reuse CRS on deserialization.
+ public GridEnvelope2D gridEnvelope2D;
+ public MathTransform gridToCRS;
+ public byte[] serializedCRS;
+
+ public GridSampleDimension[] bands;
+ public DeepCopiedRenderedImage image;
+
+ public GridCoverage2D restore() {
+ GridGeometry2D gridGeometry = new GridGeometry2D(gridEnvelope2D, gridToCRS, CRSSerializer.deserialize(serializedCRS));
+ return new GridCoverageFactory().create(name, image, gridGeometry, bands, null, null);
+ }
+
+ private static final GridEnvelopeSerializer gridEnvelopeSerializer = new GridEnvelopeSerializer();
+ private static final AffineTransform2DSerializer affineTransform2DSerializer = new AffineTransform2DSerializer();
+ private static final GridSampleDimensionSerializer gridSampleDimensionSerializer = new GridSampleDimensionSerializer();
+
+ @Override
+ public void write(Kryo kryo, Output output) {
+ KryoUtil.writeUTF8String(output, name.toString());
+ gridEnvelopeSerializer.write(kryo, output, gridEnvelope2D);
+ if (!(gridToCRS instanceof AffineTransform2D)) {
+ throw new UnsupportedOperationException("Only AffineTransform2D is supported");
+ }
+ affineTransform2DSerializer.write(kryo, output, (AffineTransform2D) gridToCRS);
+ output.writeInt(serializedCRS.length);
+ output.writeBytes(serializedCRS);
+ output.writeInt(bands.length);
+ for (GridSampleDimension band : bands) {
+ gridSampleDimensionSerializer.write(kryo, output, band);
+ }
+ image.write(kryo, output);
+ }
+
+ @Override
+ public void read(Kryo kryo, Input input) {
+ name = KryoUtil.readUTF8String(input);
+ gridEnvelope2D = gridEnvelopeSerializer.read(kryo, input, GridEnvelope2D.class);
+ gridToCRS = affineTransform2DSerializer.read(kryo, input, AffineTransform2D.class);
+ int serializedCRSLength = input.readInt();
+ serializedCRS = input.readBytes(serializedCRSLength);
+ int bandCount = input.readInt();
+ bands = new GridSampleDimension[bandCount];
+ for (int i = 0; i < bandCount; i++) {
+ bands[i] = gridSampleDimensionSerializer.read(kryo, input, GridSampleDimension.class);
+ }
+ image = new DeepCopiedRenderedImage();
+ image.read(kryo, input);
+ }
+ }
+
+ // A byte reserved for supporting rasters with other storage schemes
+ private static final int IN_DB = 0;
+
+ public static byte[] serialize(GridCoverage2D raster) throws IOException {
+ Kryo kryo = kryos.get();
+ // GridCoverage2D created by GridCoverage2DReaders contain references that are not serializable.
+ // Wrap the RenderedImage in DeepCopiedRenderedImage to make it serializable.
+ DeepCopiedRenderedImage deepCopiedRenderedImage = null;
+ RenderedImage renderedImage = raster.getRenderedImage();
+ while (renderedImage instanceof RenderedImageAdapter) {
+ renderedImage = ((RenderedImageAdapter) renderedImage).getWrappedImage();
+ }
+ if (renderedImage instanceof DeepCopiedRenderedImage) {
+ deepCopiedRenderedImage = (DeepCopiedRenderedImage) renderedImage;
+ } else {
+ deepCopiedRenderedImage = new DeepCopiedRenderedImage(renderedImage);
+ }
+
+ SerializableState state = new SerializableState();
+ GridGeometry2D gridGeometry = raster.getGridGeometry();
+ state.name = raster.getName();
+ state.gridEnvelope2D = gridGeometry.getGridRange2D();
+ state.gridToCRS = gridGeometry.getGridToCRS2D();
+ state.serializedCRS = CRSSerializer.serialize(gridGeometry.getCoordinateReferenceSystem());
+ state.bands = raster.getSampleDimensions();
+ state.image = deepCopiedRenderedImage;
+ try (UnsafeOutput out = new UnsafeOutput(4096, -1)) {
+ out.writeByte(IN_DB);
+ state.write(kryo, out);
+ return out.toBytes();
+ }
+ }
+
+ public static GridCoverage2D deserialize(byte[] bytes) throws IOException, ClassNotFoundException {
+ Kryo kryo = kryos.get();
+ try (UnsafeInput in = new UnsafeInput(bytes)) {
+ int rasterType = in.readByte();
+ if (rasterType != IN_DB) {
+ throw new IllegalArgumentException("Unsupported raster type: " + rasterType);
+ }
+ SerializableState state = new SerializableState();
+ state.read(kryo, in);
+ return state.restore();
+ }
+ }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java
index 78af992c0..19b932ad7 100644
--- a/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java
+++ b/common/src/test/java/org/apache/sedona/common/raster/RasterBandEditorsTest.java
@@ -19,6 +19,7 @@
package org.apache.sedona.common.raster;
import org.apache.sedona.common.Constructors;
+import org.apache.sedona.common.raster.serde.Serde;
import org.geotools.coverage.grid.GridCoverage2D;
import org.junit.Test;
import org.locationtech.jts.geom.Geometry;
@@ -189,7 +190,7 @@ public class RasterBandEditorsTest extends RasterTestBase{
}
@Test
- public void testClip() throws IOException, FactoryException, TransformException, ParseException {
+ public void testClip() throws IOException, FactoryException, TransformException, ParseException, ClassNotFoundException {
GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif");
String polygon = "POLYGON ((236722 4204770, 243900 4204770, 243900 4197590, 221170 4197590, 236722 4204770))";
Geometry geom = Constructors.geomFromWKT(polygon, RasterAccessors.srid(raster));
@@ -216,6 +217,8 @@ public class RasterBandEditorsTest extends RasterTestBase{
GridCoverage2D croppedRaster = RasterBandEditors.clip(raster, 1, geom, 200, true);
assertEquals(0, croppedRaster.getRenderedImage().getMinX());
assertEquals(0, croppedRaster.getRenderedImage().getMinY());
+ GridCoverage2D croppedRaster2 = Serde.deserialize(Serde.serialize(croppedRaster));
+ assertSameCoverage(croppedRaster, croppedRaster2);
points = new ArrayList<>();
points.add(Constructors.geomFromWKT("POINT(236842 4.20465e+06)", 26918));
points.add(Constructors.geomFromWKT("POINT(236961 4.20453e+06)", 26918));
diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsForTestingTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsForTestingTest.java
new file mode 100644
index 000000000..0dde7811d
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/raster/RasterConstructorsForTestingTest.java
@@ -0,0 +1,111 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster;
+
+import org.apache.sedona.common.raster.serde.Serde;
+import org.apache.sedona.common.utils.RasterUtils;
+import org.geotools.coverage.grid.GridCoverage2D;
+import org.junit.Test;
+
+import java.awt.image.ComponentSampleModel;
+import java.awt.image.MultiPixelPackedSampleModel;
+import java.awt.image.PixelInterleavedSampleModel;
+import java.awt.image.Raster;
+import java.awt.image.SinglePixelPackedSampleModel;
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class RasterConstructorsForTestingTest extends RasterTestBase {
+ @Test
+ public void testBandedRaster() {
+ GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "BandedSampleModel", 4, 3);
+ assertTrue(raster.getRenderedImage().getSampleModel() instanceof ComponentSampleModel);
+ testSerde(raster);
+ }
+
+ @Test
+ public void testPixelInterleavedRaster() {
+ GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "PixelInterleavedSampleModel", 4, 3);
+ assertTrue(raster.getRenderedImage().getSampleModel() instanceof PixelInterleavedSampleModel);
+ testSerde(raster);
+ raster = makeRasterWithFallbackParams(4, "I", "PixelInterleavedSampleModelComplex", 4, 3);
+ assertTrue(raster.getRenderedImage().getSampleModel() instanceof PixelInterleavedSampleModel);
+ testSerde(raster);
+ }
+
+ @Test
+ public void testComponentSampleModel() {
+ GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "ComponentSampleModel", 4, 3);
+ assertTrue(raster.getRenderedImage().getSampleModel() instanceof ComponentSampleModel);
+ testSerde(raster);
+ }
+
+ @Test
+ public void testSinglePixelPackedSampleModel() {
+ GridCoverage2D raster = makeRasterWithFallbackParams(4, "I", "SinglePixelPackedSampleModel", 4, 3);
+ assertTrue(raster.getRenderedImage().getSampleModel() instanceof SinglePixelPackedSampleModel);
+ testSerde(raster);
+ }
+
+ @Test
+ public void testMultiPixelPackedSampleModel() {
+ GridCoverage2D raster = makeRasterWithFallbackParams(1, "B", "MultiPixelPackedSampleModel", 4, 3);
+ assertTrue(raster.getRenderedImage().getSampleModel() instanceof MultiPixelPackedSampleModel);
+ testSerde(raster);
+
+ raster = makeRasterWithFallbackParams(1, "B", "MultiPixelPackedSampleModel", 21, 8);
+ Raster r = RasterUtils.getRaster(raster.getRenderedImage());
+ assertTrue(r.getSampleModel() instanceof MultiPixelPackedSampleModel);
+ assertEquals(21, r.getWidth());
+ assertEquals(8, r.getHeight());
+ for (int y = 0; y < r.getHeight(); y++) {
+ for (int x = 0; x < r.getWidth(); x++) {
+ assertEquals((x + y * 21) % 16, r.getSample(x, y, 0));
+ }
+ }
+ }
+
+ private static GridCoverage2D makeRasterWithFallbackParams(int numBand, String bandDataType, String sampleModelType, int width, int height) {
+ return RasterConstructorsForTesting.makeRasterForTesting(numBand, bandDataType, sampleModelType, width, height,
+ 0.5, -0.5, 1, -1, 0, 0, 3857);
+ }
+
+ private static void testSerde(GridCoverage2D raster) {
+ try {
+ byte[] bytes = Serde.serialize(raster);
+ GridCoverage2D roundTripRaster = Serde.deserialize(bytes);
+ assertNotNull(roundTripRaster);
+ assertEquals(raster.getNumSampleDimensions(), roundTripRaster.getNumSampleDimensions());
+
+ assertEquals(raster.getGridGeometry(), roundTripRaster.getGridGeometry());
+ int width = raster.getRenderedImage().getWidth();
+ int height = raster.getRenderedImage().getHeight();
+ Raster r = RasterUtils.getRaster(raster.getRenderedImage());
+ for (int b = 0; b < raster.getNumSampleDimensions(); b++) {
+ for (int y = 0; y < height; y++) {
+ for (int x = 0; x < width; x++) {
+ double value = b + y * width + x;
+ assertEquals(value, r.getSampleDouble(x, y, b), 0.0001);
+ }
+ }
+ }
+
+ } catch (IOException | ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java b/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java
index e1f1a9ee9..d531c77f6 100644
--- a/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java
+++ b/common/src/test/java/org/apache/sedona/common/raster/RasterTestBase.java
@@ -49,8 +49,8 @@ public class RasterTestBase {
protected static final double FP_TOLERANCE = 1E-4;
- GridCoverage2D oneBandRaster;
- GridCoverage2D multiBandRaster;
+ protected GridCoverage2D oneBandRaster;
+ protected GridCoverage2D multiBandRaster;
byte[] geoTiff;
byte[] testNc;
String ncFile = resourceFolder + "raster/netcdf/test.nc";
@@ -121,6 +121,10 @@ public class RasterTestBase {
return factory.create("test", image, new Envelope2D(DefaultGeographicCRS.WGS84, 0, 0, 10, 10));
}
+ protected void assertSameCoverage(GridCoverage2D expected, GridCoverage2D actual) {
+ assertSameCoverage(expected, actual, 10);
+ }
+
protected void assertSameCoverage(GridCoverage2D expected, GridCoverage2D actual, int density) {
Assert.assertEquals(expected.getNumSampleDimensions(), actual.getNumSampleDimensions());
Envelope expectedEnvelope = expected.getEnvelope();
diff --git a/common/src/test/java/org/apache/sedona/common/raster/CRSSerializerTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/CRSSerializerTest.java
similarity index 97%
rename from common/src/test/java/org/apache/sedona/common/raster/CRSSerializerTest.java
rename to common/src/test/java/org/apache/sedona/common/raster/serde/CRSSerializerTest.java
index fd749bc06..bb5399fd0 100644
--- a/common/src/test/java/org/apache/sedona/common/raster/CRSSerializerTest.java
+++ b/common/src/test/java/org/apache/sedona/common/raster/serde/CRSSerializerTest.java
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.sedona.common.raster;
+package org.apache.sedona.common.raster.serde;
import org.geotools.referencing.CRS;
import org.junit.Assert;
diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/DataBufferSerializerTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/DataBufferSerializerTest.java
new file mode 100644
index 000000000..b3c6de179
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/raster/serde/DataBufferSerializerTest.java
@@ -0,0 +1,153 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.awt.image.DataBuffer;
+import java.awt.image.DataBufferByte;
+import java.awt.image.DataBufferDouble;
+import java.awt.image.DataBufferFloat;
+import java.awt.image.DataBufferInt;
+import java.awt.image.DataBufferShort;
+import java.awt.image.DataBufferUShort;
+
+public class DataBufferSerializerTest extends KryoSerializerTestBase {
+ private static final DataBufferSerializer serializer = new DataBufferSerializer();
+
+ private static void assertEquals(DataBuffer expected, DataBuffer actual) {
+ Assert.assertEquals(expected.getDataType(), actual.getDataType());
+ Assert.assertEquals(expected.getNumBanks(), actual.getNumBanks());
+ Assert.assertEquals(expected.getSize(), actual.getSize());
+ Assert.assertArrayEquals(expected.getOffsets(), actual.getOffsets());
+ for (int bank = 0; bank < expected.getNumBanks(); bank++) {
+ for (int k = 0; k < expected.getSize(); k++) {
+ Assert.assertEquals(expected.getElemDouble(bank, k), actual.getElemDouble(bank, k), 1e-6);
+ }
+ }
+ }
+
+ @Test
+ public void serializeByteBuffer() {
+ byte[][] dataArray = {
+ {1, 2, 3, 4, 5},
+ {6, 7, 8, 9, 0}
+ };
+ int size = 5;
+ int[] offsets = {0, 0};
+ DataBufferByte dataBufferByte = new DataBufferByte(dataArray, size, offsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, dataBufferByte);
+ try (Input in = createInput(out)) {
+ DataBuffer dataBufferByte1 = serializer.read(kryo, in, DataBuffer.class);
+ assertEquals(dataBufferByte, dataBufferByte1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeShortBuffer() {
+ short[][] dataArray = {
+ {1, 2, 3, 4, 5},
+ {6, 7, 8, 9, 0}
+ };
+ int size = 5;
+ int[] offsets = {0, 0};
+ DataBuffer dataBufferShort = new DataBufferShort(dataArray, size, offsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, dataBufferShort);
+ try (Input in = createInput(out)) {
+ DataBuffer dataBufferShort1 = serializer.read(kryo, in, DataBuffer.class);
+ Assert.assertTrue(dataBufferShort1 instanceof DataBufferShort);
+ assertEquals(dataBufferShort, dataBufferShort1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeUShortBuffer() {
+ short[][] dataArray = {
+ {1, 2, 3, 4, 5},
+ {6, 7, 8, 9, 0}
+ };
+ int size = 5;
+ int[] offsets = {0, 0};
+ DataBuffer dataBufferShort = new DataBufferUShort(dataArray, size, offsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, dataBufferShort);
+ try (Input in = createInput(out)) {
+ DataBuffer dataBufferShort1 = serializer.read(kryo, in, DataBuffer.class);
+ Assert.assertTrue(dataBufferShort1 instanceof DataBufferUShort);
+ assertEquals(dataBufferShort, dataBufferShort1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeIntBuffer() {
+ int[][] dataArray = {
+ {1, 2, 3, 4, 5},
+ {6, 7, 8, 9, 0}
+ };
+ int size = 5;
+ int[] offsets = {0, 0};
+ DataBuffer dataBufferInt = new DataBufferInt(dataArray, size, offsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, dataBufferInt);
+ try (Input in = createInput(out)) {
+ DataBuffer dataBufferInt1 = serializer.read(kryo, in, DataBuffer.class);
+ assertEquals(dataBufferInt, dataBufferInt1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeFloatBuffer() {
+ float[][] dataArray = {
+ {1, 2, 3, 4, 5},
+ {6, 7, 8, 9, 0}
+ };
+ int size = 5;
+ int[] offsets = {0, 0};
+ DataBuffer dataBufferFloat = new DataBufferFloat(dataArray, size, offsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, dataBufferFloat);
+ try (Input in = createInput(out)) {
+ DataBuffer dataBufferFloat1 = serializer.read(kryo, in, DataBuffer.class);
+ assertEquals(dataBufferFloat, dataBufferFloat1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeDoubleBuffer() {
+ double[][] dataArray = {
+ {1, 2, 3, 4, 5},
+ {6, 7, 8, 9, 0}
+ };
+ int size = 5;
+ int[] offsets = {0, 0};
+ DataBuffer dataBufferDouble = new DataBufferDouble(dataArray, size, offsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, dataBufferDouble);
+ try (Input in = createInput(out)) {
+ DataBuffer dataBufferDouble1 = serializer.read(kryo, in, DataBuffer.class);
+ assertEquals(dataBufferDouble, dataBufferDouble1);
+ }
+ }
+ }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/KryoSerializerTestBase.java b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoSerializerTestBase.java
new file mode 100644
index 000000000..af088a3cc
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoSerializerTestBase.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import com.esotericsoftware.kryo.io.UnsafeInput;
+import com.esotericsoftware.kryo.io.UnsafeOutput;
+
+public class KryoSerializerTestBase {
+ protected static final Kryo kryo = new Kryo();
+
+ protected static Output createOutput() {
+ return new UnsafeOutput(4096, -1);
+ }
+
+ protected static Input createInput(Output out) {
+ out.flush();
+ byte[] bytes = out.toBytes();
+ return new UnsafeInput(bytes);
+ }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/KryoUtilTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoUtilTest.java
new file mode 100644
index 000000000..359f54f87
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/raster/serde/KryoUtilTest.java
@@ -0,0 +1,234 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+public class KryoUtilTest extends KryoSerializerTestBase {
+
+ private static class TestClass {
+ private int a;
+ private String b;
+ private double[] c;
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ TestClass testClass = (TestClass) o;
+ return a == testClass.a && Objects.equals(b, testClass.b) && Arrays.equals(c, testClass.c);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(a, b, Arrays.hashCode(c));
+ }
+ }
+
+ @Test
+ public void writeObjectWithLength() {
+ TestClass obj = new TestClass();
+ obj.a = 1;
+ obj.b = "test";
+ obj.c = new double[]{1.0, 2.0, 3.0};
+ try (Output out = createOutput()) {
+ KryoUtil.writeObjectWithLength(kryo, out, obj);
+ try (Input in = createInput(out)) {
+ in.readInt(); // skip serialized object length
+ TestClass obj2 = kryo.readObject(in, TestClass.class);
+ assertEquals(obj, obj2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeUTF8String() {
+ String str =
+ "Hello - English\n" +
+ "Hola - Spanish\n" +
+ "Bonjour - French\n" +
+ "Hallo - German\n" +
+ "Ciao - Italian\n" +
+ "你好 - Chinese\n" +
+ "こんにちは - Japanese\n" +
+ "안녕하세요 - Korean\n" +
+ "Здравствуйте - Russian\n" +
+ "नमस्ते - Hindi\n" +
+ "مرحبا - Arabic\n" +
+ "שלום - Hebrew\n" +
+ "สวัสดี - Thai\n" +
+ "Merhaba - Turkish\n" +
+ "Γεια σας - Greek";
+ try (Output out = createOutput()) {
+ KryoUtil.writeUTF8String(out, str);
+ try (Input in = createInput(out)) {
+ String str2 = KryoUtil.readUTF8String(in);
+ assertEquals(str, str2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeIntArray() {
+ int[] arr = new int[]{1, 2, 3, 4, 5};
+ try (Output out = createOutput()) {
+ KryoUtil.writeIntArray(out, arr);
+ try (Input in = createInput(out)) {
+ int[] arr2 = KryoUtil.readIntArray(in);
+ assertArrayEquals(arr, arr2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeIntArrays() {
+ int[][] arrs = new int[][]{
+ new int[]{1, 2, 3, 4, 5},
+ new int[]{6, 7, 8, 9, 10}
+ };
+ try (Output out = createOutput()) {
+ KryoUtil.writeIntArrays(out, arrs);
+ try (Input in = createInput(out)) {
+ int[][] arrs2 = KryoUtil.readIntArrays(in);
+ assertArrayEquals(arrs, arrs2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeByteArrays() {
+ byte[][] arrs = new byte[][]{
+ new byte[]{1, 2, 3, 4, 5},
+ new byte[]{6, 7, 8, 9, 10}
+ };
+ try (Output out = createOutput()) {
+ KryoUtil.writeByteArrays(out, arrs);
+ try (Input in = createInput(out)) {
+ byte[][] arrs2 = KryoUtil.readByteArrays(in);
+ assertArrayEquals(arrs, arrs2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeDoubleArrays() {
+ double[][] arrs = new double[][]{
+ new double[]{1.0, 2.0, 3.0, 4.0, 5.0},
+ new double[]{6.0, 7.0, 8.0, 9.0, 10.0}
+ };
+ try (Output out = createOutput()) {
+ KryoUtil.writeDoubleArrays(out, arrs);
+ try (Input in = createInput(out)) {
+ double[][] arrs2 = KryoUtil.readDoubleArrays(in);
+ assertArrayEquals(arrs, arrs2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeLongArrays() {
+ long[][] arrs = new long[][]{
+ new long[]{1L, 2L, 3L, 4L, 5L},
+ new long[]{6L, 7L, 8L, 9L, 10L}
+ };
+ try (Output out = createOutput()) {
+ KryoUtil.writeLongArrays(out, arrs);
+ try (Input in = createInput(out)) {
+ long[][] arrs2 = KryoUtil.readLongArrays(in);
+ assertArrayEquals(arrs, arrs2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeFloatArrays() {
+ float[][] arrs = new float[][]{
+ new float[]{1.0f, 2.0f, 3.0f, 4.0f, 5.0f},
+ new float[]{6.0f, 7.0f, 8.0f, 9.0f, 10.0f}
+ };
+ try (Output out = createOutput()) {
+ KryoUtil.writeFloatArrays(out, arrs);
+ try (Input in = createInput(out)) {
+ float[][] arrs2 = KryoUtil.readFloatArrays(in);
+ assertArrayEquals(arrs, arrs2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeShortArrays() {
+ short[][] arrs = new short[][]{
+ new short[]{1, 2, 3, 4, 5},
+ new short[]{6, 7, 8, 9, 10}
+ };
+ try (Output out = createOutput()) {
+ KryoUtil.writeShortArrays(out, arrs);
+ try (Input in = createInput(out)) {
+ short[][] arrs2 = KryoUtil.readShortArrays(in);
+ assertArrayEquals(arrs, arrs2);
+ }
+ }
+ }
+
+ @Test
+ public void serializeNullUTF8StringMap() {
+ try (Output out = createOutput()) {
+ KryoUtil.writeUTF8StringMap(out, null);
+ try (Input in = createInput(out)) {
+ Map<String, String> map = KryoUtil.readUTF8StringMap(in);
+ Assert.assertNull(map);
+ }
+ }
+ }
+
+ @Test
+ public void serializeEmptyUTF8StringMap() {
+ try (Output out = createOutput()) {
+ KryoUtil.writeUTF8StringMap(out, Collections.emptyMap());
+ try (Input in = createInput(out)) {
+ Map<String, String> map = KryoUtil.readUTF8StringMap(in);
+ Assert.assertNotNull(map);
+ Assert.assertTrue(map.isEmpty());
+ }
+ }
+ }
+
+ @Test
+ public void serializeUTF8StringMap() {
+ Map<String, String> map = new HashMap<>();
+ map.put("key1", "value1");
+ map.put("key2", "value2");
+
+ try (Output out = createOutput()) {
+ KryoUtil.writeUTF8StringMap(out, map);
+ try (Input in = createInput(out)) {
+ Map<String, String> map2 = KryoUtil.readUTF8StringMap(in);
+ Assert.assertNotNull(map2);
+ Assert.assertEquals(map, map2);
+ }
+ }
+ }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/raster/serde/SampleModelSerializerTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/SampleModelSerializerTest.java
new file mode 100644
index 000000000..5bac27f93
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/raster/serde/SampleModelSerializerTest.java
@@ -0,0 +1,112 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster.serde;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.media.jai.ComponentSampleModelJAI;
+import java.awt.image.BandedSampleModel;
+import java.awt.image.ComponentSampleModel;
+import java.awt.image.DataBuffer;
+import java.awt.image.MultiPixelPackedSampleModel;
+import java.awt.image.PixelInterleavedSampleModel;
+import java.awt.image.SampleModel;
+import java.awt.image.SinglePixelPackedSampleModel;
+
+public class SampleModelSerializerTest extends KryoSerializerTestBase {
+ private static final SampleModelSerializer serializer = new SampleModelSerializer();
+
+ @Test
+ public void serializeBandedSampleModel() {
+ int[] bankIndices = {2, 0, 1};
+ int[] bandOffsets = {4, 8, 12};
+ SampleModel sm = new BandedSampleModel(DataBuffer.TYPE_INT, 100, 80, 100, bankIndices, bandOffsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, sm);
+ try (Input in = createInput(out)) {
+ SampleModel sm1 = serializer.read(kryo, in, SampleModel.class);
+ Assert.assertEquals(sm, sm1);
+ }
+ }
+ }
+
+ @Test
+ public void serializePixelInterleavedSampleModel() {
+ int[] bandOffsets = {0, 1, 2};
+ SampleModel sm = new PixelInterleavedSampleModel(DataBuffer.TYPE_INT, 100, 80, 3, 300, bandOffsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, sm);
+ try (Input in = createInput(out)) {
+ SampleModel sm1 = serializer.read(kryo, in, SampleModel.class);
+ Assert.assertEquals(sm, sm1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeComponentSampleModel() {
+ int[] bankIndices = {1, 0};
+ int[] bandOffsets = {0, 10000};
+ SampleModel sm = new ComponentSampleModel(DataBuffer.TYPE_INT, 100, 80, 1, 100, bankIndices, bandOffsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, sm);
+ try (Input in = createInput(out)) {
+ SampleModel sm1 = serializer.read(kryo, in, SampleModel.class);
+ Assert.assertEquals(sm, sm1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeComponentSampleModelJAI() {
+ int[] bankIndices = {1, 0};
+ int[] bandOffsets = {0, 10000};
+ SampleModel sm = new ComponentSampleModelJAI(DataBuffer.TYPE_INT, 100, 80, 1, 100, bankIndices, bandOffsets);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, sm);
+ try (Input in = createInput(out)) {
+ SampleModel sm1 = serializer.read(kryo, in, SampleModel.class);
+ Assert.assertEquals(sm, sm1);
+ }
+ }
+ }
+
+ @Test
+ public void serializeSinglePixelPackedSampleModel() {
+ int[] bitMasks = {0x000000ff, 0x0000ff00, 0x00ff0000};
+ SampleModel sm = new SinglePixelPackedSampleModel(DataBuffer.TYPE_INT, 100, 80, 100, bitMasks);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, sm);
+ try (Input in = createInput(out)) {
+ SampleModel sm1 = serializer.read(kryo, in, SampleModel.class);
+ Assert.assertEquals(sm, sm1);
+ }
+ }
+ }
+
+ @Test
+ public void serializedMultiPixelPackedSampleModel() {
+ SampleModel sm = new MultiPixelPackedSampleModel(DataBuffer.TYPE_BYTE, 100, 80, 4);
+ try (Output out = createOutput()) {
+ serializer.write(kryo, out, sm);
+ try (Input in = createInput(out)) {
+ SampleModel sm1 = serializer.read(kryo, in, SampleModel.class);
+ Assert.assertEquals(sm, sm1);
+ }
+ }
+ }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java b/common/src/test/java/org/apache/sedona/common/raster/serde/SerdeTest.java
similarity index 73%
rename from common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java
rename to common/src/test/java/org/apache/sedona/common/raster/serde/SerdeTest.java
index 7a29aaa45..844c3d34f 100644
--- a/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/raster/serde/SerdeTest.java
@@ -16,11 +16,14 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.sedona.common.raster;
+package org.apache.sedona.common.raster.serde;
+import org.apache.sedona.common.raster.RasterConstructors;
+import org.apache.sedona.common.raster.RasterTestBase;
import org.geotools.coverage.grid.GridCoverage2D;
import org.geotools.gce.geotiff.GeoTiffReader;
import org.junit.Test;
+import org.opengis.referencing.FactoryException;
import java.io.File;
import java.io.IOException;
@@ -37,12 +40,12 @@ public class SerdeTest extends RasterTestBase {
};
@Test
- public void testRoundtripSerdeSingelbandRaster() throws IOException, ClassNotFoundException {
+ public void testRoundTripSerdeSingleBandRaster() throws IOException, ClassNotFoundException {
testRoundTrip(oneBandRaster);
}
@Test
- public void testRoundtripSerdeMultibandRaster() throws IOException, ClassNotFoundException {
+ public void testRoundTripSerdeMultiBandRaster() throws IOException, ClassNotFoundException {
testRoundTrip(multiBandRaster);
}
@@ -55,6 +58,20 @@ public class SerdeTest extends RasterTestBase {
}
}
+ @Test
+ public void testNorthPoleRaster() throws IOException, ClassNotFoundException, FactoryException {
+ // If we are not using non-strict mode to serializing CRS, this will raise an exception:
+ // org.geotools.referencing.wkt.UnformattableObjectException: This "AxisDirection" object is too complex for
+ // WKT syntax.
+ GridCoverage2D raster = RasterConstructors.makeEmptyRaster(
+ 1, "B", 256, 256,
+ -345000.000, 345000.000,
+ 2000, -2000,
+ 0, 0,
+ 3996);
+ testRoundTrip(raster);
+ }
+
private GridCoverage2D testRoundTrip(GridCoverage2D raster) throws IOException, ClassNotFoundException {
return testRoundTrip(raster, 10);
}
diff --git a/docs/setup/compile.md b/docs/setup/compile.md
index 417038c27..49e627775 100644
--- a/docs/setup/compile.md
+++ b/docs/setup/compile.md
@@ -73,11 +73,20 @@ For example,
export SPARK_HOME=$PWD/spark-3.0.1-bin-hadoop2.7
export PYTHONPATH=$SPARK_HOME/python
```
-2. Compile the Sedona Scala and Java code with `-Dgeotools` and then copy the ==sedona-spark-shaded-{{ sedona.current_version }}.jar== to ==SPARK_HOME/jars/== folder.
+2. Put JAI jars to ==SPARK_HOME/jars/== folder.
+```
+export JAI_CORE_VERSION="1.1.3"
+export JAI_CODEC_VERSION="1.1.3"
+export JAI_IMAGEIO_VERSION="1.1"
+wget -P $SPARK_HOME/jars/ https://repo.osgeo.org/repository/release/javax/media/jai_core/${JAI_CORE_VERSION}/jai_core-${JAI_CORE_VERSION}.jar
+wget -P $SPARK_HOME/jars/ https://repo.osgeo.org/repository/release/javax/media/jai_codec/${JAI_CODEC_VERSION}/jai_codec-${JAI_CODEC_VERSION}.jar
+wget -P $SPARK_HOME/jars/ https://repo.osgeo.org/repository/release/javax/media/jai_imageio/${JAI_IMAGEIO_VERSION}/jai_imageio-${JAI_IMAGEIO_VERSION}.jar
+```
+3. Compile the Sedona Scala and Java code with `-Dgeotools` and then copy the ==sedona-spark-shaded-{{ sedona.current_version }}.jar== to ==SPARK_HOME/jars/== folder.
```
cp spark-shaded/target/sedona-spark-shaded-xxx.jar $SPARK_HOME/jars/
```
-3. Install the following libraries
+4. Install the following libraries
```
sudo apt-get -y install python3-pip python-dev libgeos-dev
sudo pip3 install -U setuptools
@@ -86,12 +95,12 @@ sudo pip3 install -U virtualenvwrapper
sudo pip3 install -U pipenv
```
Homebrew can be used to install libgeos-dev in macOS: `brew install geos`
-4. Set up pipenv to the desired Python version: 3.7, 3.8, or 3.9
+5. Set up pipenv to the desired Python version: 3.7, 3.8, or 3.9
```
cd python
pipenv --python 3.7
```
-5. Install the PySpark version and the other dependency
+6. Install the PySpark version and the other dependency
```
cd python
pipenv install pyspark
@@ -99,7 +108,7 @@ pipenv install --dev
```
`pipenv install pyspark` installs the latest version of pyspark.
In order to remain consistent with the installed spark version, use `pipenv install pyspark==<spark_version>`
-6. Run the Python tests
+7. Run the Python tests
```
cd python
pipenv run python setup.py build_ext --inplace
diff --git a/docs/tutorial/raster.md b/docs/tutorial/raster.md
index ee4a922e2..cfc64bc9c 100644
--- a/docs/tutorial/raster.md
+++ b/docs/tutorial/raster.md
@@ -583,6 +583,92 @@ SELECT RS_AsPNG(raster)
Please refer to [Raster writer docs](../../api/sql/Raster-writer) for more details.
+## Collecting raster Dataframes and working with them locally in Python
+
+Sedona allows collecting Dataframes with raster columns and working with them locally in Python since `v1.6.0`.
+The raster objects are represented as `SedonaRaster` objects in Python, which can be used to perform raster operations.
+
+```python
+df_raster = sedona.read.format("binaryFile").load("/path/to/raster.tif").selectExpr("RS_FromGeoTiff(content) as rast")
+rows = df_raster.collect()
+raster = rows[0].rast
+raster # <sedona.raster.sedona_raster.InDbSedonaRaster at 0x1618fb1f0>
+```
+
+You can retrieve the metadata of the raster by accessing the properties of the `SedonaRaster` object.
+
+```python
+raster.width # width of the raster
+raster.height # height of the raster
+raster.affine_trans # affine transformation matrix
+raster.crs_wkt # coordinate reference system as WKT
+```
+
+You can get a numpy array containing the band data of the raster using the `as_numpy` or `as_numpy_masked` method. The
+band data is organized in CHW order.
+
+```python
+raster.as_numpy() # numpy array of the raster
+raster.as_numpy_masked() # numpy array with nodata values masked as nan
+```
+
+If you want to work with the raster data using `rasterio`, you can retrieve a `rasterio.DatasetReader` object using the
+`as_rasterio` method.
+
+```python
+ds = raster.as_rasterio() # rasterio.DatasetReader object
+# Work with the raster using rasterio
+band1 = ds.read(1) # read the first band
+```
+
+## Writing Python UDF to work with raster data
+
+You can write Python UDFs to work with raster data in Python. The UDFs can take `SedonaRaster` objects as input and
+return any Spark data type as output. This is an example of a Python UDF that calculates the mean of the raster data.
+
+```python
+from pyspark.sql.types import DoubleType
+
+def mean_udf(raster):
+ return float(raster.as_numpy().mean())
+
+sedona.udf.register("mean_udf", mean_udf, DoubleType())
+df_raster.withColumn("mean", expr("mean_udf(rast)")).show()
+```
+
+```
++--------------------+------------------+
+| rast| mean|
++--------------------+------------------+
+|GridCoverage2D["g...|1542.8092886117788|
++--------------------+------------------+
+```
+
+It is much trickier to write an UDF that returns a raster object, since Sedona does not support serializing Python raster
+objects yet. However, you can write a UDF that returns the band data as an array and then construct the raster object using
+`RS_MakeRaster`. This is an example of a Python UDF that creates a mask raster based on the first band of the input raster.
+
+```python
+from pyspark.sql.types import ArrayType, DoubleType
+import numpy as np
+
+def mask_udf(raster):
+ band1 = raster.as_numpy()[0,:,:]
+ mask = (band1 < 1400).astype(np.float64)
+ return mask.flatten().tolist()
+
+sedona.udf.register("mask_udf", band_udf, ArrayType(DoubleType()))
+df_raster.withColumn("mask", expr("mask_udf(rast)")).withColumn("mask_rast", expr("RS_MakeRaster(rast, 'I', mask)")).show()
+```
+
+```
++--------------------+--------------------+--------------------+
+| rast| mask| mask_rast|
++--------------------+--------------------+--------------------+
+|GridCoverage2D["g...|[0.0, 0.0, 0.0, 0...|GridCoverage2D["g...|
++--------------------+--------------------+--------------------+
+```
+
## Performance optimization
When working with large raster datasets, refer to the [documentation on storing raster geometries in Parquet format](../storing-blobs-in-parquet) for recommendations to optimize performance.
diff --git a/python/Pipfile b/python/Pipfile
index 110203363..cd7fdee21 100644
--- a/python/Pipfile
+++ b/python/Pipfile
@@ -19,6 +19,7 @@ attrs="*"
pyarrow="*"
keplergl = "==0.3.2"
pydeck = "===0.8.0"
+rasterio = ">=1.2.10"
[requires]
python_version = "3.7"
diff --git a/python/sedona/raster/__init__.py b/python/sedona/raster/__init__.py
new file mode 100644
index 000000000..a67d5ea25
--- /dev/null
+++ b/python/sedona/raster/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/python/sedona/raster/awt_raster.py b/python/sedona/raster/awt_raster.py
new file mode 100644
index 000000000..d95135943
--- /dev/null
+++ b/python/sedona/raster/awt_raster.py
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from .data_buffer import DataBuffer
+from .sample_model import SampleModel
+
+
+class AWTRaster:
+ """Raster data structure of Java AWT Raster used by GeoTools GridCoverage2D.
+
+ """
+ min_x: int
+ min_y: int
+ width: int
+ height: int
+ sample_model: SampleModel
+ data_buffer: DataBuffer
+
+ def __init__(self, min_x, min_y, width, height, sample_model: SampleModel, data_buffer: DataBuffer):
+ if sample_model.width != width or sample_model.height != height:
+ raise RuntimeError("Size of the image does not match with the sample model")
+ self.min_x = min_x
+ self.min_y = min_y
+ self.width = width
+ self.height = height
+ self.sample_model = sample_model
+ self.data_buffer = data_buffer
diff --git a/python/sedona/raster/data_buffer.py b/python/sedona/raster/data_buffer.py
new file mode 100644
index 000000000..8826e26bd
--- /dev/null
+++ b/python/sedona/raster/data_buffer.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Any
+import numpy as np
+
+
+class DataBuffer:
+ TYPE_BYTE = 0
+ TYPE_USHORT = 1
+ TYPE_SHORT = 2
+ TYPE_INT = 3
+ TYPE_FLOAT = 4
+ TYPE_DOUBLE = 5
+
+ data_type: int
+ bank_data: List[np.ndarray]
+ size: int
+ offsets: List[int]
+
+ def __init__(self, data_type: int, bank_data: List[np.ndarray], size: int, offsets: List[int]):
+ self.data_type = data_type
+ self.bank_data = bank_data
+ self.size = size
+ self.offsets = offsets
diff --git a/python/sedona/raster/meta.py b/python/sedona/raster/meta.py
new file mode 100644
index 000000000..b0013359d
--- /dev/null
+++ b/python/sedona/raster/meta.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from enum import Enum
+from typing import List, Dict, Optional
+
+
+class PixelAnchor(Enum):
+ """Anchor of the pixel cell. GeoTools anchors the coordinates at the center
+ of pixels, while GDAL anchors the coordinates at the upper-left corner of
+ the pixels. This difference requires us to convert the affine
+ transformation between these conventions.
+
+ """
+ CENTER = 1
+ UPPER_LEFT = 2
+
+
+class AffineTransform:
+ scale_x: float
+ skew_y: float
+ skew_x: float
+ scale_y: float
+ ip_x: float
+ ip_y: float
+ pixel_anchor: PixelAnchor
+
+ def __init__(self, scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, pixel_anchor: PixelAnchor):
+ self.scale_x = scale_x
+ self.skew_y = skew_y
+ self.skew_x = skew_x
+ self.scale_y = scale_y
+ self.ip_x = ip_x
+ self.ip_y = ip_y
+ self.pixel_anchor = pixel_anchor
+
+ def with_anchor(self, pixel_anchor: PixelAnchor):
+ if pixel_anchor == self.pixel_anchor:
+ return self
+ return self._do_change_pixel_anchor(self.pixel_anchor, pixel_anchor)
+
+ def translate(self, offset_x: float, offset_y: float):
+ new_ipx = self.ip_x + offset_x * self.scale_x + offset_y * self.skew_x
+ new_ipy = self.ip_y + offset_x * self.skew_y + offset_y * self.scale_y
+ return AffineTransform(self.scale_x, self.skew_y, self.skew_x, self.scale_y,
+ new_ipx, new_ipy, self.pixel_anchor)
+
+ def _do_change_pixel_anchor(self, from_anchor: PixelAnchor, to_anchor: PixelAnchor):
+ assert from_anchor != to_anchor
+ if from_anchor == PixelAnchor.CENTER:
+ m00 = 1.0
+ m10 = 0.0
+ m01 = 0.0
+ m11 = 1.0
+ m02 = -0.5
+ m12 = -0.5
+ else:
+ m00 = 1.0
+ m10 = 0.0
+ m01 = 0.0
+ m11 = 1.0
+ m02 = 0.5
+ m12 = 0.5
+
+ old_m00 = self.scale_x
+ old_m10 = self.skew_y
+ old_m01 = self.skew_x
+ old_m11 = self.scale_y
+ old_m02 = self.ip_x
+ old_m12 = self.ip_y
+ new_m00 = old_m00 * m00 + old_m01 * m10
+ new_m01 = old_m00 * m01 + old_m01 * m11
+ new_m02 = old_m00 * m02 + old_m01 * m12 + old_m02
+ new_m10 = old_m10 * m00 + old_m11 * m10
+ new_m11 = old_m10 * m01 + old_m11 * m11
+ new_m12 = old_m10 * m02 + old_m11 * m12 + old_m12
+ return AffineTransform(new_m00, new_m10, new_m01, new_m11, new_m02, new_m12, to_anchor)
+
+ def __repr__(self):
+ return ("[ {} {} {}\n".format(self.scale_x, self.skew_x, self.ip_x) +
+ " {} {} {}\n".format(self.skew_y, self.scale_y, self.ip_y) +
+ " 0 0 1 ]")
+
+
+class SampleDimension:
+ """Raster band metadata.
+
+ """
+ description: str
+ offset: float
+ scale: float
+ nodata: float
+
+ def __init__(self, description, offset, scale, nodata):
+ self.description = description
+ self.offset = offset
+ self.scale = scale
+ self.nodata = nodata
diff --git a/python/sedona/raster/raster_serde.py b/python/sedona/raster/raster_serde.py
new file mode 100644
index 000000000..63b740c5a
--- /dev/null
+++ b/python/sedona/raster/raster_serde.py
@@ -0,0 +1,180 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union, Tuple, List, Dict
+from io import BytesIO
+import struct
+import zlib
+import numpy as np
+
+from .sample_model import SampleModel, ComponentSampleModel, PixelInterleavedSampleModel, MultiPixelPackedSampleModel, SinglePixelPackedSampleModel
+from .data_buffer import DataBuffer
+from .awt_raster import AWTRaster
+from .meta import AffineTransform, PixelAnchor, SampleDimension
+from .sedona_raster import SedonaRaster, InDbSedonaRaster
+
+
+class RasterTypes:
+ IN_DB = 0
+
+
+def deserialize(buf: Union[bytearray, bytes]) -> Optional[SedonaRaster]:
+ if buf is None:
+ return None
+
+ bio = BytesIO(buf)
+ raster_type = int(bio.read(1)[0])
+ return _deserialize(bio, raster_type)
+
+
+def _deserialize(bio: BytesIO, raster_type: int) -> SedonaRaster:
+ name = _read_utf8_string(bio)
+ width, height, x, y = _read_grid_envelope(bio)
+ affine_trans = _read_affine_transformation(bio)
+ affine_trans = affine_trans.translate(x, y)
+ affine_trans = affine_trans.with_anchor(PixelAnchor.UPPER_LEFT)
+ crs_wkt = _read_crs_wkt(bio)
+ bands_meta = _read_sample_dimensions(bio)
+ if raster_type == RasterTypes.IN_DB:
+ # In-DB raster
+ awt_raster = _read_awt_raster(bio)
+ return InDbSedonaRaster(width, height, bands_meta, affine_trans, crs_wkt, awt_raster)
+ else:
+ raise ValueError("unsupported raster_type: {}".format(raster_type))
+
+
+def _read_grid_envelope(bio: BytesIO) -> Tuple[int, int, int, int]:
+ width, height, x, y = struct.unpack("=iiii", bio.read(4 * 4))
+ return (width, height, x, y)
+
+
+def _read_affine_transformation(bio: BytesIO) -> AffineTransform:
+ scale_x, skew_y, skew_x, scale_y, ip_x, ip_y = struct.unpack("=dddddd", bio.read(8 * 6))
+ return AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER)
+
+
+def _read_crs_wkt(bio: BytesIO) -> str:
+ size, = struct.unpack("=i", bio.read(4))
+ compressed_wkt = bio.read(size)
+ crs_wkt = zlib.decompress(compressed_wkt)
+ return crs_wkt.decode('utf-8')
+
+
+def _read_sample_dimensions(bio: BytesIO) -> List[SampleDimension]:
+ num_bands, = struct.unpack("=i", bio.read(4))
+ bands_meta = []
+ for i in range(num_bands):
+ description = _read_utf8_string(bio)
+ offset, scale, nodata = struct.unpack("=ddd", bio.read(8 * 3))
+ _ignore_java_object(bio)
+ bands_meta.append(SampleDimension(description, offset, scale, nodata))
+ return bands_meta
+
+
+def _read_awt_raster(bio: BytesIO) -> AWTRaster:
+ min_x, min_y, width, height = struct.unpack("=iiii", bio.read(4 * 4))
+ _ignore_java_object(bio) # image properties
+ _ignore_java_object(bio) # color model
+ min_x_1, min_y_1 = struct.unpack("=ii", bio.read(4 * 2))
+ if min_x_1 != min_x or min_y_1 != min_y:
+ raise RuntimeError("malformed serialized raster: minx/miny of the image cannot match with minx/miny of the AWT raster")
+ sample_model = _read_sample_model(bio)
+ data_buffer = _read_data_buffer(bio)
+ return AWTRaster(min_x, min_y, width, height, sample_model, data_buffer)
+
+
+def _read_sample_model(bio: BytesIO) -> SampleModel:
+ sample_model_type, data_type, width, height = struct.unpack("=iiii", bio.read(4 * 4))
+ if sample_model_type == SampleModel.TYPE_BANDED:
+ bank_indices = _read_int_array(bio)
+ band_offsets = _read_int_array(bio)
+ return ComponentSampleModel(data_type, width, height, 1, width, bank_indices, band_offsets)
+ elif sample_model_type == SampleModel.TYPE_PIXEL_INTERLEAVED:
+ pixel_stride, scanline_stride = struct.unpack("=ii", bio.read(4 * 2))
+ band_offsets = _read_int_array(bio)
+ return PixelInterleavedSampleModel(data_type, width, height, pixel_stride, scanline_stride, band_offsets)
+ elif sample_model_type in [SampleModel.TYPE_COMPONENT, SampleModel.TYPE_COMPONENT_JAI]:
+ pixel_stride, scanline_stride = struct.unpack("=ii", bio.read(4 * 2))
+ bank_indices = _read_int_array(bio)
+ band_offsets = _read_int_array(bio)
+ return ComponentSampleModel(data_type, width, height, pixel_stride, scanline_stride, bank_indices, band_offsets)
+ elif sample_model_type == SampleModel.TYPE_SINGLE_PIXEL_PACKED:
+ scanline_stride, = struct.unpack("=i", bio.read(4))
+ bit_masks = _read_int_array(bio)
+ return SinglePixelPackedSampleModel(data_type, width, height, scanline_stride, bit_masks)
+ elif sample_model_type == SampleModel.TYPE_MULTI_PIXEL_PACKED:
+ num_bits, scanline_stride, data_bit_offset = struct.unpack("=iii", bio.read(4 * 3))
+ return MultiPixelPackedSampleModel(data_type, width, height, num_bits, scanline_stride, data_bit_offset)
+ else:
+ raise RuntimeError(f"Unsupported SampleModel type: {sample_model_type}")
+
+
+def _read_data_buffer(bio: BytesIO) -> DataBuffer:
+ data_type, = struct.unpack("=i", bio.read(4))
+ offsets = _read_int_array(bio)
+ size, = struct.unpack("=i", bio.read(4))
+
+ num_banks, = struct.unpack("=i", bio.read(4))
+ banks = []
+ for i in range(num_banks):
+ bank_size, = struct.unpack("=i", bio.read(4))
+ if data_type == DataBuffer.TYPE_BYTE:
+ np_array = np.frombuffer(bio.read(bank_size), dtype=np.uint8)
+ elif data_type == DataBuffer.TYPE_SHORT:
+ np_array = np.frombuffer(bio.read(2 * bank_size), dtype=np.int16)
+ elif data_type == DataBuffer.TYPE_USHORT:
+ np_array = np.frombuffer(bio.read(2 * bank_size), dtype=np.uint16)
+ elif data_type == DataBuffer.TYPE_INT:
+ np_array = np.frombuffer(bio.read(4 * bank_size), dtype=np.int32)
+ elif data_type == DataBuffer.TYPE_FLOAT:
+ np_array = np.frombuffer(bio.read(4 * bank_size), dtype=np.float32)
+ elif data_type == DataBuffer.TYPE_DOUBLE:
+ np_array = np.frombuffer(bio.read(8 * bank_size), dtype=np.float64)
+ else:
+ raise ValueError("unknown data_type {}".format(data_type))
+
+ banks.append(np_array)
+
+ return DataBuffer(data_type, banks, size, offsets)
+
+
+def _read_utf8_string(bio: BytesIO) -> str:
+ size, = struct.unpack("=i", bio.read(4))
+ utf8_bytes = bio.read(size)
+ return utf8_bytes.decode('utf-8')
+
+
+def _ignore_java_object(bio: BytesIO):
+ size, = struct.unpack("=i", bio.read(4))
+ bio.read(size)
+
+
+def _read_int_array(bio: BytesIO) -> List[int]:
+ length, = struct.unpack("=i", bio.read(4))
+ return [struct.unpack("=i", bio.read(4))[0] for _ in range(length)]
+
+
+def _read_utf8_string_map(bio: BytesIO) -> Optional[Dict[str, str]]:
+ size, = struct.unpack("=i", bio.read(4))
+ if size == -1:
+ return None
+ params = {}
+ for _ in range(size):
+ key = _read_utf8_string(bio)
+ value = _read_utf8_string(bio)
+ params[key] = value
+ return params
diff --git a/python/sedona/raster/sample_model.py b/python/sedona/raster/sample_model.py
new file mode 100644
index 000000000..4c5ac193e
--- /dev/null
+++ b/python/sedona/raster/sample_model.py
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List
+from abc import ABC, abstractmethod
+import numpy as np
+
+from .data_buffer import DataBuffer
+
+
+class SampleModel(ABC):
+ """The SampleModel class and its subclasses are defined according to the data structure of
+ SampleModel class in Java AWT.
+
+ """
+ TYPE_BANDED = 1
+ TYPE_PIXEL_INTERLEAVED = 2
+ TYPE_SINGLE_PIXEL_PACKED = 3
+ TYPE_MULTI_PIXEL_PACKED = 4
+ TYPE_COMPONENT_JAI = 5
+ TYPE_COMPONENT = 6
+
+ sample_model_type: int
+ data_type: int
+ width: int
+ height: int
+
+ def __init__(self, sample_model_type, data_type, width, height):
+ self.sample_model_type = sample_model_type
+ self.data_type = data_type
+ self.width = width
+ self.height = height
+
+ @abstractmethod
+ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
+ raise NotImplementedError("Abstract method as_numpy was not implemented by subclass")
+
+
+class ComponentSampleModel(SampleModel):
+ pixel_stride: int
+ scanline_stride: int
+ bank_indices: List[int]
+ band_offsets: List[int]
+
+ def __init__(self, data_type, width, height, pixel_stride, scanline_stride, bank_indices, band_offsets):
+ super().__init__(SampleModel.TYPE_COMPONENT, data_type, width, height)
+ self.pixel_stride = pixel_stride
+ self.scanline_stride = scanline_stride
+ self.bank_indices = bank_indices
+ self.band_offsets = band_offsets
+
+ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
+ if self.scanline_stride == self.width and self.pixel_stride == 1:
+ # Fast path: no gaps between pixels
+ band_arrs = []
+ for bank_index in self.bank_indices:
+ bank_data = data_buffer.bank_data[bank_index]
+ offset = self.band_offsets[bank_index]
+ if offset != 0:
+ bank_data = bank_data[offset:(offset + self.width * self.height)]
+ band_arr = bank_data.reshape(self.height, self.width)
+ band_arrs.append(band_arr)
+ return np.array(band_arrs)
+ else:
+ # Slow path
+ band_arrs = []
+ for k in range(len(self.bank_indices)):
+ bank_index = self.bank_indices[k]
+ bank_data = data_buffer.bank_data[bank_index]
+ offset = self.band_offsets[k]
+ band_pixel_data = []
+ for y in range(self.height):
+ for x in range(self.width):
+ pos = offset + y * self.scanline_stride + x * self.pixel_stride
+ band_pixel_data.append(bank_data[pos])
+ arr = np.array(band_pixel_data).reshape(self.height, self.width)
+ band_arrs.append(arr)
+
+ return np.array(band_arrs)
+
+
+class PixelInterleavedSampleModel(SampleModel):
+ pixel_stride: int
+ scanline_stride: int
+ band_offsets: List[int]
+
+ def __init__(self, data_type, width, height, pixel_stride, scanline_stride, band_offsets):
+ super().__init__(SampleModel.TYPE_PIXEL_INTERLEAVED, data_type, width, height)
+ self.pixel_stride = pixel_stride
+ self.scanline_stride = scanline_stride
+ self.band_offsets = band_offsets
+
+ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
+ num_bands = len(self.band_offsets)
+ bank_data = data_buffer.bank_data[0]
+ if self.pixel_stride == num_bands and \
+ self.scanline_stride == self.width * num_bands and \
+ self.band_offsets == list(range(0, num_bands)):
+ # Fast path: no gapping in between band data, no band reordering
+ arr = bank_data.reshape(self.height, self.width, num_bands)
+ return np.transpose(arr, [2, 0, 1])
+ else:
+ # Slow path
+ pixel_data = []
+ for y in range(self.height):
+ for x in range(self.width):
+ begin = y * self.scanline_stride + x * self.pixel_stride
+ end = begin + num_bands
+ pixel = bank_data[begin:end][self.band_offsets]
+ pixel_data.append(pixel)
+ arr = np.array(pixel_data).reshape(self.height, self.width, num_bands)
+ return np.transpose(arr, [2, 0, 1])
+
+
+class SinglePixelPackedSampleModel(SampleModel):
+ scanline_stride: int
+ bit_masks: List[int]
+ bit_offsets: List[int]
+
+ def __init__(self, data_type, width, height, scanline_stride, bit_masks):
+ super().__init__(SampleModel.TYPE_SINGLE_PIXEL_PACKED, data_type, width, height)
+ self.scanline_stride = scanline_stride
+ self.bit_masks = bit_masks
+ self.bit_offsets = []
+ for v in self.bit_masks:
+ self.bit_offsets.append((v & -v).bit_length() - 1)
+
+ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
+ num_bands = len(self.bit_masks)
+ bank_data = data_buffer.bank_data[0]
+ pixel_data = []
+ for y in range(self.height):
+ for x in range(self.width):
+ pos = y * self.scanline_stride + x
+ value = bank_data[pos]
+ pixel = []
+ for mask, bit_offset in zip(self.bit_masks, self.bit_offsets):
+ pixel.append((value & mask) >> bit_offset)
+ pixel_data.append(pixel)
+ arr = np.array(pixel_data, dtype=bank_data.dtype).reshape(self.height, self.width, num_bands)
+ return np.transpose(arr, [2, 0, 1])
+
+
+class MultiPixelPackedSampleModel(SampleModel):
+ num_bits: int
+ scanline_stride: int
+ data_bit_offset: int
+
+ def __init__(self, data_type, width, height, num_bits, scanline_stride, data_bit_offset):
+ super().__init__(SampleModel.TYPE_MULTI_PIXEL_PACKED, data_type, width, height)
+ self.num_bits = num_bits
+ self.scanline_stride = scanline_stride
+ self.data_bit_offset = data_bit_offset
+
+ def as_numpy(self, data_buffer: DataBuffer) -> np.ndarray:
+ bank_data = data_buffer.bank_data[0]
+ bits_per_value = bank_data.dtype.itemsize * 8
+ pixel_per_value = bits_per_value / self.num_bits
+ shift_right = bits_per_value - self.num_bits
+ mask = ((1 << self.num_bits) - 1) << shift_right
+
+ band_data = []
+ for y in range(self.height):
+ pos = y * self.scanline_stride + self.data_bit_offset // bits_per_value
+ value = bank_data[pos]
+ shift = self.data_bit_offset % bits_per_value
+ value = (value << shift)
+ pixels: List[int] = []
+ while len(pixels) < self.width:
+ while shift < bits_per_value and len(pixels) < self.width:
+ pixels.append((value & mask) >> shift_right)
+ value = (value << self.num_bits)
+ shift += self.num_bits
+ pos += 1
+ value = bank_data[pos]
+ shift = 0
+ band_data.append(np.array(pixels, dtype=bank_data.dtype))
+
+ return np.array(band_data).reshape(1, self.height, self.width)
diff --git a/python/sedona/raster/sedona_raster.py b/python/sedona/raster/sedona_raster.py
new file mode 100644
index 000000000..e5ecb3723
--- /dev/null
+++ b/python/sedona/raster/sedona_raster.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Dict, Optional
+from abc import ABC, abstractmethod
+from xml.etree.ElementTree import Element, SubElement, tostring
+
+import numpy as np
+import rasterio # type: ignore
+import rasterio.env # type: ignore
+from rasterio.transform import Affine # type: ignore
+from rasterio.io import MemoryFile # type: ignore
+from rasterio.io import DatasetReader # type: ignore
+
+try:
+ # for rasterio >= 1.3.0
+ from rasterio._path import _parse_path as parse_path # type: ignore
+except:
+ # for rasterio >= 1.2.0
+ from rasterio.path import parse_path # type: ignore
+
+from .awt_raster import AWTRaster
+from .data_buffer import DataBuffer
+from .meta import AffineTransform, PixelAnchor
+from .meta import SampleDimension
+
+
+def _rasterio_open(fp, driver=None):
+ """A variant of rasterio.open. This function skip setting up a new GDAL env
+ when there is already an environment. This saves us lots of overhead
+ introduced by GDAL env initialization.
+
+ """
+ if rasterio.env.hasenv():
+ # There is already an env, so we can get rid of the overhead of
+ # GDAL env initialization in rasterio.open().
+ return DatasetReader(parse_path(fp), driver=driver)
+ else:
+ return rasterio.open(fp, mode="r", driver=driver)
+
+
+def _generate_vrt_xml(src_path, data_type, width, height, geo_transform, crs_wkt, off_x, off_y, band_indices) -> bytes:
+ # Create root element
+ root = Element('VRTDataset')
+ root.set('rasterXSize', str(width))
+ root.set('rasterYSize', str(height))
+
+ # Add CRS
+ if crs_wkt is not None and crs_wkt != '':
+ srs = SubElement(root, 'SRS')
+ srs.text = crs_wkt
+
+ # Add GeoTransform
+ gt = SubElement(root, 'GeoTransform')
+ gt.text = geo_transform
+
+ # Add bands
+ for i, band_index in enumerate(band_indices, start=1):
+ band = SubElement(root, 'VRTRasterBand')
+ band.set('dataType', data_type)
+ band.set('band', str(i))
+
+ # Add source
+ source = SubElement(band, 'SimpleSource')
+ src_prop = SubElement(source, 'SourceFilename')
+ src_prop.text = src_path
+
+ # Set source properties
+ SubElement(source, 'SourceBand').text = str(band_index + 1)
+ SubElement(source, 'SrcRect', {'xOff': str(off_x), 'yOff': str(off_y), 'xSize': str(width), 'ySize': str(height)})
+ SubElement(source, 'DstRect', {'xOff': '0', 'yOff': '0', 'xSize': str(width), 'ySize': str(height)})
+
+ # Generate pretty XML
+ xml_bytes = tostring(root, encoding='utf-8')
+ return xml_bytes
+
+
+class SedonaRaster(ABC):
+ _width: int
+ _height: int
+ _bands_meta: List[SampleDimension]
+ _affine_trans: AffineTransform
+ _crs_wkt: str
+
+ def __init__(self, width: int, height: int, bands_meta: List[SampleDimension],
+ affine_trans: AffineTransform, crs_wkt: str):
+ self._width = width
+ self._height = height
+ self._bands_meta = bands_meta
+ self._affine_trans = affine_trans
+ self._crs_wkt = crs_wkt
+
+ @property
+ def width(self) -> int:
+ """Width of the raster in pixel"""
+ return self._width
+
+ @property
+ def height(self) -> int:
+ """Height of the raster in pixel"""
+ return self._height
+
+ @property
+ def crs_wkt(self) -> str:
+ """CRS of the raster as a WKT string"""
+ return self._crs_wkt
+
+ @property
+ def bands_meta(self) -> List[SampleDimension]:
+ """Metadata of bands, including nodata value for each band"""
+ return self._bands_meta
+
+ @property
+ def affine_trans(self) -> AffineTransform:
+ """Geo transform of the raster"""
+ return self._affine_trans
+
+ @abstractmethod
+ def as_numpy(self) -> np.ndarray:
+ """Get the bands data as an numpy array in CHW layout
+
+ """
+ raise NotImplementedError()
+
+ def as_numpy_masked(self) -> np.ndarray:
+ """Get the bands data as an numpy array in CHW layout, with nodata
+ values masked as nan.
+
+ """
+ arr = self.as_numpy()
+ nodata_values = np.array([bm.nodata for bm in self._bands_meta])
+ nodata_values_reshaped = nodata_values[:, None, None]
+ mask = arr == nodata_values_reshaped
+ masked_arr = np.where(mask, np.nan, arr)
+ return masked_arr
+
+ @abstractmethod
+ def as_rasterio(self) -> DatasetReader:
+ """Retrieve the raster as an rasterio DatasetReader
+
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def close(self):
+ """Release all resources allocated for this sedona raster. The rasterio
+ DatasetReader returned by as_rasterio() will also be closed.
+
+ """
+ raise NotImplementedError()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def __del__(self):
+ self.close()
+
+
+class InDbSedonaRaster(SedonaRaster):
+ awt_raster: AWTRaster
+ rasterio_memfile: Optional[MemoryFile]
+ rasterio_dataset_reader: Optional[DatasetReader]
+
+ def __init__(self, width: int, height: int, bands_meta: List[SampleDimension],
+ affine_trans: AffineTransform, crs_wkt: str,
+ awt_raster: AWTRaster):
+ super().__init__(width, height, bands_meta, affine_trans, crs_wkt)
+ self.awt_raster = awt_raster
+ self.rasterio_memfile = None
+ self.rasterio_dataset_reader = None
+
+ def as_numpy(self) -> np.ndarray:
+ sm = self.awt_raster.sample_model
+ return sm.as_numpy(self.awt_raster.data_buffer)
+
+ def as_rasterio(self) -> DatasetReader:
+ if self.rasterio_dataset_reader is not None:
+ return self.rasterio_dataset_reader
+
+ affine = Affine.from_gdal(
+ self._affine_trans.ip_x, self._affine_trans.scale_x, self._affine_trans.skew_x,
+ self._affine_trans.ip_y, self._affine_trans.skew_y, self._affine_trans.scale_y)
+ num_bands = len(self._bands_meta)
+
+ data_array = np.ascontiguousarray(self.as_numpy())
+
+ dtype = data_array.dtype
+ if dtype == np.uint8:
+ data_type = 'Byte'
+ elif dtype == np.int8:
+ data_type = 'Int8'
+ elif dtype == np.uint16:
+ data_type = 'Uint16'
+ elif dtype == np.int16:
+ data_type = 'Int16'
+ elif dtype == np.uint32:
+ data_type = 'UInt32'
+ elif dtype == np.int32:
+ data_type = 'Int32'
+ elif dtype == np.float32:
+ data_type = 'Float32'
+ elif dtype == np.float64:
+ data_type = 'Float64'
+ elif dtype == np.int64:
+ data_type = 'Int64'
+ elif dtype == np.uint64:
+ data_type = 'Uint64'
+ else:
+ raise RuntimeError("unknown dtype: " + str(dtype))
+
+ arr_if = data_array.__array_interface__
+ data_pointer = arr_if['data'][0]
+ geotransform = (f"{self._affine_trans.ip_x}/{self._affine_trans.scale_x}/{self._affine_trans.skew_x}/" +
+ f"{self._affine_trans.ip_y}/{self._affine_trans.skew_y}/{self._affine_trans.scale_y}")
+ # FIXME: GDAL 3.6 shipped with rasterio does not support
+ # SPATIALREFERENCE parameter, so we have to workaround this issue in a
+ # hacky way. If newer versions of rasterio bundle GDAL 3.7 then this
+ # won't be a problem. See https://gdal.org/drivers/raster/mem.html
+ desc = (f"MEM:::DATAPOINTER={data_pointer},PIXELS={self._width},LINES={self._height},BANDS={num_bands}," +
+ f"DATATYPE={data_type},GEOTRANSFORM={geotransform}")
+
+ # construct a VRT to wrap this MEM dataset, with SRS set up properly
+ vrt_xml = _generate_vrt_xml(
+ desc, data_type, self._width, self._height, geotransform.replace('/', ','), self._crs_wkt,
+ 0, 0, list(range(num_bands)))
+
+ # dataset = _rasterio_open(desc, driver="MEM")
+ self.rasterio_memfile = MemoryFile(vrt_xml, ext='.vrt')
+ dataset = self.rasterio_memfile.open(driver='VRT')
+
+ # XXX: dataset does not copy the data held by data_array, so we set
+ # data_array as a property of dataset to make sure that the lifetime of
+ # data_array is as long as dataset, otherwise we may see band data
+ # corruption.
+ dataset.mem_data_array = data_array
+ return dataset
+
+ def close(self):
+ if self.rasterio_dataset_reader is not None:
+ self.rasterio_dataset_reader.close()
+ self.rasterio_dataset_reader = None
+ if self.rasterio_memfile is not None:
+ self.rasterio_memfile.close()
+ self.rasterio_memfile = None
diff --git a/python/sedona/sql/types.py b/python/sedona/sql/types.py
index 36e22e17f..239f19df8 100644
--- a/python/sedona/sql/types.py
+++ b/python/sedona/sql/types.py
@@ -18,6 +18,8 @@
from pyspark.sql.types import UserDefinedType, BinaryType
from ..utils import geometry_serde
+from ..raster import raster_serde
+from ..raster.sedona_raster import SedonaRaster
class GeometryType(UserDefinedType):
@@ -55,7 +57,7 @@ class RasterType(UserDefinedType):
raise NotImplementedError("RasterType.serialize is not implemented yet")
def deserialize(self, datum):
- raise NotImplementedError("RasterType.deserialize is not implemented yet")
+ return raster_serde.deserialize(datum)
@classmethod
def module(cls):
@@ -67,3 +69,6 @@ class RasterType(UserDefinedType):
@classmethod
def scalaUDT(cls):
return "org.apache.spark.sql.sedona_sql.UDT.RasterUDT"
+
+
+SedonaRaster.__UDT__ = RasterType()
diff --git a/python/setup.py b/python/setup.py
index 7576957d0..6429499cb 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -52,7 +52,7 @@ setup(
long_description=long_description,
long_description_content_type="text/markdown",
python_requires='>=3.6',
- install_requires=['attrs', "shapely>=1.7.0"],
+ install_requires=['attrs', "shapely>=1.7.0", "rasterio>=1.2.10"],
extras_require={
"spark": ["pyspark>=2.3.0"],
"pydeck-map": ["pandas<=1.3.5", "geopandas<=0.10.2", "pydeck==0.8.0"],
diff --git a/python/tests/raster/__init__.py b/python/tests/raster/__init__.py
new file mode 100644
index 000000000..a67d5ea25
--- /dev/null
+++ b/python/tests/raster/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/python/tests/raster/test_meta.py b/python/tests/raster/test_meta.py
new file mode 100644
index 000000000..68135ba25
--- /dev/null
+++ b/python/tests/raster/test_meta.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from pytest import approx
+
+from sedona.raster.meta import AffineTransform
+from sedona.raster.meta import PixelAnchor
+
+
+class TestAffineTransform:
+
+ def test_change_anchor_to_upper_left(self):
+ scale_x = 10.0
+ skew_y = 1.0
+ skew_x = 2.0
+ scale_y = -8.0
+ ip_x = 100
+ ip_y = 200
+
+ trans = AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.CENTER)
+ trans_gdal = trans.with_anchor(PixelAnchor.UPPER_LEFT)
+ assert trans_gdal.scale_x == approx(scale_x)
+ assert trans_gdal.scale_y == approx(scale_y)
+ assert trans_gdal.skew_x == approx(skew_x)
+ assert trans_gdal.skew_y == approx(skew_y)
+ assert trans_gdal.ip_x == approx(94.0)
+ assert trans_gdal.ip_y == approx(203.5)
+
+ def test_change_anchor_to_center(self):
+ scale_x = 10.0
+ skew_y = 1.0
+ skew_x = 2.0
+ scale_y = -8.0
+ ip_x = 100
+ ip_y = 200
+
+ trans_gdal = AffineTransform(scale_x, skew_y, skew_x, scale_y, ip_x, ip_y, PixelAnchor.UPPER_LEFT)
+ trans = trans_gdal.with_anchor(PixelAnchor.CENTER)
+ assert trans.scale_x == approx(scale_x)
+ assert trans.scale_y == approx(scale_y)
+ assert trans.skew_x == approx(skew_x)
+ assert trans.skew_y == approx(skew_y)
+ assert trans.ip_x == approx(106.0)
+ assert trans.ip_y == approx(196.5)
diff --git a/python/tests/raster/test_pandas_udf.py b/python/tests/raster/test_pandas_udf.py
new file mode 100644
index 000000000..8e7304941
--- /dev/null
+++ b/python/tests/raster/test_pandas_udf.py
@@ -0,0 +1,76 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+from tests.test_base import TestBase
+from pyspark.sql.functions import expr, pandas_udf
+from pyspark.sql.types import IntegerType
+import pyspark
+import pandas as pd
+import numpy as np
+import rasterio
+
+from tests import world_map_raster_input_location
+
+class TestRasterPandasUDF(TestBase):
+ @pytest.mark.skipif(pyspark.__version__ < '3.4', reason="requires Spark 3.4 or higher")
+ def test_raster_as_param(self):
+ spark = TestRasterPandasUDF.spark
+ df = spark.range(10).withColumn("rast", expr("RS_MakeRasterForTesting(1, 'I', 'PixelInterleavedSampleModel', 4, 3, 100, 100, 10, -10, 0, 0, 3857)"))
+
+ # A Python Pandas UDF that takes a raster as input
+ @pandas_udf(IntegerType())
+ def pandas_udf_raster_as_param(s: pd.Series) -> pd.Series:
+ from sedona.raster import raster_serde
+
+ def func(x):
+ with raster_serde.deserialize(x) as raster:
+ arr = raster.as_numpy()
+ return int(np.sum(arr))
+
+ return s.apply(func)
+
+ # A Python Pandas UDF that takes a raster as input
+ @pandas_udf(IntegerType())
+ def pandas_udf_raster_as_param_2(s: pd.Series) -> pd.Series:
+ from sedona.raster import raster_serde
+
+ def func(x):
+ with raster_serde.deserialize(x) as raster:
+ ds = raster.as_rasterio()
+ return int(np.sum(ds.read(1)))
+
+ # wrap s.apply() with a rasterio env to get rid of the overhead of repeated
+ # env initialization in as_rasterio()
+ with rasterio.Env():
+ return s.apply(func)
+
+ spark.udf.register("pandas_udf_raster_as_param", pandas_udf_raster_as_param)
+ spark.udf.register("pandas_udf_raster_as_param_2", pandas_udf_raster_as_param_2)
+
+ df_result = df.selectExpr("pandas_udf_raster_as_param(rast) as res")
+ rows = df_result.collect()
+ assert len(rows) == 10
+ for row in rows:
+ assert row['res'] == 66
+
+ df_result = df.selectExpr("pandas_udf_raster_as_param_2(rast) as res")
+ rows = df_result.collect()
+ assert len(rows) == 10
+ for row in rows:
+ assert row['res'] == 66
diff --git a/python/tests/raster/test_serde.py b/python/tests/raster/test_serde.py
new file mode 100644
index 000000000..dc94b0109
--- /dev/null
+++ b/python/tests/raster/test_serde.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import rasterio
+import numpy as np
+
+from tests.test_base import TestBase
+from pyspark.sql.functions import expr
+from sedona.sql.types import RasterType
+
+from tests import world_map_raster_input_location
+
+class TestRasterSerde(TestBase):
+ def test_empty_raster(self):
+ df = TestRasterSerde.spark.sql("SELECT RS_MakeEmptyRaster(2, 100, 200, 1000, 2000, 1) as raster")
+ raster = df.first()[0]
+ assert raster.width == 100 and raster.height == 200 and len(raster.bands_meta) == 2
+ assert raster.affine_trans.ip_x == 1000 and raster.affine_trans.ip_y == 2000
+ assert raster.affine_trans.scale_x == 1 and raster.affine_trans.scale_y == -1
+
+ def test_banded_sample_model(self):
+ df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster")
+ raster = df.first()[0]
+ assert raster.width == 10 and raster.height == 8 and len(raster.bands_meta) == 3
+ self.validate_test_raster(raster)
+
+ def test_pixel_interleaved_sample_model(self):
+ df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'PixelInterleavedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster")
+ raster = df.first()[0]
+ assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 3
+ self.validate_test_raster(raster)
+ df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'PixelInterleavedSampleModelComplex', 8, 10, 100, 100, 10, -10, 0, 0, 3857) as raster")
+ raster = df.first()[0]
+ assert raster.width == 8 and raster.height == 10 and len(raster.bands_meta) == 4
+ self.validate_test_raster(raster)
+
+ def test_component_sample_model(self):
+ for pixel_type in ['B', 'S', 'US', 'I', 'F', 'D']:
+ df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, '{}', 'ComponentSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster".format(pixel_type))
+ raster = df.first()[0]
+ assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 4
+ self.validate_test_raster(raster)
+
+ def test_multi_pixel_packed_sample_model(self):
+ df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(1, 'B', 'MultiPixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster")
+ raster = df.first()[0]
+ assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 1
+ self.validate_test_raster(raster, packed=True)
+
+ def test_single_pixel_packed_sample_model(self):
+ df = TestRasterSerde.spark.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'SinglePixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster")
+ raster = df.first()[0]
+ assert raster.width == 10 and raster.height == 10 and len(raster.bands_meta) == 4
+ self.validate_test_raster(raster, packed=True)
+
+ def test_raster_read_from_geotiff(self):
+ raster_path = world_map_raster_input_location
+ r_orig = rasterio.open(raster_path)
+ band = r_orig.read(1)
+ band_masked = np.where(band == 0, np.nan, band)
+ df = TestRasterSerde.spark.read.format("binaryFile").load(raster_path).selectExpr("RS_FromGeoTiff(content) as raster")
+ raster = df.first()[0]
+ assert raster.width == r_orig.width
+ assert raster.height == r_orig.height
+ assert raster.bands_meta[0].nodata == 0
+
+ # test as_rasterio
+ assert (band == raster.as_numpy()[0, :, :]).all()
+ ds = raster.as_rasterio()
+ assert ds.crs is not None
+ band_actual = ds.read(1)
+ assert (band == band_actual).all()
+
+ # test as_numpy
+ arr = raster.as_numpy()
+ assert (arr[0, :, :] == band).all()
+
+ # test as_numpy_masked
+ arr = raster.as_numpy_masked()[0, :, :]
+ assert np.array_equal(arr, band_masked) or np.array_equal(np.isnan(arr), np.isnan(band_masked))
+
+ raster.close()
+ r_orig.close()
+
+ def test_to_pandas(self):
+ spark = TestRasterSerde.spark
+ df = spark.sql("SELECT RS_MakeRasterForTesting(3, 'I', 'BandedSampleModel', 10, 8, 100, 100, 10, -10, 0, 0, 3857) as raster")
+ pandas_df = df.toPandas()
+ raster = pandas_df.iloc[0]['raster']
+ self.validate_test_raster(raster)
+
+ def validate_test_raster(self, raster, packed = False):
+ arr = raster.as_numpy()
+ ds = raster.as_rasterio()
+ bands, height, width = arr.shape
+ assert bands > 0 and width > 0 and height > 0
+ assert ds.crs is not None
+ for b in range(bands):
+ band = ds.read(b + 1)
+ for y in range(height):
+ for x in range(width):
+ expected = b + y * width + x
+ if packed:
+ expected = expected % 16
+ assert arr[b, y, x] == expected
+ assert band[y, x] == expected
diff --git a/spark-shaded/pom.xml b/spark-shaded/pom.xml
index 064e543a2..b9855e876 100644
--- a/spark-shaded/pom.xml
+++ b/spark-shaded/pom.xml
@@ -63,30 +63,128 @@
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-main</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_codec</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_imageio</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-referencing</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_codec</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_imageio</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-epsg-hsql</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_codec</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_imageio</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-geotiff</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_codec</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_imageio</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-process-feature</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_codec</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_imageio</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-arcgrid</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_codec</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_imageio</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.geotools</groupId>
<artifactId>gt-coverage</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_codec</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>javax.media</groupId>
+ <artifactId>jai_imageio</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
</dependencies>
<build>
diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 3d06f84e3..f12976606 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -212,6 +212,7 @@ object Catalog {
function[RS_FromGeoTiff](),
function[RS_MakeEmptyRaster](),
function[RS_MakeRaster](),
+ function[RS_MakeRasterForTesting](),
function[RS_Tile](),
function[RS_TileExplode](),
function[RS_Envelope](),
diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
index 3ffc41c6c..016475307 100644
--- a/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/RasterSerializer.scala
@@ -19,7 +19,7 @@
package org.apache.sedona.sql.utils
-import org.apache.sedona.common.raster.Serde
+import org.apache.sedona.common.raster.serde.Serde
import org.geotools.coverage.grid.GridCoverage2D
object RasterSerializer {
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala
index 7db42c353..f88d61ccd 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/RasterUDT.scala
@@ -19,7 +19,7 @@
package org.apache.spark.sql.sedona_sql.UDT
-import org.apache.sedona.common.raster.Serde
+import org.apache.sedona.common.raster.serde.Serde
import org.apache.spark.sql.types.{BinaryType, DataType, UserDefinedType}
import org.geotools.coverage.grid.GridCoverage2D
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
index c977f0913..1b7751689 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/RasterConstructors.scala
@@ -18,7 +18,8 @@
*/
package org.apache.spark.sql.sedona_sql.expressions.raster
-import org.apache.sedona.common.raster.RasterConstructors
+import org.apache.sedona.common.raster.{RasterConstructors, RasterConstructorsForTesting}
+import org.apache.sedona.sql.utils.RasterSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, Generator, Literal}
@@ -76,6 +77,13 @@ case class RS_MakeRaster(inputExpressions: Seq[Expression])
}
}
+case class RS_MakeRasterForTesting(inputExpressions: Seq[Expression])
+ extends InferredExpression(RasterConstructorsForTesting.makeRasterForTesting _) {
+ protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
+ copy(inputExpressions = newChildren)
+ }
+}
+
case class RS_Tile(inputExpressions: Seq[Expression])
extends InferredExpression(
nullTolerantInferrableFunction3(RasterConstructors.rsTile),
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala
index f1e1c6bf6..a26f8cd9f 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/implicits.scala
@@ -18,7 +18,7 @@
*/
package org.apache.spark.sql.sedona_sql.expressions.raster
-import org.apache.sedona.common.raster.Serde
+import org.apache.sedona.common.raster.serde.Serde
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sedona_sql.expressions.SerdeAware
diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
index 0acc2ab5d..cc49cdacd 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala
@@ -28,7 +28,7 @@ import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue}
import org.locationtech.jts.geom.{Coordinate, Geometry}
import org.scalatest.{BeforeAndAfter, GivenWhenThen}
-import java.awt.image.DataBuffer
+import java.awt.image.{DataBuffer, SinglePixelPackedSampleModel}
import java.io.File
import java.net.URLConnection
import scala.collection.mutable
@@ -846,6 +846,13 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
}
}
+ it("Passed RS_MakeRasterForTesting") {
+ val result = sparkSession.sql("SELECT RS_MakeRasterForTesting(4, 'I', 'SinglePixelPackedSampleModel', 10, 10, 100, 100, 10, -10, 0, 0, 3857) as raster").first().get(0)
+ assert(result.isInstanceOf[GridCoverage2D])
+ val gridCoverage2D = result.asInstanceOf[GridCoverage2D]
+ assert(gridCoverage2D.getRenderedImage.getSampleModel.isInstanceOf[SinglePixelPackedSampleModel])
+ }
+
it("Passed RS_BandAsArray") {
val df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
val metadata = df.selectExpr("RS_Metadata(RS_FromGeoTiff(content))").first().getSeq(0)
diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala b/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
index 786af2746..53a753c05 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/serdeAwareTest.scala
@@ -21,7 +21,7 @@ package org.apache.sedona.sql
import org.apache.sedona.common.geometrySerde.GeometrySerializer
import org.apache.sedona.common.raster.RasterConstructors.fromArcInfoAsciiGrid
-import org.apache.sedona.common.raster.Serde
+import org.apache.sedona.common.raster.serde.Serde
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.sedona_sql.expressions.{ST_Buffer, ST_GeomFromText, ST_Point, ST_Union}
import org.apache.spark.sql.sedona_sql.expressions.raster.{RS_FromArcInfoAsciiGrid, RS_NumBands}