You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/02/14 06:09:41 UTC

[sedona] branch master updated: [SEDONA-207] Fix ambiguity of empty multi-geometries and multi geometries containing only empty geometries (#766)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new ff561824 [SEDONA-207] Fix ambiguity of empty multi-geometries and multi geometries containing only empty geometries (#766)
ff561824 is described below

commit ff561824abf04b7515902e201d76e36c04d7abaa
Author: Kristin Cowalcijk <mo...@yeah.net>
AuthorDate: Tue Feb 14 14:09:34 2023 +0800

    [SEDONA-207] Fix ambiguity of empty multi-geometries and multi geometries containing only empty geometries (#766)
---
 .../common/geometrySerde/GeometrySerializer.java   | 55 +++++-------
 .../geometrySerde/GeometryCollectionSerdeTest.java | 18 +++-
 .../geometrySerde/MultiLineStringSerdeTest.java    | 16 ++++
 .../common/geometrySerde/MultiPointSerdeTest.java  | 15 ++++
 .../geometrySerde/MultiPolygonSerdeTest.java       | 15 ++++
 python/sedona/utils/geometry_serde.py              | 39 +++++----
 python/tests/utils/test_geometry_serde.py          | 98 +++++++++++++++++++++-
 7 files changed, 202 insertions(+), 54 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
index 8c488a46..8b04902e 100644
--- a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
+++ b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
@@ -119,7 +119,7 @@ public class GeometrySerializer {
 
     private static Point deserializePoint(GeometryBuffer buffer, int srid) {
         CoordinateType coordType = buffer.getCoordinateType();
-        int numCoordinates = getNonNegativeInt(buffer, 4);
+        int numCoordinates = getBoundedInt(buffer, 4);
         Point point;
         if (numCoordinates == 0) {
             point = FACTORY.createPoint();
@@ -161,7 +161,7 @@ public class GeometrySerializer {
 
     private static MultiPoint deserializeMultiPoint(GeometryBuffer buffer, int srid) {
         CoordinateType coordType = buffer.getCoordinateType();
-        int numPoints = getNonNegativeInt(buffer, 4);
+        int numPoints = getBoundedInt(buffer, 4);
         int bufferSize = 8 + numPoints * coordType.bytes;
         checkBufferSize(buffer, bufferSize);
         Point[] points = new Point[numPoints];
@@ -203,7 +203,7 @@ public class GeometrySerializer {
 
     private static LineString deserializeLineString(GeometryBuffer buffer, int srid) {
         CoordinateType coordType = buffer.getCoordinateType();
-        int numCoordinates = getNonNegativeInt(buffer, 4);
+        int numCoordinates = getBoundedInt(buffer, 4);
         int bufferSize = 8 + numCoordinates * coordType.bytes;
         checkBufferSize(buffer, bufferSize);
         CoordinateSequence coordinates = buffer.getCoordinates(8, numCoordinates);
@@ -215,10 +215,6 @@ public class GeometrySerializer {
 
     private static GeometryBuffer serializeMultiLineString(MultiLineString multiLineString) {
         int numLineStrings = multiLineString.getNumGeometries();
-        if (numLineStrings <= 0) {
-            return createGeometryBuffer(
-                    WKBConstants.wkbMultiLineString, CoordinateType.XY, multiLineString.getSRID(), 8, 0);
-        }
         CoordinateType coordType = getCoordinateType(multiLineString);
         int numCoordinates = multiLineString.getNumPoints();
         int coordsOffset = 8;
@@ -243,17 +239,11 @@ public class GeometrySerializer {
 
     private static MultiLineString deserializeMultiLineString(GeometryBuffer buffer, int srid) {
         CoordinateType coordType = buffer.getCoordinateType();
-        int numCoordinates = getNonNegativeInt(buffer, 4);
-        if (numCoordinates == 0) {
-            buffer.mark(8);
-            MultiLineString multiLineString = FACTORY.createMultiLineString();
-            multiLineString.setSRID(srid);
-            return multiLineString;
-        }
+        int numCoordinates = getBoundedInt(buffer, 4);
         int coordsOffset = 8;
         int numOffset = 8 + numCoordinates * coordType.bytes;
         GeomPartSerializer serializer = new GeomPartSerializer(buffer, coordsOffset, numOffset);
-        int numLineStrings = serializer.checkedReadNonNegativeInt();
+        int numLineStrings = serializer.checkedReadBoundedInt();
         serializer.checkRemainingIntsAtLeast(numLineStrings);
         LineString[] lineStrings = new LineString[numLineStrings];
         for (int k = 0; k < numLineStrings; k++) {
@@ -291,7 +281,7 @@ public class GeometrySerializer {
 
     private static Polygon deserializePolygon(GeometryBuffer buffer, int srid) {
         CoordinateType coordType = buffer.getCoordinateType();
-        int numCoordinates = getNonNegativeInt(buffer, 4);
+        int numCoordinates = getBoundedInt(buffer, 4);
         if (numCoordinates == 0) {
             buffer.mark(8);
             Polygon polygon = FACTORY.createPolygon();
@@ -309,10 +299,6 @@ public class GeometrySerializer {
 
     private static GeometryBuffer serializeMultiPolygon(MultiPolygon multiPolygon) {
         int numPolygons = multiPolygon.getNumGeometries();
-        if (numPolygons == 0) {
-            return createGeometryBuffer(
-                    WKBConstants.wkbMultiPolygon, CoordinateType.XY, multiPolygon.getSRID(), 8, 0);
-        }
         int numCoordinates = 0;
         CoordinateType coordType = getCoordinateType(multiPolygon);
         int totalRings = 0;
@@ -346,17 +332,11 @@ public class GeometrySerializer {
 
     private static MultiPolygon deserializeMultiPolygon(GeometryBuffer buffer, int srid) {
         CoordinateType coordType = buffer.getCoordinateType();
-        int numCoordinates = getNonNegativeInt(buffer, 4);
-        if (numCoordinates == 0) {
-            buffer.mark(8);
-            MultiPolygon multiPolygon = FACTORY.createMultiPolygon();
-            multiPolygon.setSRID(srid);
-            return multiPolygon;
-        }
+        int numCoordinates = getBoundedInt(buffer, 4);
         int coordsOffset = 8;
         int numPolygonsOffset = 8 + numCoordinates * coordType.bytes;
         GeomPartSerializer serializer = new GeomPartSerializer(buffer, coordsOffset, numPolygonsOffset);
-        int numPolygons = serializer.checkedReadNonNegativeInt();
+        int numPolygons = serializer.checkedReadBoundedInt();
         Polygon[] polygons = new Polygon[numPolygons];
         for (int k = 0; k < numPolygons; k++) {
             Polygon polygon = serializer.readPolygon();
@@ -405,7 +385,7 @@ public class GeometrySerializer {
     }
 
     private static GeometryCollection deserializeGeometryCollection(GeometryBuffer buffer, int srid) {
-        int numGeometries = getNonNegativeInt(buffer, 4);
+        int numGeometries = getBoundedInt(buffer, 4);
         if (numGeometries == 0) {
             buffer.mark(8);
             GeometryCollection geometryCollection = FACTORY.createGeometryCollection();
@@ -454,11 +434,14 @@ public class GeometrySerializer {
         }
     }
 
-    private static int getNonNegativeInt(GeometryBuffer buffer, int offset) {
+    private static int getBoundedInt(GeometryBuffer buffer, int offset) {
         int value = buffer.getInt(offset);
         if (value < 0) {
             throw new IllegalArgumentException("Unexpected negative value encountered: " + value);
         }
+        if (value > buffer.getLength()) {
+            throw new IllegalArgumentException("Unexpected large value encountered: " + value);
+        }
         return value;
     }
 
@@ -517,7 +500,7 @@ public class GeometrySerializer {
         }
 
         Polygon readPolygon() {
-            int numRings = checkedReadNonNegativeInt();
+            int numRings = checkedReadBoundedInt();
             if (numRings == 0) {
                 return FACTORY.createPolygon();
             }
@@ -532,7 +515,7 @@ public class GeometrySerializer {
         }
 
         CoordinateSequence readCoordinates() {
-            int numCoordinates = getNonNegativeInt(buffer, intsOffset);
+            int numCoordinates = getBoundedInt(buffer, intsOffset);
             int newCoordsOffset = coordsOffset + buffer.getCoordinateType().bytes * numCoordinates;
             if (newCoordsOffset > coordsEndOffset) {
                 throw new IllegalStateException(
@@ -544,15 +527,15 @@ public class GeometrySerializer {
             return coordinates;
         }
 
-        int readNonNegativeInt() {
-            int value = getNonNegativeInt(buffer, intsOffset);
+        int readBoundedInt() {
+            int value = getBoundedInt(buffer, intsOffset);
             intsOffset += 4;
             return value;
         }
 
-        int checkedReadNonNegativeInt() {
+        int checkedReadBoundedInt() {
             checkBufferSize(buffer, intsOffset + 4);
-            return readNonNegativeInt();
+            return readBoundedInt();
         }
 
         void checkRemainingIntsAtLeast(int num) {
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java
index ce84e69a..18309bf4 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java
@@ -32,6 +32,8 @@ import org.locationtech.jts.geom.MultiPolygon;
 import org.locationtech.jts.geom.Point;
 import org.locationtech.jts.geom.Polygon;
 
+import javax.sound.sampled.Line;
+
 public class GeometryCollectionSerdeTest {
     private static final GeometryFactory gf = new GeometryFactory();
 
@@ -82,9 +84,23 @@ public class GeometryCollectionSerdeTest {
                                 gf.createMultiPoint(),
                                 gf.createMultiLineString(),
                                 gf.createMultiPolygon(),
+                                gf.createMultiPoint(new Point[] {
+                                        gf.createPoint(),
+                                        gf.createPoint()
+                                }),
+                                gf.createMultiLineString(new LineString[] {
+                                        gf.createLineString(),
+                                        gf.createLineString()
+                                }),
+                                gf.createMultiPolygon(new Polygon[] {
+                                        gf.createPolygon(),
+                                        gf.createPolygon(),
+                                        gf.createPolygon()
+                                }),
                                 multiPoint,
                                 multiLineString,
-                                multiPolygon
+                                multiPolygon,
+                                point
                         });
         geometryCollection.setSRID(4326);
         byte[] bytes = GeometrySerializer.serialize(geometryCollection);
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java
index 72e5637f..8e392622 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java
@@ -65,4 +65,20 @@ public class MultiLineStringSerdeTest {
         Assert.assertEquals(3, multiLineString2.getNumGeometries());
         Assert.assertEquals(multiLineString, multiLineString2);
     }
+
+    @Test
+    public void testMultiLineStringContainingEmptyLineStrings() {
+        MultiLineString multiLineString = gf.createMultiLineString(
+                new LineString[] {
+                        gf.createLineString(),
+                        gf.createLineString(),
+                        gf.createLineString()
+                }
+        );
+        multiLineString.setSRID(4326);
+        byte[] bytes = GeometrySerializer.serialize(multiLineString);
+        Geometry geom = GeometrySerializer.deserialize(bytes);
+        Assert.assertEquals(3, geom.getNumGeometries());
+        Assert.assertEquals(multiLineString, geom);
+    }
 }
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java
index 9beb2590..4e814b32 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java
@@ -86,6 +86,21 @@ public class MultiPointSerdeTest {
         Assert.assertEquals(4, multiPoint2.getGeometryN(2).getCoordinate().y, 1e-6);
     }
 
+    @Test
+    public void testMultiPointWithEmptyPointsOnly() {
+        Point[] points =
+                new Point[]{
+                        gf.createPoint(),
+                        gf.createPoint(),
+                        gf.createPoint()
+                };
+        MultiPoint multiPoint = gf.createMultiPoint(points);
+        byte[] bytes = GeometrySerializer.serialize(multiPoint);
+        Geometry geom = GeometrySerializer.deserialize(bytes);
+        Assert.assertEquals(3, geom.getNumGeometries());
+        Assert.assertEquals(multiPoint, geom);
+    }
+
     @Test
     public void testMultiPointXYM() {
         MultiPoint multiPoint =
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java
index 1fed8a47..f2098a1f 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java
@@ -86,4 +86,19 @@ public class MultiPolygonSerdeTest {
         Assert.assertEquals(3, mp.getNumGeometries());
         Assert.assertEquals(multiPolygon, mp);
     }
+
+    @Test
+    public void testMultiPolygonContainingEmptyPolygons() {
+        MultiPolygon multiPolygon =
+                gf.createMultiPolygon(
+                        new Polygon[]{
+                                gf.createPolygon(),
+                                gf.createPolygon()
+                        });
+        multiPolygon.setSRID(4326);
+        byte[] bytes = GeometrySerializer.serialize(multiPolygon);
+        Geometry geom = GeometrySerializer.deserialize(bytes);
+        Assert.assertEquals(2, geom.getNumGeometries());
+        Assert.assertEquals(multiPolygon, geom);
+    }
 }
diff --git a/python/sedona/utils/geometry_serde.py b/python/sedona/utils/geometry_serde.py
index 174b6746..9efd33eb 100644
--- a/python/sedona/utils/geometry_serde.py
+++ b/python/sedona/utils/geometry_serde.py
@@ -66,7 +66,7 @@ class CoordinateType:
 
     BYTES_PER_COORDINATE = [16, 24, 24, 32]
     NUM_COORD_COMPONENTS = [2, 3, 3, 4]
-    UNPACK_FORMAT = ['dd', 'ddd', 'dddd']
+    UNPACK_FORMAT = ['dd', 'ddd', 'ddxxxxxxxx', 'dddxxxxxxxx']
 
     @staticmethod
     def type_of(geom) -> int:
@@ -160,6 +160,8 @@ class GeometryBuffer:
 
     def read_int(self) -> int:
         value = struct.unpack_from("i", self.buffer, self.ints_offset)[0]
+        if value > len(self.buffer):
+            raise ValueError('Unexpected large integer in structural data')
         self.ints_offset += 4
         return value
 
@@ -207,6 +209,8 @@ def deserialize(buffer: bytes) -> Optional[BaseGeometry]:
     geom_type = (preamble_byte >> 4) & 0x0F
     coord_type = (preamble_byte >> 1) & 0x07
     num_coords = struct.unpack_from('i', buffer, 4)[0]
+    if num_coords > len(buffer):
+        raise ValueError('num_coords cannot be larger than buffer size')
     geom_buffer = GeometryBuffer(buffer, coord_type, 8, num_coords)
     if geom_type == GeometryTypeID.POINT:
         geom = deserialize_point(geom_buffer)
@@ -260,9 +264,6 @@ def put_coordinate(buffer: bytearray, offset: int, coord_type: int, coord: Coord
 
 
 def get_coordinates(buffer: bytearray, offset: int, coord_type: int, num_coords: int) -> Union[np.ndarray, ListCoordType]:
-    if coord_type == CoordinateType.XYM or coord_type == CoordinateType.XYZM:
-        raise NotImplementedError("XYM or XYZM coordinates are not supported")
-
     if num_coords < GET_COORDS_NUMPY_THRESHOLD:
         coords = [
             struct.unpack_from(CoordinateType.unpack_format(coord_type), buffer, offset + (i * CoordinateType.bytes_per_coord(coord_type)))
@@ -328,7 +329,16 @@ def serialize_multi_point(geom: MultiPoint) -> bytes:
     coord_type = CoordinateType.type_of(geom)
 
     header = generate_header_bytes(GeometryTypeID.MULTIPOINT, coord_type, num_points)
-    body = array.array('d', (coord for point in points for coord in list(point.coords[0]))).tobytes()
+    coords = []
+    for point in points:
+        if point.coords:
+            for coord in point.coords[0]:
+                coords.append(coord)
+        else:
+            for k in range(geom._ndim):
+                coords.append(math.nan)
+
+    body = array.array('d', coords).tobytes()
 
     return header + body
 
@@ -367,9 +377,6 @@ def deserialize_linestring(geom_buffer: GeometryBuffer) -> LineString:
 def serialize_multi_linestring(geom: MultiLineString) -> bytes:
     linestrings = list(geom.geoms)
 
-    if not linestrings:
-        return generate_header_bytes(GeometryTypeID.MULTILINESTRING, CoordinateType.XY, 0)
-
     coord_type = CoordinateType.type_of(geom)
     lines = [[list(coord) for coord in ls.coords] for ls in linestrings]
     line_lengths = [len(l) for l in lines]
@@ -386,14 +393,14 @@ def serialize_multi_linestring(geom: MultiLineString) -> bytes:
 
 
 def deserialize_multi_linestring(geom_buffer: GeometryBuffer) -> MultiLineString:
-    if geom_buffer.num_coords == 0:
-        return wkt_loads("MULTILINESTRING EMPTY")
     num_linestrings = geom_buffer.read_int()
     linestrings = []
     for k in range(0, num_linestrings):
         linestring = geom_buffer.read_linestring()
         if not linestring.is_empty:
             linestrings.append(linestring)
+    if not linestrings:
+        return wkt_loads("MULTILINESTRING EMPTY")
     return MultiLineString(linestrings)
 
 def serialize_polygon(geom: Polygon) -> bytes:
@@ -439,9 +446,6 @@ def serialize_multi_polygon(geom: MultiPolygon) -> bytes:
     coords_for = lambda x: [y for y in list(x)]
     polygons = [[coords_for(polygon.exterior.coords)] + [coords_for(ring.coords) for ring in polygon.interiors] for polygon in list(geom.geoms)]
 
-    if not polygons:
-        return generate_header_bytes(GeometryTypeID.MULTIPOLYGON, CoordinateType.XY, 0)
-
     coord_type = CoordinateType.type_of(geom)
 
     structure_data = array.array('i', [val for polygon in polygons for val in [len(polygon)] + [len(ring) for ring in polygon]]).tobytes()
@@ -457,14 +461,14 @@ def serialize_multi_polygon(geom: MultiPolygon) -> bytes:
 
 
 def deserialize_multi_polygon(geom_buffer: GeometryBuffer) -> MultiPolygon:
-    if geom_buffer.num_coords == 0:
-        return wkt_loads("MULTIPOLYGON EMPTY")
     num_polygons = geom_buffer.read_int()
     polygons = []
     for k in range(0, num_polygons):
         polygon = geom_buffer.read_polygon()
         if not polygon.is_empty:
             polygons.append(polygon)
+    if not polygons:
+        return wkt_loads("MULTIPOLYGON EMPTY")
     return MultiPolygon(polygons)
 
 
@@ -493,6 +497,7 @@ def serialize_geometry_collection(geom: GeometryCollection) -> bytearray:
 
 
 def serialize_shapely_1_empty_geom(geom: BaseGeometry) -> bytearray:
+    total_size = 8
     if isinstance(geom, Point):
         geom_type = GeometryTypeID.POINT
     elif isinstance(geom, LineString):
@@ -503,11 +508,13 @@ def serialize_shapely_1_empty_geom(geom: BaseGeometry) -> bytearray:
         geom_type = GeometryTypeID.MULTIPOINT
     elif isinstance(geom, MultiLineString):
         geom_type = GeometryTypeID.MULTILINESTRING
+        total_size = 12
     elif isinstance(geom, MultiPolygon):
         geom_type = GeometryTypeID.MULTIPOLYGON
+        total_size = 12
     else:
         raise ValueError("Invalid empty geometry collection object: {}".format(geom))
-    return create_buffer_for_geom(geom_type, CoordinateType.XY, 8, 0)
+    return create_buffer_for_geom(geom_type, CoordinateType.XY, total_size, 0)
 
 
 def deserialize_geometry_collection(geom_buffer: GeometryBuffer) -> GeometryCollection:
diff --git a/python/tests/utils/test_geometry_serde.py b/python/tests/utils/test_geometry_serde.py
index 5818ebbd..e6c8fc8c 100644
--- a/python/tests/utils/test_geometry_serde.py
+++ b/python/tests/utils/test_geometry_serde.py
@@ -16,8 +16,9 @@
 #  under the License.
 import pytest
 
-from pyspark.sql.types import StructType
+from pyspark.sql.types import (StructType, StringType)
 from sedona.sql.types import GeometryType
+from pyspark.sql.functions import expr
 from sedona.utils import geometry_serde
 
 from shapely.geometry.base import BaseGeometry
@@ -56,6 +57,101 @@ class TestGeometrySerde(TestBase):
         returned_geom = TestGeometrySerde.spark.createDataFrame([(geom,)], StructType().add("geom", GeometryType())).take(1)[0][0]
         assert geom.equals_exact(returned_geom, 1e-6)
 
+    @pytest.mark.parametrize("wkt", [
+        # empty geometries
+        'POINT EMPTY',
+        'LINESTRING EMPTY',
+        'POLYGON EMPTY',
+        'MULTIPOINT EMPTY',
+        'MULTILINESTRING EMPTY',
+        'MULTIPOLYGON EMPTY',
+        'GEOMETRYCOLLECTION EMPTY',
+        # non-empty geometries
+        'POINT (10 20)',
+        'POINT (10 20 30)',
+        'LINESTRING (10 20, 30 40)',
+        'LINESTRING (10 20 30, 40 50 60)',
+        'POLYGON ((10 10, 20 20, 20 10, 10 10))',
+        'POLYGON ((10 10 10, 20 20 10, 20 10 10, 10 10 10))',
+        'POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1))',
+        # non-empty multi geometries
+        'MULTIPOINT ((10 20), (30 40))',
+        'MULTIPOINT ((10 20 30), (40 50 60))',
+        'MULTILINESTRING ((10 20, 30 40), (50 60, 70 80))',
+        'MULTILINESTRING ((10 20 30, 40 50 60), (70 80 90, 100 110 120))',
+        'MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((-10 -10, -20 -20, -20 -10, -10 -10)))',
+        'MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1)))',
+        'GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40))',
+        'GEOMETRYCOLLECTION (POINT (10 20 30), LINESTRING (10 20 30, 40 50 60))',
+        'GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40), POLYGON ((10 10, 20 20, 20 10, 10 10)))',
+        # nested geometry collection
+        'GEOMETRYCOLLECTION (GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))',
+        'GEOMETRYCOLLECTION (POINT (1 2), GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))',
+        # multi geometries containing empty geometries
+        'MULTIPOINT (EMPTY, (10 20))',
+        'MULTIPOINT (EMPTY, EMPTY)',
+        'MULTILINESTRING (EMPTY, (10 20, 30 40))',
+        'MULTILINESTRING (EMPTY, EMPTY)',
+        'MULTIPOLYGON (EMPTY, ((10 10, 20 20, 20 10, 10 10)))',
+        'MULTIPOLYGON (EMPTY, EMPTY)',
+        'GEOMETRYCOLLECTION (POINT (10 20), POINT EMPTY, LINESTRING (10 20, 30 40))',
+        'GEOMETRYCOLLECTION (MULTIPOINT EMPTY, MULTILINESTRING EMPTY, MULTIPOLYGON EMPTY, GEOMETRYCOLLECTION EMPTY)',
+    ])
+    def test_spark_serde_compatibility_with_scala(self, wkt):
+        geom = wkt_loads(wkt)
+        schema = StructType().add("geom", GeometryType())
+        returned_geom = TestGeometrySerde.spark.createDataFrame([(geom,)], schema).take(1)[0][0]
+        assert geom.equals(returned_geom)
+
+        # serialized by python, deserialized by scala
+        returned_wkt = TestGeometrySerde.spark.createDataFrame([(geom,)], schema).selectExpr("ST_AsText(geom)").take(1)[0][0]
+        assert wkt_loads(returned_wkt).equals(geom)
+
+        # serialized by scala, deserialized by python
+        schema = StructType().add("wkt", StringType())
+        returned_geom = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema).selectExpr("ST_GeomFromText(wkt)").take(1)[0][0]
+        assert geom.equals(returned_geom)
+
+    @pytest.mark.parametrize("wkt", [
+        'POINT ZM (1 2 3 4)',
+        'LINESTRING ZM (1 2 3 4, 5 6 7 8)',
+        'POLYGON ZM ((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1))',
+        'MULTIPOINT ZM ((10 20 30 1), (40 50 60 1))',
+        'MULTILINESTRING ZM ((10 20 30 1, 40 50 60 1), (70 80 90 1, 100 110 120 1))',
+        'MULTIPOLYGON ZM (((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1)), ' +
+        '((0 0 0 1, 0 10 0 1, 10 10 0 1, 10 0 0 1, 0 0 0 1), (1 1 0 1, 1 2 0 1, 2 2 0 1, 2 1 0 1, 1 1 0 1)))',
+        'GEOMETRYCOLLECTION (POINT ZM (10 20 30 1), LINESTRING ZM (10 20 30 1, 40 50 60 1))',
+    ])
+    def test_spark_serde_on_4d_geoms(self, wkt):
+        geom = wkt_loads(wkt)
+        schema = StructType().add("wkt", StringType())
+        returned_geom, n_dims = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema)\
+            .selectExpr("ST_GeomFromText(wkt)", "ST_NDims(ST_GeomFromText(wkt))")\
+            .take(1)[0]
+        assert n_dims == 4
+        assert geom.equals(returned_geom)
+
+    @pytest.mark.parametrize("wkt", [
+        'POINT M (1 2 3)',
+        'LINESTRING M (1 2 3, 5 6 7)',
+        'POLYGON M ((10 10 10, 20 20 10, 20 10 10, 10 10 10))',
+        'MULTIPOINT M ((10 20 30), (40 50 60))',
+        'MULTILINESTRING M ((10 20 30, 40 50 60), (70 80 90, 100 110 120))',
+        'MULTIPOLYGON M (((10 10 10, 20 20 10, 20 10 10, 10 10 10)), ' +
+        '((0 0 0, 0 10 0, 10 10 0, 10 0 0, 0 0 0), (1 1 0, 1 2 0, 2 2 0, 2 1 0, 1 1 0)))',
+        'GEOMETRYCOLLECTION (POINT M (10 20 30), LINESTRING M (10 20 30, 40 50 60))',
+        ])
+    def test_spark_serde_on_xym_geoms(self, wkt):
+        geom = wkt_loads(wkt)
+        schema = StructType().add("wkt", StringType())
+        returned_geom, n_dims, z_min = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema) \
+            .withColumn("geom", expr("ST_GeomFromText(wkt)")) \
+            .selectExpr("geom", "ST_NDims(geom)", "ST_ZMin(geom)") \
+            .take(1)[0]
+        assert n_dims == 3
+        assert z_min is None
+        assert geom.equals(returned_geom)
+
     def test_point(self):
         points = [
             wkt_loads("POINT EMPTY"),