You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2023/02/14 06:09:41 UTC
[sedona] branch master updated: [SEDONA-207] Fix ambiguity of empty multi-geometries and multi geometries containing only empty geometries (#766)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new ff561824 [SEDONA-207] Fix ambiguity of empty multi-geometries and multi geometries containing only empty geometries (#766)
ff561824 is described below
commit ff561824abf04b7515902e201d76e36c04d7abaa
Author: Kristin Cowalcijk <mo...@yeah.net>
AuthorDate: Tue Feb 14 14:09:34 2023 +0800
[SEDONA-207] Fix ambiguity of empty multi-geometries and multi geometries containing only empty geometries (#766)
---
.../common/geometrySerde/GeometrySerializer.java | 55 +++++-------
.../geometrySerde/GeometryCollectionSerdeTest.java | 18 +++-
.../geometrySerde/MultiLineStringSerdeTest.java | 16 ++++
.../common/geometrySerde/MultiPointSerdeTest.java | 15 ++++
.../geometrySerde/MultiPolygonSerdeTest.java | 15 ++++
python/sedona/utils/geometry_serde.py | 39 +++++----
python/tests/utils/test_geometry_serde.py | 98 +++++++++++++++++++++-
7 files changed, 202 insertions(+), 54 deletions(-)
diff --git a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
index 8c488a46..8b04902e 100644
--- a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
+++ b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
@@ -119,7 +119,7 @@ public class GeometrySerializer {
private static Point deserializePoint(GeometryBuffer buffer, int srid) {
CoordinateType coordType = buffer.getCoordinateType();
- int numCoordinates = getNonNegativeInt(buffer, 4);
+ int numCoordinates = getBoundedInt(buffer, 4);
Point point;
if (numCoordinates == 0) {
point = FACTORY.createPoint();
@@ -161,7 +161,7 @@ public class GeometrySerializer {
private static MultiPoint deserializeMultiPoint(GeometryBuffer buffer, int srid) {
CoordinateType coordType = buffer.getCoordinateType();
- int numPoints = getNonNegativeInt(buffer, 4);
+ int numPoints = getBoundedInt(buffer, 4);
int bufferSize = 8 + numPoints * coordType.bytes;
checkBufferSize(buffer, bufferSize);
Point[] points = new Point[numPoints];
@@ -203,7 +203,7 @@ public class GeometrySerializer {
private static LineString deserializeLineString(GeometryBuffer buffer, int srid) {
CoordinateType coordType = buffer.getCoordinateType();
- int numCoordinates = getNonNegativeInt(buffer, 4);
+ int numCoordinates = getBoundedInt(buffer, 4);
int bufferSize = 8 + numCoordinates * coordType.bytes;
checkBufferSize(buffer, bufferSize);
CoordinateSequence coordinates = buffer.getCoordinates(8, numCoordinates);
@@ -215,10 +215,6 @@ public class GeometrySerializer {
private static GeometryBuffer serializeMultiLineString(MultiLineString multiLineString) {
int numLineStrings = multiLineString.getNumGeometries();
- if (numLineStrings <= 0) {
- return createGeometryBuffer(
- WKBConstants.wkbMultiLineString, CoordinateType.XY, multiLineString.getSRID(), 8, 0);
- }
CoordinateType coordType = getCoordinateType(multiLineString);
int numCoordinates = multiLineString.getNumPoints();
int coordsOffset = 8;
@@ -243,17 +239,11 @@ public class GeometrySerializer {
private static MultiLineString deserializeMultiLineString(GeometryBuffer buffer, int srid) {
CoordinateType coordType = buffer.getCoordinateType();
- int numCoordinates = getNonNegativeInt(buffer, 4);
- if (numCoordinates == 0) {
- buffer.mark(8);
- MultiLineString multiLineString = FACTORY.createMultiLineString();
- multiLineString.setSRID(srid);
- return multiLineString;
- }
+ int numCoordinates = getBoundedInt(buffer, 4);
int coordsOffset = 8;
int numOffset = 8 + numCoordinates * coordType.bytes;
GeomPartSerializer serializer = new GeomPartSerializer(buffer, coordsOffset, numOffset);
- int numLineStrings = serializer.checkedReadNonNegativeInt();
+ int numLineStrings = serializer.checkedReadBoundedInt();
serializer.checkRemainingIntsAtLeast(numLineStrings);
LineString[] lineStrings = new LineString[numLineStrings];
for (int k = 0; k < numLineStrings; k++) {
@@ -291,7 +281,7 @@ public class GeometrySerializer {
private static Polygon deserializePolygon(GeometryBuffer buffer, int srid) {
CoordinateType coordType = buffer.getCoordinateType();
- int numCoordinates = getNonNegativeInt(buffer, 4);
+ int numCoordinates = getBoundedInt(buffer, 4);
if (numCoordinates == 0) {
buffer.mark(8);
Polygon polygon = FACTORY.createPolygon();
@@ -309,10 +299,6 @@ public class GeometrySerializer {
private static GeometryBuffer serializeMultiPolygon(MultiPolygon multiPolygon) {
int numPolygons = multiPolygon.getNumGeometries();
- if (numPolygons == 0) {
- return createGeometryBuffer(
- WKBConstants.wkbMultiPolygon, CoordinateType.XY, multiPolygon.getSRID(), 8, 0);
- }
int numCoordinates = 0;
CoordinateType coordType = getCoordinateType(multiPolygon);
int totalRings = 0;
@@ -346,17 +332,11 @@ public class GeometrySerializer {
private static MultiPolygon deserializeMultiPolygon(GeometryBuffer buffer, int srid) {
CoordinateType coordType = buffer.getCoordinateType();
- int numCoordinates = getNonNegativeInt(buffer, 4);
- if (numCoordinates == 0) {
- buffer.mark(8);
- MultiPolygon multiPolygon = FACTORY.createMultiPolygon();
- multiPolygon.setSRID(srid);
- return multiPolygon;
- }
+ int numCoordinates = getBoundedInt(buffer, 4);
int coordsOffset = 8;
int numPolygonsOffset = 8 + numCoordinates * coordType.bytes;
GeomPartSerializer serializer = new GeomPartSerializer(buffer, coordsOffset, numPolygonsOffset);
- int numPolygons = serializer.checkedReadNonNegativeInt();
+ int numPolygons = serializer.checkedReadBoundedInt();
Polygon[] polygons = new Polygon[numPolygons];
for (int k = 0; k < numPolygons; k++) {
Polygon polygon = serializer.readPolygon();
@@ -405,7 +385,7 @@ public class GeometrySerializer {
}
private static GeometryCollection deserializeGeometryCollection(GeometryBuffer buffer, int srid) {
- int numGeometries = getNonNegativeInt(buffer, 4);
+ int numGeometries = getBoundedInt(buffer, 4);
if (numGeometries == 0) {
buffer.mark(8);
GeometryCollection geometryCollection = FACTORY.createGeometryCollection();
@@ -454,11 +434,14 @@ public class GeometrySerializer {
}
}
- private static int getNonNegativeInt(GeometryBuffer buffer, int offset) {
+ private static int getBoundedInt(GeometryBuffer buffer, int offset) {
int value = buffer.getInt(offset);
if (value < 0) {
throw new IllegalArgumentException("Unexpected negative value encountered: " + value);
}
+ if (value > buffer.getLength()) {
+ throw new IllegalArgumentException("Unexpected large value encountered: " + value);
+ }
return value;
}
@@ -517,7 +500,7 @@ public class GeometrySerializer {
}
Polygon readPolygon() {
- int numRings = checkedReadNonNegativeInt();
+ int numRings = checkedReadBoundedInt();
if (numRings == 0) {
return FACTORY.createPolygon();
}
@@ -532,7 +515,7 @@ public class GeometrySerializer {
}
CoordinateSequence readCoordinates() {
- int numCoordinates = getNonNegativeInt(buffer, intsOffset);
+ int numCoordinates = getBoundedInt(buffer, intsOffset);
int newCoordsOffset = coordsOffset + buffer.getCoordinateType().bytes * numCoordinates;
if (newCoordsOffset > coordsEndOffset) {
throw new IllegalStateException(
@@ -544,15 +527,15 @@ public class GeometrySerializer {
return coordinates;
}
- int readNonNegativeInt() {
- int value = getNonNegativeInt(buffer, intsOffset);
+ int readBoundedInt() {
+ int value = getBoundedInt(buffer, intsOffset);
intsOffset += 4;
return value;
}
- int checkedReadNonNegativeInt() {
+ int checkedReadBoundedInt() {
checkBufferSize(buffer, intsOffset + 4);
- return readNonNegativeInt();
+ return readBoundedInt();
}
void checkRemainingIntsAtLeast(int num) {
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java
index ce84e69a..18309bf4 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/GeometryCollectionSerdeTest.java
@@ -32,6 +32,8 @@ import org.locationtech.jts.geom.MultiPolygon;
import org.locationtech.jts.geom.Point;
import org.locationtech.jts.geom.Polygon;
+import javax.sound.sampled.Line;
+
public class GeometryCollectionSerdeTest {
private static final GeometryFactory gf = new GeometryFactory();
@@ -82,9 +84,23 @@ public class GeometryCollectionSerdeTest {
gf.createMultiPoint(),
gf.createMultiLineString(),
gf.createMultiPolygon(),
+ gf.createMultiPoint(new Point[] {
+ gf.createPoint(),
+ gf.createPoint()
+ }),
+ gf.createMultiLineString(new LineString[] {
+ gf.createLineString(),
+ gf.createLineString()
+ }),
+ gf.createMultiPolygon(new Polygon[] {
+ gf.createPolygon(),
+ gf.createPolygon(),
+ gf.createPolygon()
+ }),
multiPoint,
multiLineString,
- multiPolygon
+ multiPolygon,
+ point
});
geometryCollection.setSRID(4326);
byte[] bytes = GeometrySerializer.serialize(geometryCollection);
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java
index 72e5637f..8e392622 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiLineStringSerdeTest.java
@@ -65,4 +65,20 @@ public class MultiLineStringSerdeTest {
Assert.assertEquals(3, multiLineString2.getNumGeometries());
Assert.assertEquals(multiLineString, multiLineString2);
}
+
+ @Test
+ public void testMultiLineStringContainingEmptyLineStrings() {
+ MultiLineString multiLineString = gf.createMultiLineString(
+ new LineString[] {
+ gf.createLineString(),
+ gf.createLineString(),
+ gf.createLineString()
+ }
+ );
+ multiLineString.setSRID(4326);
+ byte[] bytes = GeometrySerializer.serialize(multiLineString);
+ Geometry geom = GeometrySerializer.deserialize(bytes);
+ Assert.assertEquals(3, geom.getNumGeometries());
+ Assert.assertEquals(multiLineString, geom);
+ }
}
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java
index 9beb2590..4e814b32 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPointSerdeTest.java
@@ -86,6 +86,21 @@ public class MultiPointSerdeTest {
Assert.assertEquals(4, multiPoint2.getGeometryN(2).getCoordinate().y, 1e-6);
}
+ @Test
+ public void testMultiPointWithEmptyPointsOnly() {
+ Point[] points =
+ new Point[]{
+ gf.createPoint(),
+ gf.createPoint(),
+ gf.createPoint()
+ };
+ MultiPoint multiPoint = gf.createMultiPoint(points);
+ byte[] bytes = GeometrySerializer.serialize(multiPoint);
+ Geometry geom = GeometrySerializer.deserialize(bytes);
+ Assert.assertEquals(3, geom.getNumGeometries());
+ Assert.assertEquals(multiPoint, geom);
+ }
+
@Test
public void testMultiPointXYM() {
MultiPoint multiPoint =
diff --git a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java
index 1fed8a47..f2098a1f 100644
--- a/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/geometrySerde/MultiPolygonSerdeTest.java
@@ -86,4 +86,19 @@ public class MultiPolygonSerdeTest {
Assert.assertEquals(3, mp.getNumGeometries());
Assert.assertEquals(multiPolygon, mp);
}
+
+ @Test
+ public void testMultiPolygonContainingEmptyPolygons() {
+ MultiPolygon multiPolygon =
+ gf.createMultiPolygon(
+ new Polygon[]{
+ gf.createPolygon(),
+ gf.createPolygon()
+ });
+ multiPolygon.setSRID(4326);
+ byte[] bytes = GeometrySerializer.serialize(multiPolygon);
+ Geometry geom = GeometrySerializer.deserialize(bytes);
+ Assert.assertEquals(2, geom.getNumGeometries());
+ Assert.assertEquals(multiPolygon, geom);
+ }
}
diff --git a/python/sedona/utils/geometry_serde.py b/python/sedona/utils/geometry_serde.py
index 174b6746..9efd33eb 100644
--- a/python/sedona/utils/geometry_serde.py
+++ b/python/sedona/utils/geometry_serde.py
@@ -66,7 +66,7 @@ class CoordinateType:
BYTES_PER_COORDINATE = [16, 24, 24, 32]
NUM_COORD_COMPONENTS = [2, 3, 3, 4]
- UNPACK_FORMAT = ['dd', 'ddd', 'dddd']
+ UNPACK_FORMAT = ['dd', 'ddd', 'ddxxxxxxxx', 'dddxxxxxxxx']
@staticmethod
def type_of(geom) -> int:
@@ -160,6 +160,8 @@ class GeometryBuffer:
def read_int(self) -> int:
value = struct.unpack_from("i", self.buffer, self.ints_offset)[0]
+ if value > len(self.buffer):
+ raise ValueError('Unexpected large integer in structural data')
self.ints_offset += 4
return value
@@ -207,6 +209,8 @@ def deserialize(buffer: bytes) -> Optional[BaseGeometry]:
geom_type = (preamble_byte >> 4) & 0x0F
coord_type = (preamble_byte >> 1) & 0x07
num_coords = struct.unpack_from('i', buffer, 4)[0]
+ if num_coords > len(buffer):
+ raise ValueError('num_coords cannot be larger than buffer size')
geom_buffer = GeometryBuffer(buffer, coord_type, 8, num_coords)
if geom_type == GeometryTypeID.POINT:
geom = deserialize_point(geom_buffer)
@@ -260,9 +264,6 @@ def put_coordinate(buffer: bytearray, offset: int, coord_type: int, coord: Coord
def get_coordinates(buffer: bytearray, offset: int, coord_type: int, num_coords: int) -> Union[np.ndarray, ListCoordType]:
- if coord_type == CoordinateType.XYM or coord_type == CoordinateType.XYZM:
- raise NotImplementedError("XYM or XYZM coordinates are not supported")
-
if num_coords < GET_COORDS_NUMPY_THRESHOLD:
coords = [
struct.unpack_from(CoordinateType.unpack_format(coord_type), buffer, offset + (i * CoordinateType.bytes_per_coord(coord_type)))
@@ -328,7 +329,16 @@ def serialize_multi_point(geom: MultiPoint) -> bytes:
coord_type = CoordinateType.type_of(geom)
header = generate_header_bytes(GeometryTypeID.MULTIPOINT, coord_type, num_points)
- body = array.array('d', (coord for point in points for coord in list(point.coords[0]))).tobytes()
+ coords = []
+ for point in points:
+ if point.coords:
+ for coord in point.coords[0]:
+ coords.append(coord)
+ else:
+ for k in range(geom._ndim):
+ coords.append(math.nan)
+
+ body = array.array('d', coords).tobytes()
return header + body
@@ -367,9 +377,6 @@ def deserialize_linestring(geom_buffer: GeometryBuffer) -> LineString:
def serialize_multi_linestring(geom: MultiLineString) -> bytes:
linestrings = list(geom.geoms)
- if not linestrings:
- return generate_header_bytes(GeometryTypeID.MULTILINESTRING, CoordinateType.XY, 0)
-
coord_type = CoordinateType.type_of(geom)
lines = [[list(coord) for coord in ls.coords] for ls in linestrings]
line_lengths = [len(l) for l in lines]
@@ -386,14 +393,14 @@ def serialize_multi_linestring(geom: MultiLineString) -> bytes:
def deserialize_multi_linestring(geom_buffer: GeometryBuffer) -> MultiLineString:
- if geom_buffer.num_coords == 0:
- return wkt_loads("MULTILINESTRING EMPTY")
num_linestrings = geom_buffer.read_int()
linestrings = []
for k in range(0, num_linestrings):
linestring = geom_buffer.read_linestring()
if not linestring.is_empty:
linestrings.append(linestring)
+ if not linestrings:
+ return wkt_loads("MULTILINESTRING EMPTY")
return MultiLineString(linestrings)
def serialize_polygon(geom: Polygon) -> bytes:
@@ -439,9 +446,6 @@ def serialize_multi_polygon(geom: MultiPolygon) -> bytes:
coords_for = lambda x: [y for y in list(x)]
polygons = [[coords_for(polygon.exterior.coords)] + [coords_for(ring.coords) for ring in polygon.interiors] for polygon in list(geom.geoms)]
- if not polygons:
- return generate_header_bytes(GeometryTypeID.MULTIPOLYGON, CoordinateType.XY, 0)
-
coord_type = CoordinateType.type_of(geom)
structure_data = array.array('i', [val for polygon in polygons for val in [len(polygon)] + [len(ring) for ring in polygon]]).tobytes()
@@ -457,14 +461,14 @@ def serialize_multi_polygon(geom: MultiPolygon) -> bytes:
def deserialize_multi_polygon(geom_buffer: GeometryBuffer) -> MultiPolygon:
- if geom_buffer.num_coords == 0:
- return wkt_loads("MULTIPOLYGON EMPTY")
num_polygons = geom_buffer.read_int()
polygons = []
for k in range(0, num_polygons):
polygon = geom_buffer.read_polygon()
if not polygon.is_empty:
polygons.append(polygon)
+ if not polygons:
+ return wkt_loads("MULTIPOLYGON EMPTY")
return MultiPolygon(polygons)
@@ -493,6 +497,7 @@ def serialize_geometry_collection(geom: GeometryCollection) -> bytearray:
def serialize_shapely_1_empty_geom(geom: BaseGeometry) -> bytearray:
+ total_size = 8
if isinstance(geom, Point):
geom_type = GeometryTypeID.POINT
elif isinstance(geom, LineString):
@@ -503,11 +508,13 @@ def serialize_shapely_1_empty_geom(geom: BaseGeometry) -> bytearray:
geom_type = GeometryTypeID.MULTIPOINT
elif isinstance(geom, MultiLineString):
geom_type = GeometryTypeID.MULTILINESTRING
+ total_size = 12
elif isinstance(geom, MultiPolygon):
geom_type = GeometryTypeID.MULTIPOLYGON
+ total_size = 12
else:
raise ValueError("Invalid empty geometry collection object: {}".format(geom))
- return create_buffer_for_geom(geom_type, CoordinateType.XY, 8, 0)
+ return create_buffer_for_geom(geom_type, CoordinateType.XY, total_size, 0)
def deserialize_geometry_collection(geom_buffer: GeometryBuffer) -> GeometryCollection:
diff --git a/python/tests/utils/test_geometry_serde.py b/python/tests/utils/test_geometry_serde.py
index 5818ebbd..e6c8fc8c 100644
--- a/python/tests/utils/test_geometry_serde.py
+++ b/python/tests/utils/test_geometry_serde.py
@@ -16,8 +16,9 @@
# under the License.
import pytest
-from pyspark.sql.types import StructType
+from pyspark.sql.types import (StructType, StringType)
from sedona.sql.types import GeometryType
+from pyspark.sql.functions import expr
from sedona.utils import geometry_serde
from shapely.geometry.base import BaseGeometry
@@ -56,6 +57,101 @@ class TestGeometrySerde(TestBase):
returned_geom = TestGeometrySerde.spark.createDataFrame([(geom,)], StructType().add("geom", GeometryType())).take(1)[0][0]
assert geom.equals_exact(returned_geom, 1e-6)
+ @pytest.mark.parametrize("wkt", [
+ # empty geometries
+ 'POINT EMPTY',
+ 'LINESTRING EMPTY',
+ 'POLYGON EMPTY',
+ 'MULTIPOINT EMPTY',
+ 'MULTILINESTRING EMPTY',
+ 'MULTIPOLYGON EMPTY',
+ 'GEOMETRYCOLLECTION EMPTY',
+ # non-empty geometries
+ 'POINT (10 20)',
+ 'POINT (10 20 30)',
+ 'LINESTRING (10 20, 30 40)',
+ 'LINESTRING (10 20 30, 40 50 60)',
+ 'POLYGON ((10 10, 20 20, 20 10, 10 10))',
+ 'POLYGON ((10 10 10, 20 20 10, 20 10 10, 10 10 10))',
+ 'POLYGON ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1))',
+ # non-empty multi geometries
+ 'MULTIPOINT ((10 20), (30 40))',
+ 'MULTIPOINT ((10 20 30), (40 50 60))',
+ 'MULTILINESTRING ((10 20, 30 40), (50 60, 70 80))',
+ 'MULTILINESTRING ((10 20 30, 40 50 60), (70 80 90, 100 110 120))',
+ 'MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((-10 -10, -20 -20, -20 -10, -10 -10)))',
+ 'MULTIPOLYGON (((10 10, 20 20, 20 10, 10 10)), ((0 0, 0 10, 10 10, 10 0, 0 0), (1 1, 1 2, 2 2, 2 1, 1 1)))',
+ 'GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40))',
+ 'GEOMETRYCOLLECTION (POINT (10 20 30), LINESTRING (10 20 30, 40 50 60))',
+ 'GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40), POLYGON ((10 10, 20 20, 20 10, 10 10)))',
+ # nested geometry collection
+ 'GEOMETRYCOLLECTION (GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))',
+ 'GEOMETRYCOLLECTION (POINT (1 2), GEOMETRYCOLLECTION (POINT (10 20), LINESTRING (10 20, 30 40)))',
+ # multi geometries containing empty geometries
+ 'MULTIPOINT (EMPTY, (10 20))',
+ 'MULTIPOINT (EMPTY, EMPTY)',
+ 'MULTILINESTRING (EMPTY, (10 20, 30 40))',
+ 'MULTILINESTRING (EMPTY, EMPTY)',
+ 'MULTIPOLYGON (EMPTY, ((10 10, 20 20, 20 10, 10 10)))',
+ 'MULTIPOLYGON (EMPTY, EMPTY)',
+ 'GEOMETRYCOLLECTION (POINT (10 20), POINT EMPTY, LINESTRING (10 20, 30 40))',
+ 'GEOMETRYCOLLECTION (MULTIPOINT EMPTY, MULTILINESTRING EMPTY, MULTIPOLYGON EMPTY, GEOMETRYCOLLECTION EMPTY)',
+ ])
+ def test_spark_serde_compatibility_with_scala(self, wkt):
+ geom = wkt_loads(wkt)
+ schema = StructType().add("geom", GeometryType())
+ returned_geom = TestGeometrySerde.spark.createDataFrame([(geom,)], schema).take(1)[0][0]
+ assert geom.equals(returned_geom)
+
+ # serialized by python, deserialized by scala
+ returned_wkt = TestGeometrySerde.spark.createDataFrame([(geom,)], schema).selectExpr("ST_AsText(geom)").take(1)[0][0]
+ assert wkt_loads(returned_wkt).equals(geom)
+
+ # serialized by scala, deserialized by python
+ schema = StructType().add("wkt", StringType())
+ returned_geom = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema).selectExpr("ST_GeomFromText(wkt)").take(1)[0][0]
+ assert geom.equals(returned_geom)
+
+ @pytest.mark.parametrize("wkt", [
+ 'POINT ZM (1 2 3 4)',
+ 'LINESTRING ZM (1 2 3 4, 5 6 7 8)',
+ 'POLYGON ZM ((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1))',
+ 'MULTIPOINT ZM ((10 20 30 1), (40 50 60 1))',
+ 'MULTILINESTRING ZM ((10 20 30 1, 40 50 60 1), (70 80 90 1, 100 110 120 1))',
+ 'MULTIPOLYGON ZM (((10 10 10 1, 20 20 10 1, 20 10 10 1, 10 10 10 1)), ' +
+ '((0 0 0 1, 0 10 0 1, 10 10 0 1, 10 0 0 1, 0 0 0 1), (1 1 0 1, 1 2 0 1, 2 2 0 1, 2 1 0 1, 1 1 0 1)))',
+ 'GEOMETRYCOLLECTION (POINT ZM (10 20 30 1), LINESTRING ZM (10 20 30 1, 40 50 60 1))',
+ ])
+ def test_spark_serde_on_4d_geoms(self, wkt):
+ geom = wkt_loads(wkt)
+ schema = StructType().add("wkt", StringType())
+ returned_geom, n_dims = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema)\
+ .selectExpr("ST_GeomFromText(wkt)", "ST_NDims(ST_GeomFromText(wkt))")\
+ .take(1)[0]
+ assert n_dims == 4
+ assert geom.equals(returned_geom)
+
+ @pytest.mark.parametrize("wkt", [
+ 'POINT M (1 2 3)',
+ 'LINESTRING M (1 2 3, 5 6 7)',
+ 'POLYGON M ((10 10 10, 20 20 10, 20 10 10, 10 10 10))',
+ 'MULTIPOINT M ((10 20 30), (40 50 60))',
+ 'MULTILINESTRING M ((10 20 30, 40 50 60), (70 80 90, 100 110 120))',
+ 'MULTIPOLYGON M (((10 10 10, 20 20 10, 20 10 10, 10 10 10)), ' +
+ '((0 0 0, 0 10 0, 10 10 0, 10 0 0, 0 0 0), (1 1 0, 1 2 0, 2 2 0, 2 1 0, 1 1 0)))',
+ 'GEOMETRYCOLLECTION (POINT M (10 20 30), LINESTRING M (10 20 30, 40 50 60))',
+ ])
+ def test_spark_serde_on_xym_geoms(self, wkt):
+ geom = wkt_loads(wkt)
+ schema = StructType().add("wkt", StringType())
+ returned_geom, n_dims, z_min = TestGeometrySerde.spark.createDataFrame([(wkt,)], schema) \
+ .withColumn("geom", expr("ST_GeomFromText(wkt)")) \
+ .selectExpr("geom", "ST_NDims(geom)", "ST_ZMin(geom)") \
+ .take(1)[0]
+ assert n_dims == 3
+ assert z_min is None
+ assert geom.equals(returned_geom)
+
def test_point(self):
points = [
wkt_loads("POINT EMPTY"),