You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/02/06 23:48:20 UTC
[incubator-sedona] branch master updated: [SEDONA-4] Handle nulls in SQL functions (#578)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new e1eab69 [SEDONA-4] Handle nulls in SQL functions (#578)
e1eab69 is described below
commit e1eab692a1edf41bba5988f56f88b60b0c9cdb3d
Author: Kurtis Seebaldt <ku...@srpatx.com>
AuthorDate: Sun Feb 6 17:48:13 2022 -0600
[SEDONA-4] Handle nulls in SQL functions (#578)
---
.../sql/sedona_sql/expressions/Functions.scala | 371 +++++++--------------
.../expressions/NullSafeExpressions.scala | 57 ++++
.../sql/sedona_sql/expressions/implicits.scala | 8 +
.../org/apache/sedona/sql/functionTestScala.scala | 105 ++++++
4 files changed, 295 insertions(+), 246 deletions(-)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 77f92d2..7b5a011 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -49,8 +49,6 @@ import org.opengis.referencing.operation.MathTransform
import org.wololo.jts2geojson.GeoJSONWriter
import java.nio.ByteOrder
-import java.util
-import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}
/**
@@ -59,21 +57,12 @@ import scala.util.{Failure, Success, Try}
* @param inputExpressions This function takes two geometries and calculates the distance between two objects.
*/
case class ST_Distance(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends BinaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 2)
- override def nullable: Boolean = false
-
override def toString: String = s" **${ST_Distance.getClass.getName}** "
- override def eval(inputRow: InternalRow): Any = {
- val leftArray = inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData]
- val rightArray = inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData]
-
- val leftGeometry = GeometrySerializer.deserialize(leftArray)
-
- val rightGeometry = GeometrySerializer.deserialize(rightArray)
-
+ override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
leftGeometry.distance(rightGeometry)
}
@@ -87,21 +76,13 @@ case class ST_Distance(inputExpressions: Seq[Expression])
}
case class ST_3DDistance(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends BinaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 2)
- override def nullable: Boolean = true
-
override def toString: String = s" **${ST_3DDistance.getClass.getName}** "
- override def eval(inputRow: InternalRow): Any = {
- val leftGeometry = inputExpressions(0).toGeometry(inputRow)
- val rightGeometry = inputExpressions(1).toGeometry(inputRow)
-
- (leftGeometry, rightGeometry) match {
- case (leftGeometry: Geometry, rightGeometry: Geometry) => Distance3DOp.distance(leftGeometry, rightGeometry)
- case _ => null
- }
+ override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
+ Distance3DOp.distance(leftGeometry, rightGeometry)
}
override def dataType = DoubleType
@@ -119,13 +100,10 @@ case class ST_3DDistance(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_ConvexHull(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+ override def nullSafeEval(geometry: Geometry): Any = {
new GenericArrayData(GeometrySerializer.serialize(geometry.convexHull()))
}
@@ -144,16 +122,10 @@ case class ST_ConvexHull(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_NPoints(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
- override def nullable: Boolean = false
+ extends UnaryGeometryExpression with CodegenFallback {
- override def eval(input: InternalRow): Any = {
- inputExpressions.length match {
- case 1 =>
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
- geometry.getCoordinates.length
- case _ => None
- }
+ override def nullSafeEval(geometry: Geometry): Any = {
+ geometry.getCoordinates.length
}
override def dataType: DataType = IntegerType
@@ -174,16 +146,18 @@ case class ST_Buffer(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback {
assert(inputExpressions.length == 2)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
val buffer: Double = inputExpressions(1).eval(input) match {
case a: Decimal => a.toDouble
case a: Double => a
case a: Int => a
}
- new GenericArrayData(GeometrySerializer.serialize(geometry.buffer(buffer)))
+ inputExpressions(0).toGeometry(input) match {
+ case geometry: Geometry => geometry.buffer(buffer).toGenericArrayData
+ case _ => null
+ }
}
override def dataType: DataType = GeometryUDT
@@ -202,13 +176,10 @@ case class ST_Buffer(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Envelope(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+ override def nullSafeEval(geometry: Geometry): Any = {
new GenericArrayData(GeometrySerializer.serialize(geometry.getEnvelope()))
}
@@ -227,13 +198,10 @@ case class ST_Envelope(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Length(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+ override def nullSafeEval(geometry: Geometry): Any = {
geometry.getLength
}
@@ -252,13 +220,10 @@ case class ST_Length(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Area(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+ override def nullSafeEval(geometry: Geometry): Any = {
geometry.getArea
}
@@ -277,13 +242,10 @@ case class ST_Area(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Centroid(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
+ override def nullSafeEval(geometry: Geometry): Any = {
new GenericArrayData(GeometrySerializer.serialize(geometry.getCentroid()))
}
@@ -305,22 +267,27 @@ case class ST_Transform(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback {
assert(inputExpressions.length >= 3 && inputExpressions.length <= 4)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
- val originalGeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
- val sourceCRScode = CRS.decode(inputExpressions(1).eval(input).asInstanceOf[UTF8String].toString)
- val targetCRScode = CRS.decode(inputExpressions(2).eval(input).asInstanceOf[UTF8String].toString)
-
- var transform: MathTransform = null
- if (inputExpressions.length == 4) {
- transform = CRS.findMathTransform(sourceCRScode, targetCRScode, inputExpressions(3).eval(input).asInstanceOf[Boolean])
- }
- else {
- transform = CRS.findMathTransform(sourceCRScode, targetCRScode, false)
+ val originalGeometry = inputExpressions(0).toGeometry(input)
+ val sourceCRS = inputExpressions(1).asString(input)
+ val targetCRS = inputExpressions(2).asString(input)
+
+ (originalGeometry, sourceCRS, targetCRS) match {
+ case (originalGeometry: Geometry, sourceCRS: String, targetCRS: String) =>
+ val sourceCRScode = CRS.decode(sourceCRS)
+ val targetCRScode = CRS.decode(targetCRS)
+ var transform: MathTransform = null
+ if (inputExpressions.length == 4) {
+ transform = CRS.findMathTransform(sourceCRScode, targetCRScode, inputExpressions(3).eval(input).asInstanceOf[Boolean])
+ }
+ else {
+ transform = CRS.findMathTransform(sourceCRScode, targetCRScode, false)
+ }
+ JTS.transform(originalGeometry, transform).toGenericArrayData
+ case (_, _, _) => null
}
- val geom = JTS.transform(originalGeometry, transform)
- new GenericArrayData(GeometrySerializer.serialize(geom))
}
override def dataType: DataType = GeometryUDT
@@ -339,35 +306,30 @@ case class ST_Transform(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_Intersection(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends BinaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 2)
lazy val GeometryFactory = new GeometryFactory()
lazy val emptyPolygon = GeometryFactory.createPolygon(null, null)
- override def nullable: Boolean = false
-
- override def eval(inputRow: InternalRow): Any = {
- val leftgeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData])
- val rightgeometry = GeometrySerializer.deserialize(inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData])
-
- val isIntersects = leftgeometry.intersects(rightgeometry)
- lazy val isLeftContainsRight = leftgeometry.contains(rightgeometry)
- lazy val isRightContainsLeft = rightgeometry.contains(leftgeometry)
+ override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
+ val isIntersects = leftGeometry.intersects(rightGeometry)
+ lazy val isLeftContainsRight = leftGeometry.contains(rightGeometry)
+ lazy val isRightContainsLeft = rightGeometry.contains(leftGeometry)
if (!isIntersects) {
return new GenericArrayData(GeometrySerializer.serialize(emptyPolygon))
}
if (isIntersects && isLeftContainsRight) {
- return new GenericArrayData(GeometrySerializer.serialize(rightgeometry))
+ return new GenericArrayData(GeometrySerializer.serialize(rightGeometry))
}
if (isIntersects && isRightContainsLeft) {
- return new GenericArrayData(GeometrySerializer.serialize(leftgeometry))
+ return new GenericArrayData(GeometrySerializer.serialize(leftGeometry))
}
- return new GenericArrayData(GeometrySerializer.serialize(leftgeometry.intersection(rightgeometry)))
+ new GenericArrayData(GeometrySerializer.serialize(leftGeometry.intersection(rightGeometry)))
}
override def dataType: DataType = GeometryUDT
@@ -436,19 +398,12 @@ case class ST_MakeValid(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_IsValid(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.eval(input).asInstanceOf[ArrayData]
- if (geometry == null) {
- null
- } else {
- val isvalidop = new IsValidOp(GeometrySerializer.deserialize(geometry))
- isvalidop.isValid
- }
+ override protected def nullSafeEval(geometry: Geometry): Any = {
+ new IsValidOp(geometry).isValid
}
override def dataType: DataType = BooleanType
@@ -466,15 +421,11 @@ case class ST_IsValid(inputExpressions: Seq[Expression])
* @param inputExpressions
*/
case class ST_IsSimple(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
- val isSimpleop = new IsSimpleOp(geometry)
- isSimpleop.isSimple
+ override protected def nullSafeEval(geometry: Geometry): Any = {
+ new IsSimpleOp(geometry).isSimple
}
override def dataType: DataType = BooleanType
@@ -498,18 +449,18 @@ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback {
assert(inputExpressions.length == 2)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
val distanceTolerance = inputExpressions(1).eval(input) match {
case number: Decimal => number.toDouble
case number: Double => number
case number: Int => number.toDouble
}
- val simplifiedGeometry = TopologyPreservingSimplifier.simplify(geometry, distanceTolerance)
-
- new GenericArrayData(GeometrySerializer.serialize(simplifiedGeometry))
+ inputExpressions(0).toGeometry(input) match {
+ case geometry: Geometry => TopologyPreservingSimplifier.simplify(geometry, distanceTolerance).toGenericArrayData
+ case _ => null
+ }
}
override def dataType: DataType = GeometryUDT
@@ -529,13 +480,16 @@ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression])
*/
case class ST_PrecisionReduce(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback {
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
val precisionScale = inputExpressions(1).eval(input).asInstanceOf[Int]
- val precisionReduce = new GeometryPrecisionReducer(new PrecisionModel(Math.pow(10, precisionScale)))
- new GenericArrayData(GeometrySerializer.serialize(precisionReduce.reduce(geometry)))
+ inputExpressions(0).toGeometry(input) match {
+ case geometry: Geometry =>
+ val precisionReduce =new GeometryPrecisionReducer(new PrecisionModel(Math.pow(10, precisionScale)))
+ precisionReduce.reduce(geometry).toGenericArrayData
+ case _ => null
+ }
}
override def dataType: DataType = GeometryUDT
@@ -548,13 +502,10 @@ case class ST_PrecisionReduce(inputExpressions: Seq[Expression])
}
case class ST_AsText(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+ override protected def nullSafeEval(geometry: Geometry): Any = {
val writer = new WKTWriter(GeometrySerializer.getDimension(geometry))
UTF8String.fromString(writer.write(geometry))
}
@@ -569,12 +520,9 @@ case class ST_AsText(inputExpressions: Seq[Expression])
}
case class ST_AsGeoJSON(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ extends UnaryGeometryExpression with CodegenFallback {
+ override protected def nullSafeEval(geometry: Geometry): Any = {
val writer = new GeoJSONWriter()
UTF8String.fromString(writer.write(geometry).toString)
}
@@ -589,13 +537,10 @@ case class ST_AsGeoJSON(inputExpressions: Seq[Expression])
}
case class ST_AsBinary(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
inputExpressions.validateLength(1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
val dimensions = if (java.lang.Double.isNaN(geometry.getCoordinate.getZ)) 2 else 3
val endian = if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) ByteOrderValues.BIG_ENDIAN else ByteOrderValues.LITTLE_ENDIAN
val writer = new WKBWriter(dimensions, endian)
@@ -612,13 +557,10 @@ case class ST_AsBinary(inputExpressions: Seq[Expression])
}
case class ST_AsEWKB(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
inputExpressions.validateLength(1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
val dimensions = if (java.lang.Double.isNaN(geometry.getCoordinate.getZ)) 2 else 3
val endian = if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) ByteOrderValues.BIG_ENDIAN else ByteOrderValues.LITTLE_ENDIAN
val writer = new WKBWriter(dimensions, endian, geometry.getSRID != 0)
@@ -635,13 +577,10 @@ case class ST_AsEWKB(inputExpressions: Seq[Expression])
}
case class ST_SRID(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
inputExpressions.validateLength(1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry.getSRID
}
@@ -658,13 +597,16 @@ case class ST_SetSRID(inputExpressions: Seq[Expression])
extends Expression with CodegenFallback {
inputExpressions.validateLength(2)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
val srid = inputExpressions(1).eval(input).asInstanceOf[Integer]
- val factory = new GeometryFactory(geometry.getPrecisionModel, srid, geometry.getFactory.getCoordinateSequenceFactory)
- new GenericArrayData(GeometrySerializer.serialize(factory.createGeometry(geometry)))
+ inputExpressions(0).toGeometry(input) match {
+ case geometry: Geometry =>
+ val factory = new GeometryFactory(geometry.getPrecisionModel, srid, geometry.getFactory.getCoordinateSequenceFactory)
+ factory.createGeometry(geometry).toGenericArrayData
+ case _ => null
+ }
}
override def dataType: DataType = GeometryUDT
@@ -677,13 +619,10 @@ case class ST_SetSRID(inputExpressions: Seq[Expression])
}
case class ST_GeometryType(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+ override protected def nullSafeEval(geometry: Geometry): Any = {
UTF8String.fromString("ST_" + geometry.getGeometryType)
}
@@ -704,18 +643,14 @@ case class ST_GeometryType(inputExpressions: Seq[Expression])
* @param inputExpressions Geometry
*/
case class ST_LineMerge(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
// Definition of the Geometry Collection Empty
lazy val GeometryFactory = new GeometryFactory()
lazy val emptyGeometry = GeometryFactory.createGeometryCollection(null)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
-
+ override protected def nullSafeEval(geometry: Geometry): Any = {
val merger = new LineMerger()
val output: Geometry = geometry match {
@@ -748,14 +683,12 @@ case class ST_LineMerge(inputExpressions: Seq[Expression])
}
case class ST_Azimuth(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends BinaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 2)
- override def nullable: Boolean = false
- override def eval(input: InternalRow): Any = {
- val geometries = (inputExpressions(0).toGeometry(input), inputExpressions(1).toGeometry(input))
- geometries match {
- case (pointA: Point, pointB: Point) => calculateAzimuth(pointA, pointB)
+ override def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
+ (leftGeometry, rightGeometry) match {
+ case (pointA: Point, pointB: Point) => calculateAzimuth(pointA, pointB)
}
}
@@ -776,14 +709,10 @@ case class ST_Azimuth(inputExpressions: Seq[Expression])
}
case class ST_X(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
-
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case point: Point => point.getX
case _ => null
@@ -801,14 +730,10 @@ case class ST_X(inputExpressions: Seq[Expression])
case class ST_Y(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
-
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case point: Point => point.getY
case _ => null
@@ -825,14 +750,10 @@ case class ST_Y(inputExpressions: Seq[Expression])
}
case class ST_Z(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
-
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case point: Point => point.getCoordinate.getZ
case _ => null
@@ -849,13 +770,10 @@ case class ST_Z(inputExpressions: Seq[Expression])
}
case class ST_StartPoint(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case line: LineString => line.getPointN(0).toGenericArrayData
case _ => null
@@ -873,14 +791,11 @@ case class ST_StartPoint(inputExpressions: Seq[Expression])
case class ST_Boundary(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
- override def nullable: Boolean = true
+ extends UnaryGeometryExpression with CodegenFallback {
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
val geometryBoundary = geometry.getBoundary
geometryBoundary.toGenericArrayData
-
}
override def dataType: DataType = GeometryUDT
@@ -894,17 +809,11 @@ case class ST_Boundary(inputExpressions: Seq[Expression])
case class ST_MinimumBoundingRadius(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
private val geometryFactory = new GeometryFactory()
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
- geometry match {
- case geom: Geometry => getMinimumBoundingRadius(geom)
- case _ => null
- }
+ override protected def nullSafeEval(geometry: Geometry): Any = {
+ getMinimumBoundingRadius(geometry)
}
private def getMinimumBoundingRadius(geom: Geometry): InternalRow = {
@@ -1061,11 +970,9 @@ case class ST_LineInterpolatePoint(inputExpressions: Seq[Expression])
case class ST_EndPoint(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
- override def nullable: Boolean = true
+ extends UnaryGeometryExpression with CodegenFallback {
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case string: LineString => string.getEndPoint.toGenericArrayData
case _ => null
@@ -1083,11 +990,9 @@ case class ST_EndPoint(inputExpressions: Seq[Expression])
}
case class ST_ExteriorRing(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
- override def nullable: Boolean = true
+ extends UnaryGeometryExpression with CodegenFallback {
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case polygon: Polygon => polygon.getExteriorRing.toGenericArrayData
case _ => null
@@ -1165,13 +1070,10 @@ case class ST_InteriorRingN(inputExpressions: Seq[Expression])
}
case class ST_Dump(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
val geometryCollection = geometry match {
case collection: GeometryCollection => {
val numberOfGeometries = collection.getNumGeometries
@@ -1194,13 +1096,10 @@ case class ST_Dump(inputExpressions: Seq[Expression])
}
case class ST_DumpPoints(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
ArrayData.toArrayData(geometry.getPoints.map(geom => geom.toGenericArrayData))
}
@@ -1215,13 +1114,10 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression])
case class ST_IsClosed(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case circle: Circle => true
case point: MultiPoint => true
@@ -1245,13 +1141,10 @@ case class ST_IsClosed(inputExpressions: Seq[Expression])
}
case class ST_NumInteriorRings(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case polygon: Polygon => polygon.getNumInteriorRing
case _: Geometry => null
@@ -1352,13 +1245,10 @@ case class ST_RemovePoint(inputExpressions: Seq[Expression])
}
case class ST_IsRing(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = true
-
- override def eval(input: InternalRow): Any = {
- val geometry = inputExpressions.head.toGeometry(input)
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry match {
case string: LineString => string.isSimple & string.isClosed
case _ => null
@@ -1383,13 +1273,10 @@ case class ST_IsRing(inputExpressions: Seq[Expression])
* @param inputExpressions Geometry
*/
case class ST_NumGeometries(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions.head.eval(input).asInstanceOf[ArrayData])
+ override protected def nullSafeEval(geometry: Geometry): Any = {
geometry.getNumGeometries()
}
@@ -1409,13 +1296,10 @@ case class ST_NumGeometries(inputExpressions: Seq[Expression])
* @param inputExpressions Geometry
*/
case class ST_FlipCoordinates(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends UnaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 1)
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val geometry = GeometrySerializer.deserialize(inputExpressions(0).eval(input).asInstanceOf[ArrayData])
+ override protected def nullSafeEval(geometry: Geometry): Any = {
GeomUtils.flipCoordinates(geometry)
geometry.toGenericArrayData
}
@@ -1561,18 +1445,13 @@ case class ST_GeoHash(inputExpressions: Seq[Expression])
}
case class ST_Difference(inputExpressions: Seq[Expression])
- extends Expression with CodegenFallback {
+ extends BinaryGeometryExpression with CodegenFallback {
assert(inputExpressions.length == 2)
lazy val GeometryFactory = new GeometryFactory()
lazy val emptyPolygon = GeometryFactory.createPolygon(null, null)
- override def nullable: Boolean = false
-
- override def eval(inputRow: InternalRow): Any = {
- val leftGeometry = GeometrySerializer.deserialize(inputExpressions(0).eval(inputRow).asInstanceOf[ArrayData])
- val rightGeometry = GeometrySerializer.deserialize(inputExpressions(1).eval(inputRow).asInstanceOf[ArrayData])
-
+ override protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any = {
val isIntersects = leftGeometry.intersects(rightGeometry)
lazy val isRightContainsLeft = rightGeometry.contains(leftGeometry)
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
new file mode 100644
index 0000000..8b7af2b
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/NullSafeExpressions.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.locationtech.jts.geom.Geometry
+import org.apache.spark.sql.sedona_sql.expressions.implicits._
+
+abstract class UnaryGeometryExpression extends Expression {
+ def inputExpressions: Seq[Expression]
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any = {
+ val geometry = inputExpressions.head.toGeometry(input)
+ (geometry) match {
+ case (geometry: Geometry) => nullSafeEval(geometry)
+ case _ => null
+ }
+ }
+
+ protected def nullSafeEval(geometry: Geometry): Any
+}
+
+abstract class BinaryGeometryExpression extends Expression {
+ def inputExpressions: Seq[Expression]
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any = {
+ val leftGeometry = inputExpressions(0).toGeometry(input)
+ val rightGeometry = inputExpressions(1).toGeometry(input)
+ (leftGeometry, rightGeometry) match {
+ case (leftGeometry: Geometry, rightGeometry: Geometry) => nullSafeEval(leftGeometry, rightGeometry)
+ case _ => null
+ }
+ }
+
+ protected def nullSafeEval(leftGeometry: Geometry, rightGeometry: Geometry): Any
+}
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
index e31efe3..62a9dbd 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/implicits.scala
@@ -23,6 +23,7 @@ import org.apache.sedona.sql.utils.GeometrySerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.unsafe.types.UTF8String
import org.locationtech.jts.geom.{Geometry, GeometryFactory, Point}
object implicits {
@@ -38,6 +39,13 @@ object implicits {
def toInt(input: InternalRow): Int = {
inputExpression.eval(input).asInstanceOf[Int]
}
+
+ def asString(input: InternalRow): String = {
+ inputExpression.eval(input).asInstanceOf[UTF8String] match {
+ case s: UTF8String => s.toString
+ case _ => null
+ }
+ }
}
implicit class SequenceEnhancer[T](seq: Seq[T]) {
diff --git a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index c6bcf30..f788543 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -1200,4 +1200,109 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample
lineString2D.getFactory.createPoint(interPoint3D).toText
)
}
+
+
+ it("handles nulls") {
+ var functionDf: DataFrame = null
+ functionDf = sparkSession.sql("select ST_Distance(null, null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_3DDistance(null, null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_ConvexHull(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_NPoints(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Buffer(null, 0)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Envelope(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Length(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Area(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Centroid(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Transform(null, null, null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Intersection(null, null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_IsValid(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_IsSimple(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_SimplifyPreserveTopology(null, 1)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_PrecisionReduce(null, 1)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_AsText(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_AsGeoJSON(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_AsBinary(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_AsEWKB(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_SRID(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_SetSRID(null, 4326)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_GeometryType(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_LineMerge(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Azimuth(null, null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_X(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Y(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Z(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_StartPoint(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Boundary(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_MinimumBoundingRadius(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_LineSubstring(null, 0, 0)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_LineInterpolatePoint(null, 0)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_EndPoint(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_ExteriorRing(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_GeometryN(null, 0)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_InteriorRingN(null, 0)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Dump(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_DumpPoints(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_IsClosed(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_NumInteriorRings(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_AddPoint(null, null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_RemovePoint(null, null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_IsRing(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_NumGeometries(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_FlipCoordinates(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_SubDivide(null, 0)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_SubDivideExplode(null, 0)")
+ assert(functionDf.count() == 0)
+ functionDf = sparkSession.sql("select ST_MakePolygon(null)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_GeoHash(null, 1)")
+ assert(functionDf.first().get(0) == null)
+ functionDf = sparkSession.sql("select ST_Difference(null, null)")
+ assert(functionDf.first().get(0) == null)
+ }
}