You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/04/01 22:00:59 UTC

spark git commit: [SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression

Repository: spark
Updated Branches:
  refs/heads/master 381358fbe -> df68beb85


[SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression

## What changes were proposed in this pull request?

JIRA: https://issues.apache.org/jira/browse/SPARK-13995

We infer relative `IsNotNull` constraints from logical plan's expressions in `constructIsNotNullConstraints` now. However, we don't consider the case of (nested) `Cast`.

For example:

    val tr = LocalRelation('a.int, 'b.long)
    val plan = tr.where('a.attr === 'b.attr).analyze

Then, the plan's constraints will have `IsNotNull(Cast(resolveColumn(tr, "a"), LongType))`, instead of `IsNotNull(resolveColumn(tr, "a"))`. This PR fixes it.

Besides, as `IsNotNull` constraints are most useful for `Attribute`, we should do recursing through any `Expression` that is null intolerant and construct `IsNotNull` constraints for all `Attribute`s under these Expressions.

For example, consider the following constraints:

    val df = Seq((1,2,3)).toDF("a", "b", "c")
    df.where("a + b = c").queryExecution.analyzed.constraints

The inferred isnotnull constraints should be isnotnull(a), isnotnull(b), isnotnull(c), instead of isnotnull(a + c) and isnotnull(c).

## How was this patch tested?

Test is added into `ConstraintPropagationSuite`.

Author: Liang-Chi Hsieh <si...@tw.ibm.com>

Closes #11809 from viirya/constraint-cast.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/df68beb8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/df68beb8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/df68beb8

Branch: refs/heads/master
Commit: df68beb85de59bb6d35b2a8a3b85dbc447798bf5
Parents: 381358f
Author: Liang-Chi Hsieh <si...@tw.ibm.com>
Authored: Fri Apr 1 13:00:55 2016 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Apr 1 13:00:55 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   |  2 +-
 .../sql/catalyst/expressions/arithmetic.scala   | 25 +++---
 .../catalyst/expressions/namedExpressions.scala |  2 +-
 .../sql/catalyst/expressions/package.scala      |  7 ++
 .../sql/catalyst/expressions/predicates.scala   | 17 ++--
 .../spark/sql/catalyst/plans/QueryPlan.scala    | 33 ++++----
 .../plans/ConstraintPropagationSuite.scala      | 85 +++++++++++++++++++-
 7 files changed, 134 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/df68beb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a965cc8..d842ffd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -112,7 +112,7 @@ object Cast {
 }
 
 /** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {
 
   override def toString: String = s"cast($child as ${dataType.simpleString})"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/df68beb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1e9c971..b388091 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 
 
-case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class UnaryMinus(child: Expression) extends UnaryExpression
+    with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
 
@@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
   override def sql: String = s"(-${child.sql})"
 }
 
-case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class UnaryPositive(child: Expression)
+    extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
   override def prettyName: String = "positive"
 
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
   extended = "> SELECT _FUNC_('-1');\n1")
-case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Abs(child: Expression)
+    extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
 
   override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
 
@@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic {
   def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
 }
 
-case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
   }
 }
 
-case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Subtract(left: Expression, right: Expression)
+    extends BinaryArithmetic with NullIntolerant {
 
   override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
 
@@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
   }
 }
 
-case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Multiply(left: Expression, right: Expression)
+    extends BinaryArithmetic with NullIntolerant {
 
   override def inputType: AbstractDataType = NumericType
 
@@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
   protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
 }
 
-case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Divide(left: Expression, right: Expression)
+    extends BinaryArithmetic with NullIntolerant {
 
   override def inputType: AbstractDataType = NumericType
 
@@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
   }
 }
 
-case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Remainder(left: Expression, right: Expression)
+    extends BinaryArithmetic with NullIntolerant {
 
   override def inputType: AbstractDataType = NumericType
 
@@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression)
   override def symbol: String = "min"
 }
 
-case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
 
   override def toString: String = s"pmod($left, $right)"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/df68beb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index a5b5758..262582c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -97,7 +97,7 @@ trait NamedExpression extends Expression {
     }
 }
 
-abstract class Attribute extends LeafExpression with NamedExpression {
+abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant {
 
   override def references: AttributeSet = AttributeSet(this)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/df68beb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index f1fa13d..23baa6f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -92,4 +92,11 @@ package object expressions  {
       StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
     }
   }
+
+  /**
+   * When an expression inherits this, meaning the expression is null intolerant (i.e. any null
+   * input will result in null output). We will use this information during constructing IsNotNull
+   * constraints.
+   */
+  trait NullIntolerant
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/df68beb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index e23ad55..4eb3325 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -90,7 +90,7 @@ trait PredicateHelper {
 
 
 case class Not(child: Expression)
-  extends UnaryExpression with Predicate with ImplicitCastInputTypes {
+  extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {
 
   override def toString: String = s"NOT $child"
 
@@ -402,7 +402,8 @@ private[sql] object Equality {
 }
 
 
-case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
+case class EqualTo(left: Expression, right: Expression)
+    extends BinaryComparison with NullIntolerant {
 
   override def inputType: AbstractDataType = AnyDataType
 
@@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
 }
 
 
-case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
+case class LessThan(left: Expression, right: Expression)
+    extends BinaryComparison with NullIntolerant {
 
   override def inputType: AbstractDataType = TypeCollection.Ordered
 
@@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
 }
 
 
-case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+case class LessThanOrEqual(left: Expression, right: Expression)
+    extends BinaryComparison with NullIntolerant {
 
   override def inputType: AbstractDataType = TypeCollection.Ordered
 
@@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
 }
 
 
-case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
+case class GreaterThan(left: Expression, right: Expression)
+    extends BinaryComparison with NullIntolerant {
 
   override def inputType: AbstractDataType = TypeCollection.Ordered
 
@@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
 }
 
 
-case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+case class GreaterThanOrEqual(left: Expression, right: Expression)
+    extends BinaryComparison with NullIntolerant {
 
   override def inputType: AbstractDataType = TypeCollection.Ordered
 

http://git-wip-us.apache.org/repos/asf/spark/blob/df68beb8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d31164f..22a4461 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
    * returns a constraint of the form `isNotNull(a)`
    */
   private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
-    var isNotNullConstraints = Set.empty[Expression]
-
-    // First, we propagate constraints if the condition consists of equality and ranges. For all
-    // other cases, we return an empty set of constraints
-    constraints.foreach {
-      case EqualTo(l, r) =>
-        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
-      case GreaterThan(l, r) =>
-        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
-      case GreaterThanOrEqual(l, r) =>
-        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
-      case LessThan(l, r) =>
-        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
-      case LessThanOrEqual(l, r) =>
-        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
-      case Not(EqualTo(l, r)) =>
-        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
-      case _ => // No inference
-    }
+    // First, we propagate constraints from the null intolerant expressions.
+    var isNotNullConstraints: Set[Expression] =
+      constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
 
     // Second, we infer additional constraints from non-nullable attributes that are part of the
     // operator's output
@@ -73,6 +57,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
   }
 
   /**
+   * Recursively explores the expressions which are null intolerant and returns all attributes
+   * in these expressions.
+   */
+  private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
+    case a: Attribute => Seq(a)
+    case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
+      expr.children.flatMap(scanNullIntolerantExpr)
+    case _ => Seq.empty[Attribute]
+  }
+
+  /**
    * Infers an additional set of constraints from a given set of equality constraints.
    * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
    * additional constraint of the form `b = 5`

http://git-wip-us.apache.org/repos/asf/spark/blob/df68beb8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index e506359..5cbb889 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
 
 class ConstraintPropagationSuite extends SparkFunSuite {
 
@@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
         IsNotNull(resolveColumn(tr, "b")))))
   }
 
+  test("infer constraints on cast") {
+    val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+    verifyConstraints(
+      tr.where('a.attr === 'b.attr &&
+        'c.attr + 100 > 'd.attr &&
+        IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
+      ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
+        Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
+        IsNotNull(resolveColumn(tr, "a")),
+        IsNotNull(resolveColumn(tr, "b")),
+        IsNotNull(resolveColumn(tr, "c")),
+        IsNotNull(resolveColumn(tr, "d")),
+        IsNotNull(resolveColumn(tr, "e")),
+        IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
+  }
+
+  test("infer isnotnull constraints from compound expressions") {
+    val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+    verifyConstraints(
+      tr.where('a.attr + 'b.attr === 'c.attr &&
+        IsNotNull(
+          Cast(
+            Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
+      ExpressionSet(Seq(
+        Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
+          Cast(resolveColumn(tr, "c"), LongType),
+        IsNotNull(resolveColumn(tr, "a")),
+        IsNotNull(resolveColumn(tr, "b")),
+        IsNotNull(resolveColumn(tr, "c")),
+        IsNotNull(resolveColumn(tr, "e")),
+        IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))
+
+    verifyConstraints(
+      tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
+      ExpressionSet(Seq(
+        Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
+          Cast(resolveColumn(tr, "c"), LongType),
+        Cast(resolveColumn(tr, "d"), DoubleType) /
+          Cast(Cast(10, LongType), DoubleType) ===
+            Cast(resolveColumn(tr, "e"), DoubleType),
+        IsNotNull(resolveColumn(tr, "a")),
+        IsNotNull(resolveColumn(tr, "b")),
+        IsNotNull(resolveColumn(tr, "c")),
+        IsNotNull(resolveColumn(tr, "d")),
+        IsNotNull(resolveColumn(tr, "e")))))
+
+    verifyConstraints(
+      tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
+      ExpressionSet(Seq(
+        Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
+          Cast(resolveColumn(tr, "c"), LongType),
+        Cast(resolveColumn(tr, "d"), DoubleType) /
+          Cast(Cast(10, LongType), DoubleType) <
+            Cast(resolveColumn(tr, "e"), DoubleType),
+        IsNotNull(resolveColumn(tr, "a")),
+        IsNotNull(resolveColumn(tr, "b")),
+        IsNotNull(resolveColumn(tr, "c")),
+        IsNotNull(resolveColumn(tr, "d")),
+        IsNotNull(resolveColumn(tr, "e")))))
+
+    verifyConstraints(
+      tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
+      ExpressionSet(Seq(
+        (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
+          (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
+            Cast(resolveColumn(tr, "e") * 1000, LongType),
+        IsNotNull(resolveColumn(tr, "a")),
+        IsNotNull(resolveColumn(tr, "b")),
+        IsNotNull(resolveColumn(tr, "c")),
+        IsNotNull(resolveColumn(tr, "d")),
+        IsNotNull(resolveColumn(tr, "e")))))
+
+    // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
+    verifyConstraints(
+      tr.where('a.attr === 'c.attr &&
+        IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
+      ExpressionSet(Seq(
+        resolveColumn(tr, "a") === resolveColumn(tr, "c"),
+        IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
+        IsNotNull(resolveColumn(tr, "a")),
+        IsNotNull(resolveColumn(tr, "c")))))
+  }
+
   test("infer IsNotNull constraints from non-nullable attributes") {
     val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
       AttributeReference("c", StringType, nullable = false)())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org