You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/02/23 06:39:20 UTC

[incubator-sedona] branch master updated: [SEDONA-82] Fixes in ST_Difference and ST_SymDifference (#584)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 754836e  [SEDONA-82] Fixes in ST_Difference and ST_SymDifference (#584)
754836e is described below

commit 754836e2da847ab5a187546fafe8f5e46980a7dd
Author: Magdalena <69...@users.noreply.github.com>
AuthorDate: Wed Feb 23 07:39:10 2022 +0100

    [SEDONA-82] Fixes in ST_Difference and ST_SymDifference (#584)
---
 .../sql/sedona_sql/expressions/Functions.scala     | 26 +++++++++-------------
 .../org/apache/sedona/sql/functionTestScala.scala  | 12 ++++++++--
 2 files changed, 21 insertions(+), 17 deletions(-)

diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 08916f9..4f29dbf 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -1444,6 +1444,11 @@ case class ST_GeoHash(inputExpressions: Seq[Expression])
   }
 }
 
+/**
+ * Return the difference between geometry A and B
+ *
+ * @param inputExpressions
+ */
 case class ST_Difference(inputExpressions: Seq[Expression])
   extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
@@ -1456,14 +1461,14 @@ case class ST_Difference(inputExpressions: Seq[Expression])
     lazy val isRightContainsLeft = rightGeometry.contains(leftGeometry)
 
     if (!isIntersects) {
-      return new GenericArrayData(GeometrySerializer.serialize(leftGeometry))
+      new GenericArrayData(GeometrySerializer.serialize(leftGeometry))
     }
 
     if (isIntersects && isRightContainsLeft) {
-      return new GenericArrayData(GeometrySerializer.serialize(emptyPolygon))
+      new GenericArrayData(GeometrySerializer.serialize(emptyPolygon))
     }
 
-    return new GenericArrayData(GeometrySerializer.serialize(leftGeometry.difference(rightGeometry)))
+    new GenericArrayData(GeometrySerializer.serialize(leftGeometry.difference(rightGeometry)))
   }
 
   override def dataType: DataType = GeometryUDT
@@ -1481,20 +1486,11 @@ case class ST_Difference(inputExpressions: Seq[Expression])
  * @param inputExpressions
  */
 case class ST_SymDifference(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val leftGeometry = inputExpressions(0).toGeometry(input)
-    val rightGeometry = inputExpressions(1).toGeometry(input)
-
-    (leftGeometry, rightGeometry) match {
-      case (leftGeometry: Geometry, rightGeometry: Geometry)
-      => new GenericArrayData(GeometrySerializer.serialize(leftGeometry.symDifference(rightGeometry)))
-      case _ => null
-    }
+  override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
+    new GenericArrayData(GeometrySerializer.serialize(leftGeometry.symDifference(rightGeometry)))
   }
 
   override def dataType: DataType = GeometryUDT
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index 779b022..109628c 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -357,10 +357,10 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
 
     it("Passed ST_Difference - right not overlaps left") {
 
-      val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b")
+      val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3 -3, -3 3, 3 3, 3 -3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 5 -1, 5 -3))') as b")
       testtable.createOrReplaceTempView("testtable")
       val diff = sparkSession.sql("select ST_Difference(a,b) from testtable")
-      assert(diff.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))"))
+      assert(diff.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("POLYGON ((-3 -3, -3 3, 3 3, 3 -3, -3 -3))"))
     }
 
     it("Passed ST_Difference - left contains right") {
@@ -379,6 +379,14 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
       assert(diff.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("POLYGON EMPTY"))
     }
 
+    it("Passed ST_Difference - one null") {
+
+      val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3 -3, 3 -3, 3 3, -3 3, -3 -3))') as a")
+      testtable.createOrReplaceTempView("testtable")
+      val diff = sparkSession.sql("select ST_Difference(a,null) from testtable")
+      assert(diff.first().get(0) == null)
+    }
+
     it("Passed ST_SymDifference - part of right overlaps left") {
 
       val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-1 -1, 1 -1, 1 1, -1 1, -1 -1))') as a,ST_GeomFromWKT('POLYGON ((0 -2, 2 -2, 2 0, 0 0, 0 -2))') as b")