You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/05/29 07:30:19 UTC
[spark] branch master updated: [SPARK-28481][SQL] More expressions
should extend NullIntolerant
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 91148f4 [SPARK-28481][SQL] More expressions should extend NullIntolerant
91148f4 is described below
commit 91148f428b61b5e44c17bd65ceff74e0a8b4b3f5
Author: Yuming Wang <yu...@ebay.com>
AuthorDate: Fri May 29 07:28:57 2020 +0000
[SPARK-28481][SQL] More expressions should extend NullIntolerant
### What changes were proposed in this pull request?
1. Make more expressions extend `NullIntolerant`.
2. Add a checker(in `ExpressionInfoSuite`) to identify whether the expression is `NullIntolerant`.
### Why are the changes needed?
Avoid skew join if the join column has many null values and can improve query performance. For examples:
```sql
CREATE TABLE t1(c1 string, c2 string) USING parquet;
CREATE TABLE t2(c1 string, c2 string) USING parquet;
EXPLAIN SELECT t1.* FROM t1 JOIN t2 ON upper(t1.c1) = upper(t2.c1);
```
Before and after this PR:
```sql
== Physical Plan ==
*(2) Project [c1#5, c2#6]
+- *(2) BroadcastHashJoin [upper(c1#5)], [upper(c1#7)], Inner, BuildLeft
:- BroadcastExchange HashedRelationBroadcastMode(List(upper(input[0, string, true]))), [id=#41]
: +- *(1) ColumnarToRow
: +- FileScan parquet default.t1[c1#5,c2#6]
+- *(2) ColumnarToRow
+- FileScan parquet default.t2[c1#7]
== Physical Plan ==
*(2) Project [c1#5, c2#6]
+- *(2) BroadcastHashJoin [upper(c1#5)], [upper(c1#7)], Inner, BuildRight
:- *(2) Project [c1#5, c2#6]
: +- *(2) Filter isnotnull(c1#5)
: +- *(2) ColumnarToRow
: +- FileScan parquet default.t1[c1#5,c2#6]
+- BroadcastExchange HashedRelationBroadcastMode(List(upper(input[0, string, true]))), [id=#59]
+- *(1) Project [c1#7]
+- *(1) Filter isnotnull(c1#7)
+- *(1) ColumnarToRow
+- FileScan parquet default.t2[c1#7]
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit test.
Closes #28626 from wangyum/SPARK-28481.
Authored-by: Yuming Wang <yu...@ebay.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../sql/catalyst/expressions/TimeWindow.scala | 2 +-
.../catalyst/expressions/bitwiseExpressions.scala | 6 +-
.../expressions/collectionOperations.scala | 39 +++++++-----
.../catalyst/expressions/complexTypeCreator.scala | 4 +-
.../sql/catalyst/expressions/csvExpressions.scala | 3 +-
.../catalyst/expressions/datetimeExpressions.scala | 73 +++++++++++++---------
.../catalyst/expressions/decimalExpressions.scala | 4 +-
.../spark/sql/catalyst/expressions/hash.scala | 11 ++--
.../catalyst/expressions/intervalExpressions.scala | 6 +-
.../sql/catalyst/expressions/jsonExpressions.scala | 6 +-
.../sql/catalyst/expressions/mathExpressions.scala | 23 ++++---
.../catalyst/expressions/regexpExpressions.scala | 6 +-
.../catalyst/expressions/stringExpressions.scala | 58 ++++++++++-------
.../spark/sql/catalyst/expressions/xml/xpath.scala | 3 +-
.../sql/expressions/ExpressionInfoSuite.scala | 39 +++++++++++-
15 files changed, 180 insertions(+), 103 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 82d6894..f7fe467 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -144,7 +144,7 @@ object TimeWindow {
case class PreciseTimestampConversion(
child: Expression,
fromType: DataType,
- toType: DataType) extends UnaryExpression with ExpectsInputTypes {
+ toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(fromType)
override def dataType: DataType = toType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 7b819db..342b14e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
> SELECT _FUNC_ 0;
-1
""")
-case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class BitwiseNot(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
@@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
0
""",
since = "3.0.0")
-case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class BitwiseCount(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 4fd68dc..b32e9ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -141,7 +141,7 @@ object Size {
""",
group = "map_funcs")
case class MapKeys(child: Expression)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
""",
group = "map_funcs")
case class MapValues(child: Expression)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@@ -361,7 +361,8 @@ case class MapValues(child: Expression)
""",
group = "map_funcs",
since = "3.0.0")
-case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class MapEntries(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
""",
group = "map_funcs",
since = "2.4.0")
-case class MapFromEntries(child: Expression) extends UnaryExpression {
+case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant {
@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
@@ -873,7 +874,7 @@ object ArraySortLike {
group = "array_funcs")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
- extends BinaryExpression with ArraySortLike {
+ extends BinaryExpression with ArraySortLike with NullIntolerant {
def this(e: Expression) = this(e, Literal(true))
@@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
Reverse logic for arrays is available since 2.4.0.
"""
)
-case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Reverse(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
@@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
""",
group = "array_funcs")
case class ArrayContains(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = BooleanType
@@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression)
since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
- extends BinaryArrayExpressionWithImplicitCast {
+ extends BinaryArrayExpressionWithImplicitCast with NullIntolerant {
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
@@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
since = "2.4.0")
// scalastyle:on line.size.limit
case class Slice(x: Expression, start: Expression, length: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = x.dataType
@@ -1688,7 +1690,8 @@ case class ArrayJoin(
""",
group = "array_funcs",
since = "2.4.0")
-case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class ArrayMin(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def nullable: Boolean = true
@@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
""",
group = "array_funcs",
since = "2.4.0")
-case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class ArrayMax(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def nullable: Boolean = true
@@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
group = "array_funcs",
since = "2.4.0")
case class ArrayPosition(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
@@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression)
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression)
- extends GetMapValueUtil with GetArrayItemUtil {
+ extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
@@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
""",
group = "array_funcs",
since = "2.4.0")
-case class Flatten(child: Expression) extends UnaryExpression {
+case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant {
private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
@@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
group = "array_funcs",
since = "2.4.0")
case class ArrayRemove(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = left.dataType
@@ -3081,7 +3085,7 @@ trait ArraySetLike {
group = "array_funcs",
since = "2.4.0")
case class ArrayDistinct(child: Expression)
- extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
+ extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
@@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression)
/**
* Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]].
*/
-trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike {
+trait ArrayBinaryLike
+ extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant {
override protected def dt: DataType = dataType
override protected def et: DataType = elementType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 5212ef3..1b4a705 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -255,7 +255,7 @@ object CreateMap {
{1.0:"2",3.0:"4"}
""", since = "2.4.0")
case class MapFromArrays(left: Expression, right: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)
@@ -476,7 +476,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
since = "2.0.1")
// scalastyle:on line.size.limit
case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
- extends TernaryExpression with ExpectsInputTypes {
+ extends TernaryExpression with ExpectsInputTypes with NullIntolerant {
def this(child: Expression, pairDelim: Expression) = {
this(child, pairDelim, Literal(":"))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index 5140db9..f9ccf3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -211,7 +211,8 @@ case class StructsToCsv(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes
+ with NullIntolerant {
override def nullable: Boolean = true
def this(options: Map[String, String], child: Expression) = this(options, child, None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 7dc008a..4f3db1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -198,7 +198,7 @@ case class CurrentBatchTimestamp(
group = "datetime_funcs",
since = "1.5.0")
case class DateAdd(startDate: Expression, days: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = days
@@ -234,7 +234,7 @@ case class DateAdd(startDate: Expression, days: Expression)
group = "datetime_funcs",
since = "1.5.0")
case class DateSub(startDate: Expression, days: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = days
@@ -266,7 +266,8 @@ case class DateSub(startDate: Expression, days: Expression)
group = "datetime_funcs",
since = "1.5.0")
case class Hour(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -298,7 +299,8 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None)
group = "datetime_funcs",
since = "1.5.0")
case class Minute(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -330,7 +332,8 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None)
group = "datetime_funcs",
since = "1.5.0")
case class Second(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -353,7 +356,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None)
}
case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -385,7 +389,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No
""",
group = "datetime_funcs",
since = "1.5.0")
-case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class DayOfYear(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -402,7 +407,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas
}
abstract class NumberToTimestampBase extends UnaryExpression
- with ExpectsInputTypes {
+ with ExpectsInputTypes with NullIntolerant {
protected def upScaleFactor: Long
@@ -487,7 +492,8 @@ case class MicrosToTimestamp(child: Expression)
""",
group = "datetime_funcs",
since = "1.5.0")
-case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Year(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -503,7 +509,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu
}
}
-case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class YearOfWeek(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -528,7 +535,8 @@ case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCa
""",
group = "datetime_funcs",
since = "1.5.0")
-case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Quarter(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -553,7 +561,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI
""",
group = "datetime_funcs",
since = "1.5.0")
-case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Month(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -577,7 +586,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp
30
""",
since = "1.5.0")
-case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class DayOfMonth(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -647,7 +657,7 @@ case class WeekDay(child: Expression) extends DayWeek {
}
}
-abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes {
+abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -665,7 +675,8 @@ abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes {
group = "datetime_funcs",
since = "1.5.0")
// scalastyle:on line.size.limit
-case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class WeekOfYear(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -704,7 +715,8 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
since = "1.5.0")
// scalastyle:on line.size.limit
case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None)
- extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(left: Expression, right: Expression) = this(left, right, None)
@@ -1154,7 +1166,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
""",
group = "datetime_funcs",
since = "1.5.0")
-case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class LastDay(startDate: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def child: Expression = startDate
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -1192,7 +1205,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
since = "1.5.0")
// scalastyle:on line.size.limit
case class NextDay(startDate: Expression, dayOfWeek: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = dayOfWeek
@@ -1248,7 +1261,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
* Adds an interval to timestamp.
*/
case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None)
- extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {
+ extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant {
def this(start: Expression, interval: Expression) = this(start, interval, None)
@@ -1306,7 +1319,7 @@ case class DateAddInterval(
interval: Expression,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
- extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression {
+ extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant {
override def left: Expression = start
override def right: Expression = interval
@@ -1380,7 +1393,7 @@ case class DateAddInterval(
since = "1.5.0")
// scalastyle:on line.size.limit
case class FromUTCTimestamp(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
override def dataType: DataType = TimestampType
@@ -1440,7 +1453,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class AddMonths(startDate: Expression, numMonths: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = numMonths
@@ -1494,7 +1507,8 @@ case class MonthsBetween(
date2: Expression,
roundOff: Expression,
timeZoneId: Option[String] = None)
- extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None)
@@ -1552,7 +1566,7 @@ case class MonthsBetween(
since = "1.5.0")
// scalastyle:on line.size.limit
case class ToUTCTimestamp(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
override def dataType: DataType = TimestampType
@@ -1906,7 +1920,7 @@ case class TruncTimestamp(
group = "datetime_funcs",
since = "1.5.0")
case class DateDiff(endDate: Expression, startDate: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = endDate
override def right: Expression = startDate
@@ -1960,7 +1974,7 @@ private case class GetTimestamp(
group = "datetime_funcs",
since = "3.0.0")
case class MakeDate(year: Expression, month: Expression, day: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def children: Seq[Expression] = Seq(year, month, day)
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType)
@@ -2031,7 +2045,8 @@ case class MakeTimestamp(
sec: Expression,
timezone: Option[Expression] = None,
timeZoneId: Option[String] = None)
- extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(
year: Expression,
@@ -2307,7 +2322,7 @@ case class Extract(field: Expression, source: Expression, child: Expression)
* between the given timestamps.
*/
case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = endTimestamp
override def right: Expression = startTimestamp
@@ -2328,7 +2343,7 @@ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expressi
* Returns the interval from the `left` date (inclusive) to the `right` date (exclusive).
*/
case class SubtractDates(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType)
override def dataType: DataType = CalendarIntervalType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 9014ebf..c2c70b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
* Note: this expression is internal and created only by the optimizer,
* we don't need to do type check for it.
*/
-case class UnscaledValue(child: Expression) extends UnaryExpression {
+case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant {
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"
@@ -49,7 +49,7 @@ case class MakeDecimal(
child: Expression,
precision: Int,
scale: Int,
- nullOnOverflow: Boolean) extends UnaryExpression {
+ nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant {
def this(child: Expression, precision: Int, scale: Int) = {
this(child, precision, scale, !SQLConf.get.ansiEnabled)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 4c8c58a..5e21b58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -53,7 +53,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
> SELECT _FUNC_('Spark');
8cde774d6f7333752ed72cacddb05126
""")
-case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Md5(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
@@ -89,7 +90,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput
""")
// scalastyle:on line.size.limit
case class Sha2(left: Expression, right: Expression)
- extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def dataType: DataType = StringType
override def nullable: Boolean = true
@@ -160,7 +161,8 @@ case class Sha2(left: Expression, right: Expression)
> SELECT _FUNC_('Spark');
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c
""")
-case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Sha1(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
@@ -187,7 +189,8 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu
> SELECT _FUNC_('Spark');
1557323817
""")
-case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Crc32(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 1a569a7..baab224 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -31,7 +31,7 @@ abstract class ExtractIntervalPart(
val dataType: DataType,
func: CalendarInterval => Any,
funcName: String)
- extends UnaryExpression with ExpectsInputTypes with Serializable {
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType)
@@ -82,7 +82,7 @@ object ExtractIntervalPart {
abstract class IntervalNumOperation(
interval: Expression,
num: Expression)
- extends BinaryExpression with ImplicitCastInputTypes with Serializable {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def left: Expression = interval
override def right: Expression = num
@@ -160,7 +160,7 @@ case class MakeInterval(
hours: Expression,
mins: Expression,
secs: Expression)
- extends SeptenaryExpression with ImplicitCastInputTypes {
+ extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(
years: Expression,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 205e527..f4568f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -519,7 +519,8 @@ case class JsonToStructs(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes
+ with NullIntolerant {
// The JSON input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
@@ -638,7 +639,8 @@ case class StructsToJson(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback
+ with ExpectsInputTypes with NullIntolerant {
override def nullable: Boolean = true
def this(options: Map[String, String], child: Expression) = this(options, child, None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 8c6fbc0..fe8ea2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -57,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String)
* @param name The short name of the function
*/
abstract class UnaryMathExpression(val f: Double => Double, name: String)
- extends UnaryExpression with Serializable with ImplicitCastInputTypes {
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
@@ -111,7 +111,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String)
* @param name The short name of the function
*/
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
- extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
@@ -324,7 +324,7 @@ case class Acosh(child: Expression)
-16
""")
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
@@ -452,7 +452,8 @@ object Factorial {
> SELECT _FUNC_(5);
120
""")
-case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Factorial(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -735,7 +736,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
""")
// scalastyle:on line.size.limit
case class Bin(child: Expression)
- extends UnaryExpression with Serializable with ImplicitCastInputTypes {
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
@@ -834,7 +835,8 @@ object Hex {
> SELECT _FUNC_('Spark SQL');
537061726B2053514C
""")
-case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Hex(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, BinaryType, StringType))
@@ -869,7 +871,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput
> SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8');
Spark SQL
""")
-case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Unhex(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
@@ -955,7 +958,7 @@ case class Pow(left: Expression, right: Expression)
4
""")
case class ShiftLeft(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -989,7 +992,7 @@ case class ShiftLeft(left: Expression, right: Expression)
2
""")
case class ShiftRight(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -1023,7 +1026,7 @@ case class ShiftRight(left: Expression, right: Expression)
2
""")
case class ShiftRightUnsigned(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 3f60ca3..28924fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -283,7 +283,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
""",
since = "1.5.0")
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -325,7 +325,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
@@ -433,7 +433,7 @@ object RegExpExtract {
""",
since = "1.5.0")
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
// last regex in string, we will update the pattern iff regexp value changed.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 876588e..334a079 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -334,7 +334,7 @@ trait String2StringExpression extends ImplicitCastInputTypes {
""",
since = "1.0.1")
case class Upper(child: Expression)
- extends UnaryExpression with String2StringExpression {
+ extends UnaryExpression with String2StringExpression with NullIntolerant {
// scalastyle:off caselocale
override def convert(v: UTF8String): UTF8String = v.toUpperCase
@@ -356,7 +356,8 @@ case class Upper(child: Expression)
sparksql
""",
since = "1.0.1")
-case class Lower(child: Expression) extends UnaryExpression with String2StringExpression {
+case class Lower(child: Expression)
+ extends UnaryExpression with String2StringExpression with NullIntolerant {
// scalastyle:off caselocale
override def convert(v: UTF8String): UTF8String = v.toLowerCase
@@ -435,7 +436,7 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
since = "2.3.0")
// scalastyle:on line.size.limit
case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(srcExpr: Expression, searchExpr: Expression) = {
this(srcExpr, searchExpr, Literal(""))
@@ -601,7 +602,7 @@ object StringTranslate {
since = "1.5.0")
// scalastyle:on line.size.limit
case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
@transient private var lastMatching: UTF8String = _
@transient private var lastReplace: UTF8String = _
@@ -666,7 +667,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
since = "1.5.0")
// scalastyle:on line.size.limit
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
- with ImplicitCastInputTypes {
+ with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
@@ -1038,7 +1039,7 @@ case class StringTrimRight(
since = "1.5.0")
// scalastyle:on line.size.limit
case class StringInstr(str: Expression, substr: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = str
override def right: Expression = substr
@@ -1080,7 +1081,7 @@ case class StringInstr(str: Expression, substr: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -1209,7 +1210,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
""",
since = "1.5.0")
case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(str: Expression, len: Expression) = {
this(str, len, Literal(" "))
@@ -1250,7 +1251,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera
""",
since = "1.5.0")
case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(str: Expression, len: Expression) = {
this(str, len, Literal(" "))
@@ -1540,7 +1541,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
Spark Sql
""",
since = "1.5.0")
-case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class InitCap(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[DataType] = Seq(StringType)
override def dataType: DataType = StringType
@@ -1567,7 +1569,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI
""",
since = "1.5.0")
case class StringRepeat(str: Expression, times: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = str
override def right: Expression = times
@@ -1597,7 +1599,7 @@ case class StringRepeat(str: Expression, times: Expression)
""",
since = "1.5.0")
case class StringSpace(child: Expression)
- extends UnaryExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -1742,7 +1744,8 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run
""",
since = "1.5.0")
// scalastyle:on line.size.limit
-case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Length(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -1770,7 +1773,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn
72
""",
since = "2.3.0")
-case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class BitLength(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -1801,7 +1805,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas
9
""",
since = "2.3.0")
-case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class OctetLength(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -1832,7 +1837,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC
""",
since = "1.5.0")
case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
- with ImplicitCastInputTypes {
+ with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
@@ -1857,7 +1862,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
M460
""",
since = "1.5.0")
-case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class SoundEx(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def dataType: DataType = StringType
@@ -1883,7 +1889,8 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT
50
""",
since = "1.5.0")
-case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Ascii(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -1925,7 +1932,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp
""",
since = "2.3.0")
// scalastyle:on line.size.limit
-case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Chr(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(LongType)
@@ -1968,7 +1976,8 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput
U3BhcmsgU1FM
""",
since = "1.5.0")
-case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Base64(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
@@ -1996,7 +2005,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
Spark SQL
""",
since = "1.5.0")
-case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class UnBase64(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -2028,7 +2038,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast
since = "1.5.0")
// scalastyle:on line.size.limit
case class Decode(bin: Expression, charset: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = bin
override def right: Expression = charset
@@ -2068,7 +2078,7 @@ case class Decode(bin: Expression, charset: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class Encode(value: Expression, charset: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = value
override def right: Expression = charset
@@ -2112,7 +2122,7 @@ case class Encode(value: Expression, charset: Expression)
""",
since = "1.5.0")
case class FormatNumber(x: Expression, d: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = x
override def right: Expression = d
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
index 55e06cb..e08a10e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -30,7 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String
*
* This is not the world's most efficient implementation due to type conversion, but works.
*/
-abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+abstract class XPathExtract
+ extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant {
override def left: Expression = xml
override def right: Expression = path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index e18514c..53f9757 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.expressions
import scala.collection.parallel.immutable.ParVector
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.Utils
class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
@@ -156,4 +157,38 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
}
}
}
+
+ test("Check whether SQL expressions should extend NullIntolerant") {
+ // Only check expressions extended from these expressions because these expressions are
+ // NullIntolerant by default.
+ val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
+ classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])
+
+ // Do not check these expressions, because these expressions extend NullIntolerant
+ // and override the eval method to avoid evaluating input1 if input2 is 0.
+ val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])
+
+ val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
+ .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
+ .filterNot(c => ignoreSet.exists(_.getName.equals(c)))
+ .map(name => Utils.classForName(name))
+ .filterNot(classOf[NonSQLExpression].isAssignableFrom)
+
+ exprTypesToCheck.foreach { superClass =>
+ candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz =>
+ val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) !=
+ superClass.getMethod("eval", classOf[InternalRow])
+ val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
+ if (isEvalOverrode && isNullIntolerantMixedIn) {
+ fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
+ s"or add ${clazz.getName} in the ignoreSet of this test.")
+ } else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
+ fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
+ } else {
+ assert((!isEvalOverrode && isNullIntolerantMixedIn) ||
+ (isEvalOverrode && !isNullIntolerantMixedIn))
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org