You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/06/15 21:37:06 UTC

[incubator-sedona] branch master updated: [SEDONA-127] Add null safety to ST_GeomFromWKT/WKB/Text (#631)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 308b8cfc [SEDONA-127] Add null safety to ST_GeomFromWKT/WKB/Text (#631)
308b8cfc is described below

commit 308b8cfcfa5e08491d0c9ab6fc9df6a863744dfa
Author: Martin Andersson <u....@gmail.com>
AuthorDate: Wed Jun 15 23:37:01 2022 +0200

    [SEDONA-127] Add null safety to ST_GeomFromWKT/WKB/Text (#631)
---
 .../sql/sedona_sql/expressions/Constructors.scala  | 58 ++++++++++++----------
 .../apache/sedona/sql/constructorTestScala.scala   |  7 +++
 2 files changed, 38 insertions(+), 27 deletions(-)

diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index ff9aa605..af5eae44 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -172,14 +172,17 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
   // This is an expression which takes one input expressions
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(inputRow: InternalRow): Any = {
-    val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
-    var fileDataSplitter = FileDataSplitter.WKT
-    var formatMapper = new FormatMapper(fileDataSplitter, false)
-    var geometry = formatMapper.readGeometry(geomString)
-    new GenericArrayData(GeometrySerializer.serialize(geometry))
+    (inputExpressions(0).eval(inputRow)) match {
+      case (geomString: UTF8String) => {
+        var fileDataSplitter = FileDataSplitter.WKT
+        var formatMapper = new FormatMapper(fileDataSplitter, false)
+        formatMapper.readGeometry(geomString.toString).toGenericArrayData
+      }
+      case _ => null
+    }
   }
 
   override def dataType: DataType = GeometryUDT
@@ -202,14 +205,17 @@ case class ST_GeomFromText(inputExpressions: Seq[Expression])
   // This is an expression which takes one input expressions
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(inputRow: InternalRow): Any = {
-    val geomString = inputExpressions(0).eval(inputRow).asInstanceOf[UTF8String].toString
-    var fileDataSplitter = FileDataSplitter.WKT
-    var formatMapper = new FormatMapper(fileDataSplitter, false)
-    var geometry = formatMapper.readGeometry(geomString)
-    new GenericArrayData(GeometrySerializer.serialize(geometry))
+    (inputExpressions(0).eval(inputRow)) match {
+      case (geomString: UTF8String) => {
+        var fileDataSplitter = FileDataSplitter.WKT
+        var formatMapper = new FormatMapper(fileDataSplitter, false)
+        formatMapper.readGeometry(geomString.toString).toGenericArrayData
+      }
+      case _ => null
+    }
   }
 
   override def dataType: DataType = GeometryUDT
@@ -232,23 +238,21 @@ case class ST_GeomFromWKB(inputExpressions: Seq[Expression])
   // This is an expression which takes one input expressions
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(inputRow: InternalRow): Any = {
-    if (inputExpressions.head.dataType.equals(StringType)) {
-      // Parse UTF-8 encoded wkb string
-      val geomString = inputExpressions.head.eval(inputRow).asInstanceOf[UTF8String].toString
-      val fileDataSplitter = FileDataSplitter.WKB
-      val formatMapper = new FormatMapper(fileDataSplitter, false)
-      val geometry = formatMapper.readGeometry(geomString)
-      new GenericArrayData(GeometrySerializer.serialize(geometry))
-    }
-    else if (inputExpressions.head.dataType.equals(BinaryType)) {
-      // convert raw wkb byte array to geometry
-      val wkbReader = new WKBReader()
-      val wkb = inputExpressions.head.eval(inputRow).asInstanceOf[Array[Byte]]
-      val geometry = wkbReader.read(wkb)
-      new GenericArrayData(GeometrySerializer.serialize(geometry))
+    (inputExpressions.head.eval(inputRow)) match {
+      case (geomString: UTF8String) => {
+        // Parse UTF-8 encoded wkb string
+        val fileDataSplitter = FileDataSplitter.WKB
+        val formatMapper = new FormatMapper(fileDataSplitter, false)
+        formatMapper.readGeometry(geomString.toString).toGenericArrayData
+      }
+      case (wkb: Array[Byte]) => {
+        // convert raw wkb byte array to geometry
+        new WKBReader().read(wkb).toGenericArrayData
+      }
+      case _ => null
     }
   }
 
diff --git a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 3c754f3e..69e4e8d6 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -67,6 +67,8 @@ class constructorTestScala extends TestBaseScala {
       polygonWktDf.createOrReplaceTempView("polygontable")
       var polygonDf = sparkSession.sql("select ST_GeomFromWkt(polygontable._c0) as countyshape from polygontable")
       assert(polygonDf.count() == 100)
+      val nullGeom = sparkSession.sql("select ST_GeomFromWKT(null)")
+      assert(nullGeom.first().isNullAt(0))
     }
 
     it("Passed ST_LineFromText") {
@@ -98,6 +100,8 @@ class constructorTestScala extends TestBaseScala {
       polygonWktDf.createOrReplaceTempView("polygontable")
       var polygonDf = sparkSession.sql("select ST_GeomFromText(polygontable._c0) as countyshape from polygontable")
       assert(polygonDf.count() == 100)
+      val nullGeom = sparkSession.sql("select ST_GeomFromText(null)")
+      assert(nullGeom.first().isNullAt(0))
     }
 
     it("Passed ST_GeomFromWKT multipolygon read as polygon bug") {
@@ -126,6 +130,9 @@ class constructorTestScala extends TestBaseScala {
       val geometries = sparkSession.sql("SELECT ST_GeomFromWKB(rawWKBTable.wkb) as countyshape from rawWKBTable")
       val expectedGeom = "LINESTRING (-2.1047439575195312 -0.354827880859375, -1.49606454372406 -0.6676061153411865)";
       assert(geometries.first().getAs[Geometry](0).toString.equals(expectedGeom))
+      // null input
+      val nullGeom = sparkSession.sql("SELECT ST_GeomFromWKB(null)")
+      assert(nullGeom.first().isNullAt(0))
     }
 
     it("Passed ST_GeomFromGeoJSON") {