You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/11/16 19:24:58 UTC

[incubator-sedona] branch master updated: [SEDONA-195] Add wkt validation and an optional srid to ST_GeomFromWKT/ST_GeomFromText (#714)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 01d314b3 [SEDONA-195] Add wkt validation and an optional srid to ST_GeomFromWKT/ST_GeomFromText (#714)
01d314b3 is described below

commit 01d314b3b0f93a67f3b2094dec5bf959166d2eb2
Author: Martin Andersson <u....@gmail.com>
AuthorDate: Wed Nov 16 20:24:53 2022 +0100

    [SEDONA-195] Add wkt validation and an optional srid to ST_GeomFromWKT/ST_GeomFromText (#714)
---
 .../org/apache/sedona/common/Constructors.java     | 18 +++++++
 .../org/apache/sedona/common/ConstructorsTest.java | 26 ++++++++++
 docs/api/sql/Constructor.md                        | 10 +++-
 .../sedona/flink/expressions/Constructors.java     |  4 +-
 .../scala/org/apache/sedona/sql/UDF/Catalog.scala  |  4 +-
 .../sql/sedona_sql/expressions/Constructors.scala  | 55 ++--------------------
 .../sedona_sql/expressions/st_constructors.scala   | 16 +++++--
 .../apache/sedona/sql/constructorTestScala.scala   |  8 ++++
 .../apache/sedona/sql/dataFrameAPITestScala.scala  | 14 ++++++
 9 files changed, 95 insertions(+), 60 deletions(-)

diff --git a/common/src/main/java/org/apache/sedona/common/Constructors.java b/common/src/main/java/org/apache/sedona/common/Constructors.java
new file mode 100644
index 00000000..f85a5b56
--- /dev/null
+++ b/common/src/main/java/org/apache/sedona/common/Constructors.java
@@ -0,0 +1,18 @@
+package org.apache.sedona.common;
+
+import org.locationtech.jts.geom.Geometry;
+import org.locationtech.jts.geom.GeometryFactory;
+import org.locationtech.jts.geom.PrecisionModel;
+import org.locationtech.jts.io.ParseException;
+import org.locationtech.jts.io.WKTReader;
+
+public class Constructors {
+
+    public static Geometry geomFromWKT(String wkt, int srid) throws ParseException {
+        if (wkt == null) {
+            return null;
+        }
+        GeometryFactory geometryFactory = new GeometryFactory(new PrecisionModel(), srid);
+        return new WKTReader(geometryFactory).read(wkt);
+    }
+}
diff --git a/common/src/test/java/org/apache/sedona/common/ConstructorsTest.java b/common/src/test/java/org/apache/sedona/common/ConstructorsTest.java
new file mode 100644
index 00000000..9d2b3b4e
--- /dev/null
+++ b/common/src/test/java/org/apache/sedona/common/ConstructorsTest.java
@@ -0,0 +1,26 @@
+package org.apache.sedona.common;
+
+import org.junit.Test;
+import org.locationtech.jts.geom.Geometry;
+import org.locationtech.jts.io.ParseException;
+
+import static org.junit.Assert.*;
+
+public class ConstructorsTest {
+
+    @Test
+    public void geomFromWKT() throws ParseException {
+        assertNull(Constructors.geomFromWKT(null, 0));
+
+        Geometry geom = Constructors.geomFromWKT("POINT (1 1)", 0);
+        assertEquals(0, geom.getSRID());
+        assertEquals("POINT (1 1)", geom.toText());
+
+        geom = Constructors.geomFromWKT("POINT (1 1)", 3006);
+        assertEquals(3006, geom.getSRID());
+        assertEquals("POINT (1 1)", geom.toText());
+
+        ParseException invalid = assertThrows(ParseException.class, () -> Constructors.geomFromWKT("not valid", 0));
+        assertEquals("Unknown geometry type: NOT (line 1)", invalid.getMessage());
+    }
+}
\ No newline at end of file
diff --git a/docs/api/sql/Constructor.md b/docs/api/sql/Constructor.md
index 903b4164..39ae94a8 100644
--- a/docs/api/sql/Constructor.md
+++ b/docs/api/sql/Constructor.md
@@ -120,13 +120,16 @@ SELECT ST_GeomFromKML('<LineString><coordinates>-71.1663,42.2614 -71.1667,42.261
 
 ## ST_GeomFromText
 
-Introduction: Construct a Geometry from Wkt. Alias of [ST_GeomFromWKT](#ST_GeomFromWKT)
+Introduction: Construct a Geometry from Wkt. If srid is not set, it defaults to 0 (unknown). Alias of [ST_GeomFromWKT](#ST_GeomFromWKT)
 
 Format:
 `ST_GeomFromText (Wkt:string)`
+`ST_GeomFromText (Wkt:string, srid:integer)`
 
 Since: `v1.0.0`
 
+The optional srid parameter was added in `v1.3.1`
+
 Spark SQL example:
 ```SQL
 SELECT ST_GeomFromText('POINT(40.7128 -74.0060)') AS geometry
@@ -150,13 +153,16 @@ FROM polygontable
 
 ## ST_GeomFromWKT
 
-Introduction: Construct a Geometry from Wkt
+Introduction: Construct a Geometry from Wkt. If srid is not set, it defaults to 0 (unknown).
 
 Format:
 `ST_GeomFromWKT (Wkt:string)`
+`ST_GeomFromWKT (Wkt:string, srid:integer)`
 
 Since: `v1.0.0`
 
+The optional srid parameter was added in `v1.3.1`
+
 Spark SQL example:
 ```SQL
 SELECT ST_GeomFromWKT(polygontable._c0) AS polygonshape
diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Constructors.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Constructors.java
index 0f62291d..f7d60000 100644
--- a/flink/src/main/java/org/apache/sedona/flink/expressions/Constructors.java
+++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Constructors.java
@@ -120,14 +120,14 @@ public class Constructors {
     public static class ST_GeomFromWKT extends ScalarFunction {
         @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
         public Geometry eval(@DataTypeHint("String") String wktString) throws ParseException {
-            return getGeometryByFileData(wktString, FileDataSplitter.WKT);
+            return org.apache.sedona.common.Constructors.geomFromWKT(wktString, 0);
         }
     }
 
     public static class ST_GeomFromText extends ScalarFunction {
         @DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class)
         public Geometry eval(@DataTypeHint("String") String wktString) throws ParseException {
-            return new ST_GeomFromWKT().eval(wktString);
+            return org.apache.sedona.common.Constructors.geomFromWKT(wktString, 0);
         }
     }
 
diff --git a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
index 15a6239b..66c0ae70 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala
@@ -39,9 +39,9 @@ object Catalog {
     function[ST_PointFromText](),
     function[ST_PolygonFromText](),
     function[ST_LineStringFromText](),
-    function[ST_GeomFromText](),
+    function[ST_GeomFromText](0),
     function[ST_LineFromText](),
-    function[ST_GeomFromWKT](),
+    function[ST_GeomFromWKT](0),
     function[ST_GeomFromWKB](),
     function[ST_GeomFromGeoJSON](),
     function[ST_GeomFromGML](),
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index bee7a814..6e4909d0 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -34,9 +34,8 @@ import org.locationtech.jts.geom.{Coordinate, GeometryFactory}
 import org.locationtech.jts.io.WKBReader
 import org.locationtech.jts.io.gml2.GMLReader
 import org.locationtech.jts.io.kml.KMLReader
-import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes
 import org.apache.spark.sql.catalyst.expressions.ImplicitCastInputTypes
-import org.apache.commons.collections.collection.TypedCollection
+import org.apache.sedona.common.Constructors
 import org.apache.spark.sql.types.BinaryType
 
 /**
@@ -179,31 +178,10 @@ case class ST_LineStringFromText(inputExpressions: Seq[Expression])
 /**
   * Return a Geometry from a WKT string
   *
-  * @param inputExpressions This function takes 1 parameter which is the geometry string. The string format must be WKT.
+  * @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT.
   */
 case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
-  extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator {
-  // This is an expression which takes one input expressions
-  assert(inputExpressions.length == 1)
-
-  override def nullable: Boolean = true
-
-  override def eval(inputRow: InternalRow): Any = {
-    (inputExpressions(0).eval(inputRow)) match {
-      case (geomString: UTF8String) => {
-        var fileDataSplitter = FileDataSplitter.WKT
-        var formatMapper = new FormatMapper(fileDataSplitter, false)
-        formatMapper.readGeometry(geomString.toString).toGenericArrayData
-      }
-      case null => null
-    }
-  }
-
-  override def dataType: DataType = GeometryUDT
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
-
-  override def children: Seq[Expression] = inputExpressions
+  extends InferredBinaryExpression(Constructors.geomFromWKT) with FoldableExpression {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
     copy(inputExpressions = newChildren)
@@ -214,33 +192,10 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
 /**
   * Return a Geometry from a WKT string
   *
-  * @param inputExpressions This function takes 1 parameter which is the geometry string. The string format must be WKT.
+  * @param inputExpressions This function takes a geometry string and a srid. The string format must be WKT.
   */
 case class ST_GeomFromText(inputExpressions: Seq[Expression])
-  extends Expression with FoldableExpression with ImplicitCastInputTypes with CodegenFallback with UserDataGeneratator with ExpectsInputTypes {
-  // This is an expression which takes one input expressions
-  assert(inputExpressions.length == 1)
-
-  override def nullable: Boolean = true
-
-  override def foldable: Boolean = inputExpressions.forall(_.foldable)
-
-  override def eval(inputRow: InternalRow): Any = {
-    (inputExpressions(0).eval(inputRow)) match {
-      case (geomString: UTF8String) => {
-        var fileDataSplitter = FileDataSplitter.WKT
-        var formatMapper = new FormatMapper(fileDataSplitter, false)
-        formatMapper.readGeometry(geomString.toString).toGenericArrayData
-      }
-      case null => null
-    }
-  }
-
-  override def dataType: DataType = GeometryUDT
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
-
-  override def children: Seq[Expression] = inputExpressions
+  extends InferredBinaryExpression(Constructors.geomFromWKT) with FoldableExpression {
 
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = {
     copy(inputExpressions = newChildren)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
index 9f1cee01..2a38b0a4 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_constructors.scala
@@ -34,14 +34,22 @@ object st_constructors extends DataFrameAPI {
   def ST_GeomFromKML(kmlString: Column): Column = wrapExpression[ST_GeomFromKML](kmlString)
   def ST_GeomFromKML(kmlString: String): Column = wrapExpression[ST_GeomFromKML](kmlString)
 
-  def ST_GeomFromText(wkt: Column): Column = wrapExpression[ST_GeomFromText](wkt)
-  def ST_GeomFromText(wkt: String): Column = wrapExpression[ST_GeomFromText](wkt)
+  def ST_GeomFromText(wkt: Column): Column = wrapExpression[ST_GeomFromText](wkt, 0)
+  def ST_GeomFromText(wkt: String): Column = wrapExpression[ST_GeomFromText](wkt, 0)
+
+  def ST_GeomFromText(wkt: Column, srid: Column): Column = wrapExpression[ST_GeomFromText](wkt, srid)
+
+  def ST_GeomFromText(wkt: String, srid: Int): Column = wrapExpression[ST_GeomFromText](wkt, srid)
 
   def ST_GeomFromWKB(wkb: Column): Column = wrapExpression[ST_GeomFromWKB](wkb)
   def ST_GeomFromWKB(wkb: String): Column = wrapExpression[ST_GeomFromWKB](wkb)
 
-  def ST_GeomFromWKT(wkt: Column): Column = wrapExpression[ST_GeomFromWKT](wkt)
-  def ST_GeomFromWKT(wkt: String): Column = wrapExpression[ST_GeomFromWKT](wkt)
+  def ST_GeomFromWKT(wkt: Column): Column = wrapExpression[ST_GeomFromWKT](wkt, 0)
+  def ST_GeomFromWKT(wkt: String): Column = wrapExpression[ST_GeomFromWKT](wkt, 0)
+
+  def ST_GeomFromWKT(wkt: Column, srid: Column): Column = wrapExpression[ST_GeomFromWKT](wkt, srid)
+
+  def ST_GeomFromWKT(wkt: String, srid: Int): Column = wrapExpression[ST_GeomFromWKT](wkt, srid)
 
   def ST_LineFromText(wkt: Column): Column = wrapExpression[ST_LineFromText](wkt)
   def ST_LineFromText(wkt: String): Column = wrapExpression[ST_LineFromText](wkt)
diff --git a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 4a798cc2..8f794065 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -75,6 +75,14 @@ class constructorTestScala extends TestBaseScala {
       }
     }
 
+    it("Passed ST_GeomFromWKT invalid input") {
+      // Fail on non wkt strings
+      val thrown = intercept[Exception] {
+        sparkSession.sql("SELECT ST_GeomFromWKT('not wkt')").collect()
+      }
+      assert(thrown.getMessage == "Unknown geometry type: NOT (line 1)")
+    }
+
     it("Passed ST_LineFromText") {
       val geometryDf = Seq("Linestring(1 2, 3 4)").map(wkt => Tuple1(wkt)).toDF("geom")
       geometryDf.createOrReplaceTempView("linetable")
diff --git a/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
index 88b23a9a..47790e55 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala
@@ -78,6 +78,13 @@ class dataFrameAPITestScala extends TestBaseScala {
       assert(actualResult == expectedResult)
     }
 
+    it("passed st_geomfromwkt with srid") {
+      val df = sparkSession.sql("SELECT 'POINT(0.0 1.0)' AS wkt").select(ST_GeomFromWKT("wkt", 4326))
+      val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry]
+      assert(actualResult.toText == "POINT (0 1)")
+      assert(actualResult.getSRID == 4326)
+    }
+
     it("passed st_geomfromtext") {
       val df = sparkSession.sql("SELECT 'POINT(0.0 1.0)' AS wkt").select(ST_GeomFromText("wkt"))
       val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry].toText()
@@ -85,6 +92,13 @@ class dataFrameAPITestScala extends TestBaseScala {
       assert(actualResult == expectedResult)
     }
 
+    it("passed st_geomfromtext with srid") {
+      val df = sparkSession.sql("SELECT 'POINT(0.0 1.0)' AS wkt").select(ST_GeomFromText("wkt", 4326))
+      val actualResult = df.take(1)(0).get(0).asInstanceOf[Geometry]
+      assert(actualResult.toText == "POINT (0 1)")
+      assert(actualResult.getSRID == 4326)
+    }
+
     it("passed st_geomfromwkb") {
       val wkbSeq = Seq[Array[Byte]](Array[Byte](1, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, -124, -42, 0, -64, 0, 0, 0, 0, -128, -75, -42, -65, 0, 0, 0, 96, -31, -17, -9, -65, 0, 0, 0, -128, 7, 93, -27, -65))
       val df = wkbSeq.toDF("wkb").select(ST_GeomFromWKB("wkb"))