You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/06/15 21:37:06 UTC
[incubator-sedona] branch master updated: [SEDONA-127] Add null safety to ST_GeomFromWKT/WKB/Text (#631)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 308b8cfc [SEDONA-127] Add null safety to ST_GeomFromWKT/WKB/Text (#631)
308b8cfc is described below
commit 308b8cfcfa5e08491d0c9ab6fc9df6a863744dfa
Author: Martin Andersson <u....@gmail.com>
AuthorDate: Wed Jun 15 23:37:01 2022 +0200
[SEDONA-127] Add null safety to ST_GeomFromWKT/WKB/Text (#631)
---
.../sql/sedona_sql/expressions/Constructors.scala | 58 ++++++++++++----------
.../apache/sedona/sql/constructorTestScala.scala | 7 +++
2 files changed, 38 insertions(+), 27 deletions(-)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index ff9aa605..af5eae44 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -172,14 +172,17 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
// This is an expression which takes one input expressions
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(inputRow: InternalRow): Any = {
- val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
- var fileDataSplitter = FileDataSplitter.WKT
- var formatMapper = new FormatMapper(fileDataSplitter, false)
- var geometry = formatMapper.readGeometry(geomString)
- new GenericArrayData(GeometrySerializer.serialize(geometry))
+ (inputExpressions(0).eval(inputRow)) match {
+ case (geomString: UTF8String) => {
+ var fileDataSplitter = FileDataSplitter.WKT
+ var formatMapper = new FormatMapper(fileDataSplitter, false)
+ formatMapper.readGeometry(geomString.toString).toGenericArrayData
+ }
+ case _ => null
+ }
}
override def dataType: DataType = GeometryUDT
@@ -202,14 +205,17 @@ case class ST_GeomFromText(inputExpressions: Seq[Expression])
// This is an expression which takes one input expressions
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(inputRow: InternalRow): Any = {
- val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
- var fileDataSplitter = FileDataSplitter.WKT
- var formatMapper = new FormatMapper(fileDataSplitter, false)
- var geometry = formatMapper.readGeometry(geomString)
- new GenericArrayData(GeometrySerializer.serialize(geometry))
+ (inputExpressions(0).eval(inputRow)) match {
+ case (geomString: UTF8String) => {
+ var fileDataSplitter = FileDataSplitter.WKT
+ var formatMapper = new FormatMapper(fileDataSplitter, false)
+ formatMapper.readGeometry(geomString.toString).toGenericArrayData
+ }
+ case _ => null
+ }
}
override def dataType: DataType = GeometryUDT
@@ -232,23 +238,21 @@ case class ST_GeomFromWKB(inputExpressions: Seq[Expression])
// This is an expression which takes one input expressions
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(inputRow: InternalRow): Any = {
- if (inputExpressions.head.dataType.equals(StringType)) {
- // Parse UTF-8 encoded wkb string
- val geomString = inputExpressions.head.eval(inputRow).asInstanceOf[UTF8String].toString
- val fileDataSplitter = FileDataSplitter.WKB
- val formatMapper = new FormatMapper(fileDataSplitter, false)
- val geometry = formatMapper.readGeometry(geomString)
- new GenericArrayData(GeometrySerializer.serialize(geometry))
- }
- else if (inputExpressions.head.dataType.equals(BinaryType)) {
- // convert raw wkb byte array to geometry
- val wkbReader = new WKBReader()
- val wkb = inputExpressions.head.eval(inputRow).asInstanceOf[Array[Byte]]
- val geometry = wkbReader.read(wkb)
- new GenericArrayData(GeometrySerializer.serialize(geometry))
+ (inputExpressions.head.eval(inputRow)) match {
+ case (geomString: UTF8String) => {
+ // Parse UTF-8 encoded wkb string
+ val fileDataSplitter = FileDataSplitter.WKB
+ val formatMapper = new FormatMapper(fileDataSplitter, false)
+ formatMapper.readGeometry(geomString.toString).toGenericArrayData
+ }
+ case (wkb: Array[Byte]) => {
+ // convert raw wkb byte array to geometry
+ new WKBReader().read(wkb).toGenericArrayData
+ }
+ case _ => null
}
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 3c754f3e..69e4e8d6 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -67,6 +67,8 @@ class constructorTestScala extends TestBaseScala {
polygonWktDf.createOrReplaceTempView("polygontable")
var polygonDf = sparkSession.sql("select ST_GeomFromWkt(polygontable._c0) as countyshape from polygontable")
assert(polygonDf.count() == 100)
+ val nullGeom = sparkSession.sql("select ST_GeomFromWKT(null)")
+ assert(nullGeom.first().isNullAt(0))
}
it("Passed ST_LineFromText") {
@@ -98,6 +100,8 @@ class constructorTestScala extends TestBaseScala {
polygonWktDf.createOrReplaceTempView("polygontable")
var polygonDf = sparkSession.sql("select ST_GeomFromText(polygontable._c0) as countyshape from polygontable")
assert(polygonDf.count() == 100)
+ val nullGeom = sparkSession.sql("select ST_GeomFromText(null)")
+ assert(nullGeom.first().isNullAt(0))
}
it("Passed ST_GeomFromWKT multipolygon read as polygon bug") {
@@ -126,6 +130,9 @@ class constructorTestScala extends TestBaseScala {
val geometries = sparkSession.sql("SELECT ST_GeomFromWKB(rawWKBTable.wkb) as countyshape from rawWKBTable")
val expectedGeom = "LINESTRING (-2.1047439575195312 -0.354827880859375, -1.49606454372406 -0.6676061153411865)";
assert(geometries.first().getAs[Geometry](0).toString.equals(expectedGeom))
+ // null input
+ val nullGeom = sparkSession.sql("SELECT ST_GeomFromWKB(null)")
+ assert(nullGeom.first().isNullAt(0))
}
it("Passed ST_GeomFromGeoJSON") {