You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/05/29 07:30:19 UTC

[spark] branch master updated: [SPARK-28481][SQL] More expressions should extend NullIntolerant

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 91148f4  [SPARK-28481][SQL] More expressions should extend NullIntolerant
91148f4 is described below

commit 91148f428b61b5e44c17bd65ceff74e0a8b4b3f5
Author: Yuming Wang <yu...@ebay.com>
AuthorDate: Fri May 29 07:28:57 2020 +0000

    [SPARK-28481][SQL] More expressions should extend NullIntolerant
    
    ### What changes were proposed in this pull request?
    
    1. Make more expressions extend `NullIntolerant`.
    2. Add a checker(in `ExpressionInfoSuite`) to identify whether the expression is `NullIntolerant`.
    
    ### Why are the changes needed?
    
    Avoid skew join if the join column has many null values and can improve query performance. For examples:
    ```sql
    CREATE TABLE t1(c1 string, c2 string) USING parquet;
    CREATE TABLE t2(c1 string, c2 string) USING parquet;
    EXPLAIN SELECT t1.* FROM t1 JOIN t2 ON upper(t1.c1) = upper(t2.c1);
    ```
    
    Before and after this PR:
    ```sql
    == Physical Plan ==
    *(2) Project [c1#5, c2#6]
    +- *(2) BroadcastHashJoin [upper(c1#5)], [upper(c1#7)], Inner, BuildLeft
       :- BroadcastExchange HashedRelationBroadcastMode(List(upper(input[0, string, true]))), [id=#41]
       :  +- *(1) ColumnarToRow
       :     +- FileScan parquet default.t1[c1#5,c2#6]
       +- *(2) ColumnarToRow
          +- FileScan parquet default.t2[c1#7]
    
    == Physical Plan ==
    *(2) Project [c1#5, c2#6]
    +- *(2) BroadcastHashJoin [upper(c1#5)], [upper(c1#7)], Inner, BuildRight
       :- *(2) Project [c1#5, c2#6]
       :  +- *(2) Filter isnotnull(c1#5)
       :     +- *(2) ColumnarToRow
       :        +- FileScan parquet default.t1[c1#5,c2#6]
       +- BroadcastExchange HashedRelationBroadcastMode(List(upper(input[0, string, true]))), [id=#59]
          +- *(1) Project [c1#7]
             +- *(1) Filter isnotnull(c1#7)
                +- *(1) ColumnarToRow
                   +- FileScan parquet default.t2[c1#7]
    
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #28626 from wangyum/SPARK-28481.
    
    Authored-by: Yuming Wang <yu...@ebay.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/expressions/TimeWindow.scala      |  2 +-
 .../catalyst/expressions/bitwiseExpressions.scala  |  6 +-
 .../expressions/collectionOperations.scala         | 39 +++++++-----
 .../catalyst/expressions/complexTypeCreator.scala  |  4 +-
 .../sql/catalyst/expressions/csvExpressions.scala  |  3 +-
 .../catalyst/expressions/datetimeExpressions.scala | 73 +++++++++++++---------
 .../catalyst/expressions/decimalExpressions.scala  |  4 +-
 .../spark/sql/catalyst/expressions/hash.scala      | 11 ++--
 .../catalyst/expressions/intervalExpressions.scala |  6 +-
 .../sql/catalyst/expressions/jsonExpressions.scala |  6 +-
 .../sql/catalyst/expressions/mathExpressions.scala | 23 ++++---
 .../catalyst/expressions/regexpExpressions.scala   |  6 +-
 .../catalyst/expressions/stringExpressions.scala   | 58 ++++++++++-------
 .../spark/sql/catalyst/expressions/xml/xpath.scala |  3 +-
 .../sql/expressions/ExpressionInfoSuite.scala      | 39 +++++++++++-
 15 files changed, 180 insertions(+), 103 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 82d6894..f7fe467 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -144,7 +144,7 @@ object TimeWindow {
 case class PreciseTimestampConversion(
     child: Expression,
     fromType: DataType,
-    toType: DataType) extends UnaryExpression with ExpectsInputTypes {
+    toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
   override def inputTypes: Seq[AbstractDataType] = Seq(fromType)
   override def dataType: DataType = toType
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 7b819db..342b14e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
       > SELECT _FUNC_ 0;
        -1
   """)
-case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class BitwiseNot(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
 
@@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
        0
   """,
   since = "3.0.0")
-case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class BitwiseCount(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType))
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 4fd68dc..b32e9ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -141,7 +141,7 @@ object Size {
   """,
   group = "map_funcs")
 case class MapKeys(child: Expression)
-  extends UnaryExpression with ExpectsInputTypes {
+  extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
 
@@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
   """,
   group = "map_funcs")
 case class MapValues(child: Expression)
-  extends UnaryExpression with ExpectsInputTypes {
+  extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
 
@@ -361,7 +361,8 @@ case class MapValues(child: Expression)
   """,
   group = "map_funcs",
   since = "3.0.0")
-case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class MapEntries(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
 
@@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
   """,
   group = "map_funcs",
   since = "2.4.0")
-case class MapFromEntries(child: Expression) extends UnaryExpression {
+case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant {
 
   @transient
   private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
@@ -873,7 +874,7 @@ object ArraySortLike {
   group = "array_funcs")
 // scalastyle:on line.size.limit
 case class SortArray(base: Expression, ascendingOrder: Expression)
-  extends BinaryExpression with ArraySortLike {
+  extends BinaryExpression with ArraySortLike with NullIntolerant {
 
   def this(e: Expression) = this(e, Literal(true))
 
@@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
     Reverse logic for arrays is available since 2.4.0.
   """
 )
-case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Reverse(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   // Input types are utilized by type coercion in ImplicitTypeCasts.
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
@@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
   """,
   group = "array_funcs")
 case class ArrayContains(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = BooleanType
 
@@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression)
   since = "2.4.0")
 // scalastyle:off line.size.limit
 case class ArraysOverlap(left: Expression, right: Expression)
-  extends BinaryArrayExpressionWithImplicitCast {
+  extends BinaryArrayExpressionWithImplicitCast with NullIntolerant {
 
   override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
     case TypeCheckResult.TypeCheckSuccess =>
@@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
   since = "2.4.0")
 // scalastyle:on line.size.limit
 case class Slice(x: Expression, start: Expression, length: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = x.dataType
 
@@ -1688,7 +1690,8 @@ case class ArrayJoin(
   """,
   group = "array_funcs",
   since = "2.4.0")
-case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class ArrayMin(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def nullable: Boolean = true
 
@@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
   """,
   group = "array_funcs",
   since = "2.4.0")
-case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class ArrayMax(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def nullable: Boolean = true
 
@@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
   group = "array_funcs",
   since = "2.4.0")
 case class ArrayPosition(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   @transient private lazy val ordering: Ordering[Any] =
     TypeUtils.getInterpretedOrdering(right.dataType)
@@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression)
   """,
   since = "2.4.0")
 case class ElementAt(left: Expression, right: Expression)
-  extends GetMapValueUtil with GetArrayItemUtil {
+  extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {
 
   @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
 
@@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
   """,
   group = "array_funcs",
   since = "2.4.0")
-case class Flatten(child: Expression) extends UnaryExpression {
+case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant {
 
   private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
 
@@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
   group = "array_funcs",
   since = "2.4.0")
 case class ArrayRemove(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = left.dataType
 
@@ -3081,7 +3085,7 @@ trait ArraySetLike {
   group = "array_funcs",
   since = "2.4.0")
 case class ArrayDistinct(child: Expression)
-  extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
+  extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
 
@@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression)
 /**
  * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]].
  */
-trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike {
+trait ArrayBinaryLike
+  extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant {
   override protected def dt: DataType = dataType
   override protected def et: DataType = elementType
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 5212ef3..1b4a705 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -255,7 +255,7 @@ object CreateMap {
        {1.0:"2",3.0:"4"}
   """, since = "2.4.0")
 case class MapFromArrays(left: Expression, right: Expression)
-  extends BinaryExpression with ExpectsInputTypes {
+  extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)
 
@@ -476,7 +476,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
   since = "2.0.1")
 // scalastyle:on line.size.limit
 case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
-  extends TernaryExpression with ExpectsInputTypes {
+  extends TernaryExpression with ExpectsInputTypes with NullIntolerant {
 
   def this(child: Expression, pairDelim: Expression) = {
     this(child, pairDelim, Literal(":"))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index 5140db9..f9ccf3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -211,7 +211,8 @@ case class StructsToCsv(
      options: Map[String, String],
      child: Expression,
      timeZoneId: Option[String] = None)
-  extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+  extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes
+    with NullIntolerant {
   override def nullable: Boolean = true
 
   def this(options: Map[String, String], child: Expression) = this(options, child, None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 7dc008a..4f3db1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -198,7 +198,7 @@ case class CurrentBatchTimestamp(
   group = "datetime_funcs",
   since = "1.5.0")
 case class DateAdd(startDate: Expression, days: Expression)
-  extends BinaryExpression with ExpectsInputTypes {
+  extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def left: Expression = startDate
   override def right: Expression = days
@@ -234,7 +234,7 @@ case class DateAdd(startDate: Expression, days: Expression)
   group = "datetime_funcs",
   since = "1.5.0")
 case class DateSub(startDate: Expression, days: Expression)
-  extends BinaryExpression with ExpectsInputTypes {
+  extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
   override def left: Expression = startDate
   override def right: Expression = days
 
@@ -266,7 +266,8 @@ case class DateSub(startDate: Expression, days: Expression)
   group = "datetime_funcs",
   since = "1.5.0")
 case class Hour(child: Expression, timeZoneId: Option[String] = None)
-  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+    with NullIntolerant {
 
   def this(child: Expression) = this(child, None)
 
@@ -298,7 +299,8 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None)
   group = "datetime_funcs",
   since = "1.5.0")
 case class Minute(child: Expression, timeZoneId: Option[String] = None)
-  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+    with NullIntolerant {
 
   def this(child: Expression) = this(child, None)
 
@@ -330,7 +332,8 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None)
   group = "datetime_funcs",
   since = "1.5.0")
 case class Second(child: Expression, timeZoneId: Option[String] = None)
-  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+    with NullIntolerant {
 
   def this(child: Expression) = this(child, None)
 
@@ -353,7 +356,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None)
 }
 
 case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None)
-  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+  extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+    with NullIntolerant {
 
   def this(child: Expression) = this(child, None)
 
@@ -385,7 +389,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No
   """,
   group = "datetime_funcs",
   since = "1.5.0")
-case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class DayOfYear(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -402,7 +407,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas
 }
 
 abstract class NumberToTimestampBase extends UnaryExpression
-  with ExpectsInputTypes {
+  with ExpectsInputTypes with NullIntolerant {
 
   protected def upScaleFactor: Long
 
@@ -487,7 +492,8 @@ case class MicrosToTimestamp(child: Expression)
   """,
   group = "datetime_funcs",
   since = "1.5.0")
-case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Year(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -503,7 +509,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu
   }
 }
 
-case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class YearOfWeek(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -528,7 +535,8 @@ case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCa
   """,
   group = "datetime_funcs",
   since = "1.5.0")
-case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Quarter(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -553,7 +561,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI
   """,
   group = "datetime_funcs",
   since = "1.5.0")
-case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Month(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -577,7 +586,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp
        30
   """,
   since = "1.5.0")
-case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class DayOfMonth(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -647,7 +657,7 @@ case class WeekDay(child: Expression) extends DayWeek {
   }
 }
 
-abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes {
+abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -665,7 +675,8 @@ abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes {
   group = "datetime_funcs",
   since = "1.5.0")
 // scalastyle:on line.size.limit
-case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class WeekOfYear(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
 
@@ -704,7 +715,8 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None)
-  extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+    with NullIntolerant {
 
   def this(left: Expression, right: Expression) = this(left, right, None)
 
@@ -1154,7 +1166,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
   """,
   group = "datetime_funcs",
   since = "1.5.0")
-case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class LastDay(startDate: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
   override def child: Expression = startDate
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -1192,7 +1205,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class NextDay(startDate: Expression, dayOfWeek: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def left: Expression = startDate
   override def right: Expression = dayOfWeek
@@ -1248,7 +1261,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
  * Adds an interval to timestamp.
  */
 case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None)
-  extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {
+  extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant {
 
   def this(start: Expression, interval: Expression) = this(start, interval, None)
 
@@ -1306,7 +1319,7 @@ case class DateAddInterval(
     interval: Expression,
     timeZoneId: Option[String] = None,
     ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
-  extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression {
+  extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant {
 
   override def left: Expression = start
   override def right: Expression = interval
@@ -1380,7 +1393,7 @@ case class DateAddInterval(
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class FromUTCTimestamp(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
   override def dataType: DataType = TimestampType
@@ -1440,7 +1453,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class AddMonths(startDate: Expression, numMonths: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def left: Expression = startDate
   override def right: Expression = numMonths
@@ -1494,7 +1507,8 @@ case class MonthsBetween(
     date2: Expression,
     roundOff: Expression,
     timeZoneId: Option[String] = None)
-  extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+    with NullIntolerant {
 
   def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None)
 
@@ -1552,7 +1566,7 @@ case class MonthsBetween(
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class ToUTCTimestamp(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
   override def dataType: DataType = TimestampType
@@ -1906,7 +1920,7 @@ case class TruncTimestamp(
   group = "datetime_funcs",
   since = "1.5.0")
 case class DateDiff(endDate: Expression, startDate: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def left: Expression = endDate
   override def right: Expression = startDate
@@ -1960,7 +1974,7 @@ private case class GetTimestamp(
   group = "datetime_funcs",
   since = "3.0.0")
 case class MakeDate(year: Expression, month: Expression, day: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def children: Seq[Expression] = Seq(year, month, day)
   override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType)
@@ -2031,7 +2045,8 @@ case class MakeTimestamp(
     sec: Expression,
     timezone: Option[Expression] = None,
     timeZoneId: Option[String] = None)
-  extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+  extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+    with NullIntolerant {
 
   def this(
       year: Expression,
@@ -2307,7 +2322,7 @@ case class Extract(field: Expression, source: Expression, child: Expression)
  * between the given timestamps.
  */
 case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression)
-  extends BinaryExpression with ExpectsInputTypes {
+  extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def left: Expression = endTimestamp
   override def right: Expression = startTimestamp
@@ -2328,7 +2343,7 @@ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expressi
  * Returns the interval from the `left` date (inclusive) to the `right` date (exclusive).
  */
 case class SubtractDates(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType)
   override def dataType: DataType = CalendarIntervalType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 9014ebf..c2c70b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
  * Note: this expression is internal and created only by the optimizer,
  * we don't need to do type check for it.
  */
-case class UnscaledValue(child: Expression) extends UnaryExpression {
+case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant {
 
   override def dataType: DataType = LongType
   override def toString: String = s"UnscaledValue($child)"
@@ -49,7 +49,7 @@ case class MakeDecimal(
     child: Expression,
     precision: Int,
     scale: Int,
-    nullOnOverflow: Boolean) extends UnaryExpression {
+    nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant {
 
   def this(child: Expression, precision: Int, scale: Int) = {
     this(child, precision, scale, !SQLConf.get.ansiEnabled)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 4c8c58a..5e21b58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -53,7 +53,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
       > SELECT _FUNC_('Spark');
        8cde774d6f7333752ed72cacddb05126
   """)
-case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Md5(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = StringType
 
@@ -89,7 +90,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput
   """)
 // scalastyle:on line.size.limit
 case class Sha2(left: Expression, right: Expression)
-  extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
 
   override def dataType: DataType = StringType
   override def nullable: Boolean = true
@@ -160,7 +161,8 @@ case class Sha2(left: Expression, right: Expression)
       > SELECT _FUNC_('Spark');
        85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c
   """)
-case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Sha1(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = StringType
 
@@ -187,7 +189,8 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu
       > SELECT _FUNC_('Spark');
        1557323817
   """)
-case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Crc32(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = LongType
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 1a569a7..baab224 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -31,7 +31,7 @@ abstract class ExtractIntervalPart(
     val dataType: DataType,
     func: CalendarInterval => Any,
     funcName: String)
-  extends UnaryExpression with ExpectsInputTypes with Serializable {
+  extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType)
 
@@ -82,7 +82,7 @@ object ExtractIntervalPart {
 abstract class IntervalNumOperation(
     interval: Expression,
     num: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes with Serializable {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
   override def left: Expression = interval
   override def right: Expression = num
 
@@ -160,7 +160,7 @@ case class MakeInterval(
     hours: Expression,
     mins: Expression,
     secs: Expression)
-  extends SeptenaryExpression with ImplicitCastInputTypes {
+  extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   def this(
       years: Expression,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 205e527..f4568f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -519,7 +519,8 @@ case class JsonToStructs(
     options: Map[String, String],
     child: Expression,
     timeZoneId: Option[String] = None)
-  extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+  extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes
+    with NullIntolerant {
 
   // The JSON input data might be missing certain fields. We force the nullability
   // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
@@ -638,7 +639,8 @@ case class StructsToJson(
     options: Map[String, String],
     child: Expression,
     timeZoneId: Option[String] = None)
-  extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+  extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback
+    with ExpectsInputTypes with NullIntolerant {
   override def nullable: Boolean = true
 
   def this(options: Map[String, String], child: Expression) = this(options, child, None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 8c6fbc0..fe8ea2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -57,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String)
  * @param name The short name of the function
  */
 abstract class UnaryMathExpression(val f: Double => Double, name: String)
-  extends UnaryExpression with Serializable with ImplicitCastInputTypes {
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
   override def dataType: DataType = DoubleType
@@ -111,7 +111,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String)
  * @param name The short name of the function
  */
 abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
-  extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
 
   override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
 
@@ -324,7 +324,7 @@ case class Acosh(child: Expression)
        -16
   """)
 case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
@@ -452,7 +452,8 @@ object Factorial {
       > SELECT _FUNC_(5);
        120
   """)
-case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Factorial(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[DataType] = Seq(IntegerType)
 
@@ -735,7 +736,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
   """)
 // scalastyle:on line.size.limit
 case class Bin(child: Expression)
-  extends UnaryExpression with Serializable with ImplicitCastInputTypes {
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
 
   override def inputTypes: Seq[DataType] = Seq(LongType)
   override def dataType: DataType = StringType
@@ -834,7 +835,8 @@ object Hex {
       > SELECT _FUNC_('Spark SQL');
        537061726B2053514C
   """)
-case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Hex(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(LongType, BinaryType, StringType))
@@ -869,7 +871,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput
       > SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8');
        Spark SQL
   """)
-case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Unhex(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
 
@@ -955,7 +958,7 @@ case class Pow(left: Expression, right: Expression)
        4
   """)
 case class ShiftLeft(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -989,7 +992,7 @@ case class ShiftLeft(left: Expression, right: Expression)
        2
   """)
 case class ShiftRight(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -1023,7 +1026,7 @@ case class ShiftRight(left: Expression, right: Expression)
        2
   """)
 case class ShiftRightUnsigned(left: Expression, right: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(IntegerType, LongType), IntegerType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 3f60ca3..28924fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -283,7 +283,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
   """,
   since = "1.5.0")
 case class StringSplit(str: Expression, regex: Expression, limit: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = ArrayType(StringType)
   override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -325,7 +325,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression)
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   // last regex in string, we will update the pattern iff regexp value changed.
   @transient private var lastRegex: UTF8String = _
@@ -433,7 +433,7 @@ object RegExpExtract {
   """,
   since = "1.5.0")
 case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
   def this(s: Expression, r: Expression) = this(s, r, Literal(1))
 
   // last regex in string, we will update the pattern iff regexp value changed.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 876588e..334a079 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -334,7 +334,7 @@ trait String2StringExpression extends ImplicitCastInputTypes {
   """,
   since = "1.0.1")
 case class Upper(child: Expression)
-  extends UnaryExpression with String2StringExpression {
+  extends UnaryExpression with String2StringExpression with NullIntolerant {
 
   // scalastyle:off caselocale
   override def convert(v: UTF8String): UTF8String = v.toUpperCase
@@ -356,7 +356,8 @@ case class Upper(child: Expression)
        sparksql
   """,
   since = "1.0.1")
-case class Lower(child: Expression) extends UnaryExpression with String2StringExpression {
+case class Lower(child: Expression)
+  extends UnaryExpression with String2StringExpression with NullIntolerant {
 
   // scalastyle:off caselocale
   override def convert(v: UTF8String): UTF8String = v.toLowerCase
@@ -435,7 +436,7 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
   since = "2.3.0")
 // scalastyle:on line.size.limit
 case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   def this(srcExpr: Expression, searchExpr: Expression) = {
     this(srcExpr, searchExpr, Literal(""))
@@ -601,7 +602,7 @@ object StringTranslate {
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   @transient private var lastMatching: UTF8String = _
   @transient private var lastReplace: UTF8String = _
@@ -666,7 +667,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
-    with ImplicitCastInputTypes {
+    with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
 
@@ -1038,7 +1039,7 @@ case class StringTrimRight(
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class StringInstr(str: Expression, substr: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def left: Expression = str
   override def right: Expression = substr
@@ -1080,7 +1081,7 @@ case class StringInstr(str: Expression, substr: Expression)
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -1209,7 +1210,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
   """,
   since = "1.5.0")
 case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   def this(str: Expression, len: Expression) = {
     this(str, len, Literal(" "))
@@ -1250,7 +1251,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera
   """,
   since = "1.5.0")
 case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   def this(str: Expression, len: Expression) = {
     this(str, len, Literal(" "))
@@ -1540,7 +1541,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
        Spark Sql
   """,
   since = "1.5.0")
-case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class InitCap(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[DataType] = Seq(StringType)
   override def dataType: DataType = StringType
@@ -1567,7 +1569,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI
   """,
   since = "1.5.0")
 case class StringRepeat(str: Expression, times: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def left: Expression = str
   override def right: Expression = times
@@ -1597,7 +1599,7 @@ case class StringRepeat(str: Expression, times: Expression)
   """,
   since = "1.5.0")
 case class StringSpace(child: Expression)
-  extends UnaryExpression with ImplicitCastInputTypes {
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -1742,7 +1744,8 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run
   """,
   since = "1.5.0")
 // scalastyle:on line.size.limit
-case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Length(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
   override def dataType: DataType = IntegerType
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
 
@@ -1770,7 +1773,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn
        72
   """,
   since = "2.3.0")
-case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class BitLength(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
   override def dataType: DataType = IntegerType
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
 
@@ -1801,7 +1805,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas
        9
   """,
   since = "2.3.0")
-case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class OctetLength(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
   override def dataType: DataType = IntegerType
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
 
@@ -1832,7 +1837,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC
   """,
   since = "1.5.0")
 case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
-    with ImplicitCastInputTypes {
+    with ImplicitCastInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
 
@@ -1857,7 +1862,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
        M460
   """,
   since = "1.5.0")
-case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class SoundEx(child: Expression)
+  extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def dataType: DataType = StringType
 
@@ -1883,7 +1889,8 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT
        50
   """,
   since = "1.5.0")
-case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Ascii(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = IntegerType
   override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -1925,7 +1932,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp
   """,
   since = "2.3.0")
 // scalastyle:on line.size.limit
-case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Chr(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(LongType)
@@ -1968,7 +1976,8 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput
        U3BhcmsgU1FM
   """,
   since = "1.5.0")
-case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Base64(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(BinaryType)
@@ -1996,7 +2005,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
        Spark SQL
   """,
   since = "1.5.0")
-case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class UnBase64(child: Expression)
+  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def dataType: DataType = BinaryType
   override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -2028,7 +2038,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class Decode(bin: Expression, charset: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def left: Expression = bin
   override def right: Expression = charset
@@ -2068,7 +2078,7 @@ case class Decode(bin: Expression, charset: Expression)
   since = "1.5.0")
 // scalastyle:on line.size.limit
 case class Encode(value: Expression, charset: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes {
+  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   override def left: Expression = value
   override def right: Expression = charset
@@ -2112,7 +2122,7 @@ case class Encode(value: Expression, charset: Expression)
   """,
   since = "1.5.0")
 case class FormatNumber(x: Expression, d: Expression)
-  extends BinaryExpression with ExpectsInputTypes {
+  extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def left: Expression = x
   override def right: Expression = d
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
index 55e06cb..e08a10e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -30,7 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String
  *
  * This is not the world's most efficient implementation due to type conversion, but works.
  */
-abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+abstract class XPathExtract
+  extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant {
   override def left: Expression = xml
   override def right: Expression = path
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index e18514c..53f9757 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.expressions
 import scala.collection.parallel.immutable.ParVector
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.HiveResult.hiveResultString
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.Utils
 
 class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
 
@@ -156,4 +157,38 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
       }
     }
   }
+
+  test("Check whether SQL expressions should extend NullIntolerant") {
+    // Only check expressions extended from these expressions because these expressions are
+    // NullIntolerant by default.
+    val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
+      classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])
+
+    // Do not check these expressions, because these expressions extend NullIntolerant
+    // and override the eval method to avoid evaluating input1 if input2 is 0.
+    val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])
+
+    val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
+      .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
+      .filterNot(c => ignoreSet.exists(_.getName.equals(c)))
+      .map(name => Utils.classForName(name))
+      .filterNot(classOf[NonSQLExpression].isAssignableFrom)
+
+    exprTypesToCheck.foreach { superClass =>
+      candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz =>
+        val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) !=
+          superClass.getMethod("eval", classOf[InternalRow])
+        val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
+        if (isEvalOverrode && isNullIntolerantMixedIn) {
+          fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
+            s"or add ${clazz.getName} in the ignoreSet of this test.")
+        } else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
+          fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
+        } else {
+          assert((!isEvalOverrode && isNullIntolerantMixedIn) ||
+            (isEvalOverrode && !isNullIntolerantMixedIn))
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org