You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/02/06 23:48:20 UTC

[incubator-sedona] branch master updated: [SEDONA-4] Handle nulls in SQL functions (#578)

This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new e1eab69  [SEDONA-4] Handle nulls in SQL functions (#578)
e1eab69 is described below

commit e1eab692a1edf41bba5988f56f88b60b0c9cdb3d
Author: Kurtis Seebaldt <ku...@srpatx.com>
AuthorDate: Sun Feb 6 17:48:13 2022 -0600

    [SEDONA-4] Handle nulls in SQL functions (#578)
---
 .../sql/sedona_sql/expressions/Functions.scala     | 371 +++++++--------------
 .../expressions/NullSafeExpressions.scala          |  57 ++++
 .../sql/sedona_sql/expressions/implicits.scala     |   8 +
 .../org/apache/sedona/sql/functionTestScala.scala  | 105 ++++++
 4 files changed, 295 insertions(+), 246 deletions(-)

diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 77f92d2..7b5a011 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -49,8 +49,6 @@ import org.opengis.referencing.operation.MathTransform
 import org.wololo.jts2geojson.GeoJSONWriter
 
 import java.nio.ByteOrder
-import java.util
-import scala.collection.mutable.ArrayBuffer
 import scala.util.{Failure, Success, Try}
 
 /**
@@ -59,21 +57,12 @@ import scala.util.{Failure, Success, Try}
   * @param inputExpressions This function takes two geometries and calculates the distance between two objects.
   */
 case class ST_Distance(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
-  override def nullable: Boolean = false
-
   override def toString: String = s" **${ST_Distance.getClass.getName}**  "
 
-  override def eval(inputRow: InternalRow): Any = {
-    val leftArray = inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData]
-    val rightArray = inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData]
-
-    val leftGeometry = GeometrySerializer.deserialize(leftArray)
-
-    val rightGeometry = GeometrySerializer.deserialize(rightArray)
-
+  override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
     leftGeometry.distance(rightGeometry)
   }
 
@@ -87,21 +76,13 @@ case class ST_Distance(inputExpressions: Seq[Expression])
 }
 
 case class ST_3DDistance(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
-  override def nullable: Boolean = true
-
   override def toString: String = s" **${ST_3DDistance.getClass.getName}**  "
 
-  override def eval(inputRow: InternalRow): Any = {
-    val leftGeometry = inputExpressions(0).toGeometry(inputRow)
-    val rightGeometry = inputExpressions(1).toGeometry(inputRow)
-
-    (leftGeometry, rightGeometry) match {
-      case (leftGeometry: Geometry, rightGeometry: Geometry) => Distance3DOp.distance(leftGeometry, rightGeometry)
-      case _ => null
-    }
+  override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
+    Distance3DOp.distance(leftGeometry, rightGeometry)
   }
 
   override def dataType = DoubleType
@@ -119,13 +100,10 @@ case class ST_3DDistance(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_ConvexHull(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+  override def nullSafeEval(geometry: Geometry): Any = {
     new GenericArrayData(GeometrySerializer.serialize(geometry.convexHull()))
   }
 
@@ -144,16 +122,10 @@ case class ST_ConvexHull(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_NPoints(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
-  override def nullable: Boolean = false
+  extends UnaryGeometryExpression with CodegenFallback {
 
-  override def eval(input: InternalRow): Any = {
-    inputExpressions.length match {
-      case 1 =>
-        val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
-        geometry.getCoordinates.length
-      case _ => None
-    }
+  override def nullSafeEval(geometry: Geometry): Any = {
+    geometry.getCoordinates.length
   }
 
   override def dataType: DataType = IntegerType
@@ -174,16 +146,18 @@ case class ST_Buffer(inputExpressions: Seq[Expression])
   extends Expression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
     val buffer: Double = inputExpressions(1).eval(input) match {
       case a: Decimal => a.toDouble
       case a: Double => a
       case a: Int => a
     }
-    new GenericArrayData(GeometrySerializer.serialize(geometry.buffer(buffer)))
+    inputExpressions(0).toGeometry(input) match {
+      case geometry: Geometry => geometry.buffer(buffer).toGenericArrayData
+      case _ => null
+    }
   }
 
   override def dataType: DataType = GeometryUDT
@@ -202,13 +176,10 @@ case class ST_Buffer(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Envelope(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+  override def nullSafeEval(geometry: Geometry): Any = {
     new GenericArrayData(GeometrySerializer.serialize(geometry.getEnvelope()))
   }
 
@@ -227,13 +198,10 @@ case class ST_Envelope(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Length(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+  override def nullSafeEval(geometry: Geometry): Any = {
     geometry.getLength
   }
 
@@ -252,13 +220,10 @@ case class ST_Length(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Area(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+  override def nullSafeEval(geometry: Geometry): Any = {
     geometry.getArea
   }
 
@@ -277,13 +242,10 @@ case class ST_Area(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Centroid(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
+  override def nullSafeEval(geometry: Geometry): Any = {
     new GenericArrayData(GeometrySerializer.serialize(geometry.getCentroid()))
   }
 
@@ -305,22 +267,27 @@ case class ST_Transform(inputExpressions: Seq[Expression])
   extends Expression with CodegenFallback {
   assert(inputExpressions.length >= 3 && inputExpressions.length <= 4)
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(input: InternalRow): Any = {
-    val originalGeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
-    val sourceCRScode = CRS.decode(inputExpressions(1).eval(input).asInstanceOf[UTF8String].toString)
-    val targetCRScode = CRS.decode(inputExpressions(2).eval(input).asInstanceOf[UTF8String].toString)
-
-    var transform: MathTransform = null
-    if (inputExpressions.length == 4) {
-      transform = CRS.findMathTransform(sourceCRScode, targetCRScode, inputExpressions(3).eval(input).asInstanceOf[Boolean])
-    }
-    else {
-      transform = CRS.findMathTransform(sourceCRScode, targetCRScode, false)
+    val originalGeometry = inputExpressions(0).toGeometry(input)
+    val sourceCRS = inputExpressions(1).asString(input)
+    val targetCRS = inputExpressions(2).asString(input)
+
+    (originalGeometry, sourceCRS, targetCRS) match {
+      case (originalGeometry: Geometry, sourceCRS: String, targetCRS: String) =>
+        val sourceCRScode = CRS.decode(sourceCRS)
+        val targetCRScode = CRS.decode(targetCRS)
+        var transform: MathTransform = null
+        if (inputExpressions.length == 4) {
+          transform = CRS.findMathTransform(sourceCRScode, targetCRScode, inputExpressions(3).eval(input).asInstanceOf[Boolean])
+        }
+        else {
+          transform = CRS.findMathTransform(sourceCRScode, targetCRScode, false)
+        }
+        JTS.transform(originalGeometry, transform).toGenericArrayData
+      case (_, _, _) => null
     }
-    val geom = JTS.transform(originalGeometry, transform)
-    new GenericArrayData(GeometrySerializer.serialize(geom))
   }
 
   override def dataType: DataType = GeometryUDT
@@ -339,35 +306,30 @@ case class ST_Transform(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_Intersection(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
   lazy val GeometryFactory = new GeometryFactory()
   lazy val emptyPolygon = GeometryFactory.createPolygon(null, null)
 
-  override def nullable: Boolean = false
-
-  override def eval(inputRow: InternalRow): Any = {
-    val leftgeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData])
-    val rightgeometry = GeometrySerializer.deserialize(inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData])
-
-    val isIntersects = leftgeometry.intersects(rightgeometry)
-    lazy val isLeftContainsRight = leftgeometry.contains(rightgeometry)
-    lazy val isRightContainsLeft = rightgeometry.contains(leftgeometry)
+  override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
+    val isIntersects = leftGeometry.intersects(rightGeometry)
+    lazy val isLeftContainsRight = leftGeometry.contains(rightGeometry)
+    lazy val isRightContainsLeft = rightGeometry.contains(leftGeometry)
 
     if (!isIntersects) {
       return new GenericArrayData(GeometrySerializer.serialize(emptyPolygon))
     }
 
     if (isIntersects && isLeftContainsRight) {
-      return new GenericArrayData(GeometrySerializer.serialize(rightgeometry))
+      return new GenericArrayData(GeometrySerializer.serialize(rightGeometry))
     }
 
     if (isIntersects && isRightContainsLeft) {
-      return new GenericArrayData(GeometrySerializer.serialize(leftgeometry))
+      return new GenericArrayData(GeometrySerializer.serialize(leftGeometry))
     }
 
-    return new GenericArrayData(GeometrySerializer.serialize(leftgeometry.intersection(rightgeometry)))
+    new GenericArrayData(GeometrySerializer.serialize(leftGeometry.intersection(rightGeometry)))
   }
 
   override def dataType: DataType = GeometryUDT
@@ -436,19 +398,12 @@ case class ST_MakeValid(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_IsValid(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
 
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.eval(input).asInstanceOf[ArrayData]
-    if (geometry == null) {
-      null
-    } else {
-      val isvalidop = new IsValidOp(GeometrySerializer.deserialize(geometry))
-      isvalidop.isValid
-    }
+  override protected def nullSafeEval(geometry: Geometry): Any = {
+    new IsValidOp(geometry).isValid
   }
 
   override def dataType: DataType = BooleanType
@@ -466,15 +421,11 @@ case class ST_IsValid(inputExpressions: Seq[Expression])
   * @param inputExpressions
   */
 case class ST_IsSimple(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
-    val isSimpleop = new IsSimpleOp(geometry)
-    isSimpleop.isSimple
+  override protected def nullSafeEval(geometry: Geometry): Any = {
+    new IsSimpleOp(geometry).isSimple
   }
 
   override def dataType: DataType = BooleanType
@@ -498,18 +449,18 @@ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression])
   extends Expression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
     val distanceTolerance = inputExpressions(1).eval(input) match {
       case number: Decimal => number.toDouble
       case number: Double => number
       case number: Int => number.toDouble
     }
-    val simplifiedGeometry = TopologyPreservingSimplifier.simplify(geometry, distanceTolerance)
-
-    new GenericArrayData(GeometrySerializer.serialize(simplifiedGeometry))
+    inputExpressions(0).toGeometry(input) match {
+      case geometry: Geometry => TopologyPreservingSimplifier.simplify(geometry, distanceTolerance).toGenericArrayData
+      case _ => null
+    }
   }
 
   override def dataType: DataType = GeometryUDT
@@ -529,13 +480,16 @@ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression])
   */
 case class ST_PrecisionReduce(inputExpressions: Seq[Expression])
   extends Expression with CodegenFallback {
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
     val precisionScale = inputExpressions(1).eval(input).asInstanceOf[Int]
-    val precisionReduce = new GeometryPrecisionReducer(new PrecisionModel(Math.pow(10, precisionScale)))
-    new GenericArrayData(GeometrySerializer.serialize(precisionReduce.reduce(geometry)))
+    inputExpressions(0).toGeometry(input) match {
+      case geometry: Geometry =>
+        val precisionReduce =new GeometryPrecisionReducer(new PrecisionModel(Math.pow(10, precisionScale)))
+        precisionReduce.reduce(geometry).toGenericArrayData
+      case _ => null
+    }
   }
 
   override def dataType: DataType = GeometryUDT
@@ -548,13 +502,10 @@ case class ST_PrecisionReduce(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsText(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     val writer = new WKTWriter(GeometrySerializer.getDimension(geometry))
     UTF8String.fromString(writer.write(geometry))
   }
@@ -569,12 +520,9 @@ case class ST_AsText(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsGeoJSON(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  extends UnaryGeometryExpression with CodegenFallback {
 
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     val writer = new GeoJSONWriter()
     UTF8String.fromString(writer.write(geometry).toString)
   }
@@ -589,13 +537,10 @@ case class ST_AsGeoJSON(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsBinary(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   inputExpressions.validateLength(1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     val dimensions = if (java.lang.Double.isNaN(geometry.getCoordinate.getZ)) 2 else 3
     val endian = if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) ByteOrderValues.BIG_ENDIAN else ByteOrderValues.LITTLE_ENDIAN
     val writer = new WKBWriter(dimensions, endian)
@@ -612,13 +557,10 @@ case class ST_AsBinary(inputExpressions: Seq[Expression])
 }
 
 case class ST_AsEWKB(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   inputExpressions.validateLength(1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     val dimensions = if (java.lang.Double.isNaN(geometry.getCoordinate.getZ)) 2 else 3
     val endian = if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) ByteOrderValues.BIG_ENDIAN else ByteOrderValues.LITTLE_ENDIAN
     val writer = new WKBWriter(dimensions, endian, geometry.getSRID != 0)
@@ -635,13 +577,10 @@ case class ST_AsEWKB(inputExpressions: Seq[Expression])
 }
 
 case class ST_SRID(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   inputExpressions.validateLength(1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry.getSRID
   }
 
@@ -658,13 +597,16 @@ case class ST_SetSRID(inputExpressions: Seq[Expression])
   extends Expression with CodegenFallback {
   inputExpressions.validateLength(2)
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = true
 
   override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
     val srid = inputExpressions(1).eval(input).asInstanceOf[Integer]
-    val factory = new GeometryFactory(geometry.getPrecisionModel, srid, geometry.getFactory.getCoordinateSequenceFactory)
-    new GenericArrayData(GeometrySerializer.serialize(factory.createGeometry(geometry)))
+    inputExpressions(0).toGeometry(input) match {
+      case geometry: Geometry =>
+        val factory = new GeometryFactory(geometry.getPrecisionModel, srid, geometry.getFactory.getCoordinateSequenceFactory)
+        factory.createGeometry(geometry).toGenericArrayData
+      case _ => null
+    }
   }
 
   override def dataType: DataType = GeometryUDT
@@ -677,13 +619,10 @@ case class ST_SetSRID(inputExpressions: Seq[Expression])
 }
 
 case class ST_GeometryType(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     UTF8String.fromString("ST_" + geometry.getGeometryType)
   }
 
@@ -704,18 +643,14 @@ case class ST_GeometryType(inputExpressions: Seq[Expression])
   * @param inputExpressions Geometry
   */
 case class ST_LineMerge(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
   // Definition of the Geometry Collection Empty
   lazy val GeometryFactory = new GeometryFactory()
   lazy val emptyGeometry = GeometryFactory.createGeometryCollection(null)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
-
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     val merger = new LineMerger()
 
     val output: Geometry = geometry match {
@@ -748,14 +683,12 @@ case class ST_LineMerge(inputExpressions: Seq[Expression])
 }
 
 case class ST_Azimuth(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
-  override def nullable: Boolean = false
 
-  override def eval(input: InternalRow): Any = {
-    val geometries = (inputExpressions(0).toGeometry(input), inputExpressions(1).toGeometry(input))
-    geometries match {
-       case (pointA: Point, pointB: Point) => calculateAzimuth(pointA, pointB)
+  override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
+    (leftGeometry, rightGeometry) match {
+      case (pointA: Point, pointB: Point) => calculateAzimuth(pointA, pointB)
     }
   }
 
@@ -776,14 +709,10 @@ case class ST_Azimuth(inputExpressions: Seq[Expression])
 }
 
 case class ST_X(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
-
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case point: Point => point.getX
       case _ => null
@@ -801,14 +730,10 @@ case class ST_X(inputExpressions: Seq[Expression])
 
 
 case class ST_Y(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
-
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case point: Point => point.getY
       case _ => null
@@ -825,14 +750,10 @@ case class ST_Y(inputExpressions: Seq[Expression])
 }
 
 case class ST_Z(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
-
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case point: Point => point.getCoordinate.getZ
       case _ => null
@@ -849,13 +770,10 @@ case class ST_Z(inputExpressions: Seq[Expression])
 }
 
 case class ST_StartPoint(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case line: LineString => line.getPointN(0).toGenericArrayData
       case _ => null
@@ -873,14 +791,11 @@ case class ST_StartPoint(inputExpressions: Seq[Expression])
 
 
 case class ST_Boundary(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
-  override def nullable: Boolean = true
+  extends UnaryGeometryExpression with CodegenFallback {
 
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     val geometryBoundary = geometry.getBoundary
     geometryBoundary.toGenericArrayData
-
   }
 
   override def dataType: DataType = GeometryUDT
@@ -894,17 +809,11 @@ case class ST_Boundary(inputExpressions: Seq[Expression])
 
 
 case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   private val geometryFactory = new GeometryFactory()
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
-    geometry match {
-      case geom: Geometry => getMinimumBoundingRadius(geom)
-      case _ => null
-    }
+  override protected def nullSafeEval(geometry: Geometry): Any = {
+    getMinimumBoundingRadius(geometry)
   }
 
   private def getMinimumBoundingRadius(geom: Geometry): InternalRow = {
@@ -1061,11 +970,9 @@ case class ST_LineInterpolatePoint(inputExpressions: Seq[Expression])
 
 
 case class ST_EndPoint(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
-  override def nullable: Boolean = true
+  extends UnaryGeometryExpression with CodegenFallback {
 
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case string: LineString => string.getEndPoint.toGenericArrayData
       case _ => null
@@ -1083,11 +990,9 @@ case class ST_EndPoint(inputExpressions: Seq[Expression])
 }
 
 case class ST_ExteriorRing(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
-  override def nullable: Boolean = true
+  extends UnaryGeometryExpression with CodegenFallback {
 
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case polygon: Polygon => polygon.getExteriorRing.toGenericArrayData
       case _ => null
@@ -1165,13 +1070,10 @@ case class ST_InteriorRingN(inputExpressions: Seq[Expression])
 }
 
 case class ST_Dump(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     val geometryCollection = geometry match {
       case collection: GeometryCollection => {
         val numberOfGeometries = collection.getNumGeometries
@@ -1194,13 +1096,10 @@ case class ST_Dump(inputExpressions: Seq[Expression])
 }
 
 case class ST_DumpPoints(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     ArrayData.toArrayData(geometry.getPoints.map(geom => geom.toGenericArrayData))
   }
 
@@ -1215,13 +1114,10 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression])
 
 
 case class ST_IsClosed(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case circle: Circle => true
       case point: MultiPoint => true
@@ -1245,13 +1141,10 @@ case class ST_IsClosed(inputExpressions: Seq[Expression])
 }
 
 case class ST_NumInteriorRings(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case polygon: Polygon => polygon.getNumInteriorRing
       case _: Geometry => null
@@ -1352,13 +1245,10 @@ case class ST_RemovePoint(inputExpressions: Seq[Expression])
 }
 
 case class ST_IsRing(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = inputExpressions.head.toGeometry(input)
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry match {
       case string: LineString => string.isSimple & string.isClosed
       case _ => null
@@ -1383,13 +1273,10 @@ case class ST_IsRing(inputExpressions: Seq[Expression])
   * @param inputExpressions Geometry
   */
 case class ST_NumGeometries(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     geometry.getNumGeometries()
   }
 
@@ -1409,13 +1296,10 @@ case class ST_NumGeometries(inputExpressions: Seq[Expression])
   * @param inputExpressions Geometry
   */
 case class ST_FlipCoordinates(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends UnaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 1)
 
-  override def nullable: Boolean = false
-
-  override def eval(input: InternalRow): Any = {
-    val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
+  override protected def nullSafeEval(geometry: Geometry): Any = {
     GeomUtils.flipCoordinates(geometry)
     geometry.toGenericArrayData
   }
@@ -1561,18 +1445,13 @@ case class ST_GeoHash(inputExpressions: Seq[Expression])
 }
 
 case class ST_Difference(inputExpressions: Seq[Expression])
-  extends Expression with CodegenFallback {
+  extends BinaryGeometryExpression with CodegenFallback {
   assert(inputExpressions.length == 2)
 
   lazy val GeometryFactory = new GeometryFactory()
   lazy val emptyPolygon = GeometryFactory.createPolygon(null, null)
 
-  override def nullable: Boolean = false
-
-  override def eval(inputRow: InternalRow): Any = {
-    val leftGeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData])
-    val rightGeometry = GeometrySerializer.deserialize(inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData])
-
+  override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
     val isIntersects = leftGeometry.intersects(rightGeometry)
     lazy val isRightContainsLeft = rightGeometry.contains(leftGeometry)
 
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
new file mode 100644
index 0000000..8b7af2b
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.locationtech.jts.geom.Geometry
+import org.apache.spark.sql.sedona_sql.expressions.implicits._
+
+abstract class UnaryGeometryExpression extends Expression {
+  def inputExpressions: Seq[Expression]
+
+  override def nullable: Boolean = true
+
+  override def eval(input: InternalRow): Any = {
+    val geometry = inputExpressions.head.toGeometry(input)
+    (geometry) match {
+      case (geometry: Geometry) => nullSafeEval(geometry)
+      case _ => null
+    }
+  }
+
+  protected def nullSafeEval(geometry: Geometry): Any
+}
+
+abstract class BinaryGeometryExpression extends Expression {
+  def inputExpressions: Seq[Expression]
+
+  override def nullable: Boolean = true
+
+  override def eval(input: InternalRow): Any = {
+    val leftGeometry = inputExpressions(0).toGeometry(input)
+    val rightGeometry = inputExpressions(1).toGeometry(input)
+    (leftGeometry, rightGeometry) match {
+      case (leftGeometry: Geometry, rightGeometry: Geometry) => nullSafeEval(leftGeometry, rightGeometry)
+      case _ => null
+    }
+  }
+
+  protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any
+}
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
index e31efe3..62a9dbd 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
@@ -23,6 +23,7 @@ import org.apache.sedona.sql.utils.GeometrySerializer
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.unsafe.types.UTF8String
 import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point}
 
 object implicits {
@@ -38,6 +39,13 @@ object implicits {
     def toInt(input: InternalRow): Int = {
       inputExpression.eval(input).asInstanceOf[Int]
     }
+
+    def asString(input: InternalRow): String = {
+      inputExpression.eval(input).asInstanceOf[UTF8String] match {
+        case s: UTF8String => s.toString
+        case _ => null
+      }
+    }
   }
 
   implicit class SequenceEnhancer[T](seq: Seq[T]) {
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index c6bcf30..f788543 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -1200,4 +1200,109 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
         lineString2D.getFactory.createPoint(interPoint3D).toText
       )
   }
+
+
+  it("handles nulls") {
+    var functionDf: DataFrame = null
+    functionDf = sparkSession.sql("select ST_Distance(null, null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_3DDistance(null, null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_ConvexHull(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_NPoints(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Buffer(null, 0)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Envelope(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Length(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Area(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Centroid(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Transform(null, null, null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Intersection(null, null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_IsValid(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_IsSimple(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_SimplifyPreserveTopology(null, 1)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_PrecisionReduce(null, 1)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_AsText(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_AsGeoJSON(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_AsBinary(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_AsEWKB(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_SRID(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_SetSRID(null, 4326)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_GeometryType(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_LineMerge(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Azimuth(null, null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_X(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Y(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Z(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_StartPoint(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Boundary(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_MinimumBoundingRadius(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_LineSubstring(null, 0, 0)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_LineInterpolatePoint(null, 0)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_EndPoint(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_ExteriorRing(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_GeometryN(null, 0)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_InteriorRingN(null, 0)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Dump(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_DumpPoints(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_IsClosed(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_NumInteriorRings(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_AddPoint(null, null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_RemovePoint(null, null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_IsRing(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_NumGeometries(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_FlipCoordinates(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_SubDivide(null, 0)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_SubDivideExplode(null, 0)")
+    assert(functionDf.count() == 0)
+    functionDf = sparkSession.sql("select ST_MakePolygon(null)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_GeoHash(null, 1)")
+    assert(functionDf.first().get(0) == null)
+    functionDf = sparkSession.sql("select ST_Difference(null, null)")
+    assert(functionDf.first().get(0) == null)
+  }
 }