You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/03/30 20:36:40 UTC

spark git commit: [SPARK-20121][SQL] simplify NullPropagation with NullIntolerant

Repository: spark
Updated Branches:
  refs/heads/master 5e00a5de1 -> c734fc504


[SPARK-20121][SQL] simplify NullPropagation with NullIntolerant

## What changes were proposed in this pull request?

Instead of iterating all expressions that can return null for null inputs, we can just check `NullIntolerant`.

## How was this patch tested?

existing tests

Author: Wenchen Fan <we...@databricks.com>

Closes #17450 from cloud-fan/null.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c734fc50
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c734fc50
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c734fc50

Branch: refs/heads/master
Commit: c734fc504a3f6a3d3b0bd90ff54604b17df2b413
Parents: 5e00a5d
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Mar 30 13:36:36 2017 -0700
Committer: Xiao Li <ga...@gmail.com>
Committed: Thu Mar 30 13:36:36 2017 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 18 +++---
 .../expressions/complexTypeExtractors.scala     |  8 +--
 .../sql/catalyst/expressions/package.scala      |  2 +-
 .../expressions/regexpExpressions.scala         | 10 ++--
 .../expressions/stringExpressions.scala         | 15 ++---
 .../sql/catalyst/optimizer/expressions.scala    | 59 ++++++--------------
 6 files changed, 39 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c734fc50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 4870093..f2b2522 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -113,7 +113,7 @@ case class Abs(child: Expression)
   protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
 }
 
-abstract class BinaryArithmetic extends BinaryOperator {
+abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
 
   override def dataType: DataType = left.dataType
 
@@ -146,7 +146,7 @@ object BinaryArithmetic {
       > SELECT 1 _FUNC_ 2;
        3
   """)
-case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
+case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit
       > SELECT 2 _FUNC_ 1;
        1
   """)
-case class Subtract(left: Expression, right: Expression)
-    extends BinaryArithmetic with NullIntolerant {
+case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression)
       > SELECT 2 _FUNC_ 3;
        6
   """)
-case class Multiply(left: Expression, right: Expression)
-    extends BinaryArithmetic with NullIntolerant {
+case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
 
   override def inputType: AbstractDataType = NumericType
 
@@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression)
        1.0
   """)
 // scalastyle:on line.size.limit
-case class Divide(left: Expression, right: Expression)
-    extends BinaryArithmetic with NullIntolerant {
+case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
 
   override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
 
@@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression)
       > SELECT 2 _FUNC_ 1.8;
        0.2
   """)
-case class Remainder(left: Expression, right: Expression)
-    extends BinaryArithmetic with NullIntolerant {
+case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
 
   override def inputType: AbstractDataType = NumericType
 
@@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression)
       > SELECT _FUNC_(-10, 3);
        2
   """)
-case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
 
   override def toString: String = s"pmod($left, $right)"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c734fc50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 0c256c3..de1594d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -104,7 +104,7 @@ trait ExtractValue extends Expression
  * For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
  */
 case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
-  extends UnaryExpression with ExtractValue {
+  extends UnaryExpression with ExtractValue with NullIntolerant {
 
   lazy val childSchema = child.dataType.asInstanceOf[StructType]
 
@@ -152,7 +152,7 @@ case class GetArrayStructFields(
     field: StructField,
     ordinal: Int,
     numFields: Int,
-    containsNull: Boolean) extends UnaryExpression with ExtractValue {
+    containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant {
 
   override def dataType: DataType = ArrayType(field.dataType, containsNull)
   override def toString: String = s"$child.${field.name}"
@@ -213,7 +213,7 @@ case class GetArrayStructFields(
  * We need to do type checking here as `ordinal` expression maybe unresolved.
  */
 case class GetArrayItem(child: Expression, ordinal: Expression)
-  extends BinaryExpression with ExpectsInputTypes with ExtractValue {
+  extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {
 
   // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
@@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
  * We need to do type checking here as `key` expression maybe unresolved.
  */
 case class GetMapValue(child: Expression, key: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes with ExtractValue {
+  extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {
 
   private def keyType = child.dataType.asInstanceOf[MapType].keyType
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c734fc50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 1b00c9e..4c8b177 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -138,5 +138,5 @@ package object expressions  {
    * input will result in null output). We will use this information during constructing IsNotNull
    * constraints.
    */
-  trait NullIntolerant
+  trait NullIntolerant extends Expression
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c734fc50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 4896a62..b23da53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 
-trait StringRegexExpression extends ImplicitCastInputTypes {
-  self: BinaryExpression =>
+abstract class StringRegexExpression extends BinaryExpression
+  with ImplicitCastInputTypes with NullIntolerant {
 
   def escape(v: String): String
   def matches(regex: Pattern, str: String): Boolean
@@ -69,8 +69,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
  */
 @ExpressionDescription(
   usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.")
-case class Like(left: Expression, right: Expression)
-  extends BinaryExpression with StringRegexExpression {
+case class Like(left: Expression, right: Expression) extends StringRegexExpression {
 
   override def escape(v: String): String = StringUtils.escapeLikeRegex(v)
 
@@ -122,8 +121,7 @@ case class Like(left: Expression, right: Expression)
 
 @ExpressionDescription(
   usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.")
-case class RLike(left: Expression, right: Expression)
-  extends BinaryExpression with StringRegexExpression {
+case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
 
   override def escape(v: String): String = v
   override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/c734fc50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 908aa44..5598a14 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx
 }
 
 /** A base trait for functions that compare two strings, returning a boolean. */
-trait StringPredicate extends Predicate with ImplicitCastInputTypes {
-  self: BinaryExpression =>
+abstract class StringPredicate extends BinaryExpression
+  with Predicate with ImplicitCastInputTypes with NullIntolerant {
 
   def compare(l: UTF8String, r: UTF8String): Boolean
 
@@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes {
 /**
  * A function that returns true if the string `left` contains the string `right`.
  */
-case class Contains(left: Expression, right: Expression)
-    extends BinaryExpression with StringPredicate {
+case class Contains(left: Expression, right: Expression) extends StringPredicate {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
@@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression)
 /**
  * A function that returns true if the string `left` starts with the string `right`.
  */
-case class StartsWith(left: Expression, right: Expression)
-    extends BinaryExpression with StringPredicate {
+case class StartsWith(left: Expression, right: Expression) extends StringPredicate {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
@@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression)
 /**
  * A function that returns true if the string `left` ends with the string `right`.
  */
-case class EndsWith(left: Expression, right: Expression)
-    extends BinaryExpression with StringPredicate {
+case class EndsWith(left: Expression, right: Expression) extends StringPredicate {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
@@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression)
   """)
 // scalastyle:on line.size.limit
 case class Substring(str: Expression, pos: Expression, len: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
 
   def this(str: Expression, pos: Expression) = {
     this(str, pos, Literal(Integer.MAX_VALUE))

http://git-wip-us.apache.org/repos/asf/spark/blob/c734fc50/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 21d1cd5..3303912 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -347,35 +347,30 @@ object LikeSimplification extends Rule[LogicalPlan] {
  * Null value propagation from bottom to top of the expression tree.
  */
 case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] {
-  private def nonNullLiteral(e: Expression): Boolean = e match {
-    case Literal(null, _) => false
-    case _ => true
+  private def isNullLiteral(e: Expression): Boolean = e match {
+    case Literal(null, _) => true
+    case _ => false
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
       case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
         Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
-      case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
+      case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) =>
         Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
-      case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
-      case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
-      case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
-      case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
-      case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
-      case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
-      case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
-      case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
-        Literal.create(null, e.dataType)
-      case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
-      case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
       case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
         // This rule should be only triggered when isDistinct field is false.
         ae.copy(aggregateFunction = Count(Literal(1)))
 
+      case IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
+      case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
+
+      case EqualNullSafe(Literal(null, _), r) => IsNull(r)
+      case EqualNullSafe(l, Literal(null, _)) => IsNull(l)
+
       // For Coalesce, remove null literals.
       case e @ Coalesce(children) =>
-        val newChildren = children.filter(nonNullLiteral)
+        val newChildren = children.filterNot(isNullLiteral)
         if (newChildren.isEmpty) {
           Literal.create(null, e.dataType)
         } else if (newChildren.length == 1) {
@@ -384,33 +379,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] {
           Coalesce(newChildren)
         }
 
-      case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
-      case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
-      case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)
-
-      // Put exceptional cases above if any
-      case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
-      case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
-
-      case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
-      case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
-
-      case e: StringRegexExpression => e.children match {
-        case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
-        case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
-        case _ => e
-      }
-
-      case e: StringPredicate => e.children match {
-        case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
-        case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
-        case _ => e
-      }
-
-      // If the value expression is NULL then transform the In expression to
-      // Literal(null)
-      case In(Literal(null, _), list) => Literal.create(null, BooleanType)
+      // If the value expression is NULL then transform the In expression to null literal.
+      case In(Literal(null, _), _) => Literal.create(null, BooleanType)
 
+      // Non-leaf NullIntolerant expressions will return null, if at least one of its children is
+      // a null literal.
+      case e: NullIntolerant if e.children.exists(isNullLiteral) =>
+        Literal.create(null, e.dataType)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org