You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/10/26 15:09:51 UTC

spark git commit: [SPARK-17733][SQL] InferFiltersFromConstraints rule never terminates for query

Repository: spark
Updated Branches:
  refs/heads/master 402205ddf -> 3c023570b


[SPARK-17733][SQL] InferFiltersFromConstraints rule never terminates for query

## What changes were proposed in this pull request?

The function `QueryPlan.inferAdditionalConstraints` and `UnaryNode.getAliasedConstraints` can produce a non-converging set of constraints for recursive functions. For instance, if we have two constraints of the form(where a is an alias):
`a = b, a = f(b, c)`
Applying both these rules in the next iteration would infer:
`f(b, c) = f(f(b, c), c)`
This process repeated, the iteration won't converge and the set of constraints will grow larger and larger until OOM.

~~To fix this problem, we collect alias from expressions and skip infer constraints if we are to transform an `Expression` to another which contains it.~~
To fix this problem, we apply additional check in `inferAdditionalConstraints`, when it's possible to generate recursive constraints, we skip generate that.

## How was this patch tested?

Add new testcase in `SQLQuerySuite`/`InferFiltersFromConstraintsSuite`.

Author: jiangxingbo <ji...@gmail.com>

Closes #15319 from jiangxb1987/constraints.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3c023570
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3c023570
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3c023570

Branch: refs/heads/master
Commit: 3c023570b28bc1ed24f5b2448311130fd1777fd3
Parents: 402205d
Author: jiangxingbo <ji...@gmail.com>
Authored: Wed Oct 26 17:09:48 2016 +0200
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Wed Oct 26 17:09:48 2016 +0200

----------------------------------------------------------------------
 .../spark/sql/catalyst/plans/QueryPlan.scala    | 88 ++++++++++++++++++--
 .../InferFiltersFromConstraintsSuite.scala      | 87 ++++++++++++++++++-
 .../spark/sql/catalyst/plans/PlanTest.scala     | 25 +++++-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  5 +-
 4 files changed, 191 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3c023570/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0fb6e7d..45ee296 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -68,26 +68,104 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
     case _ => Seq.empty[Attribute]
   }
 
+  // Collect aliases from expressions, so we may avoid producing recursive constraints.
+  private lazy val aliasMap = AttributeMap(
+    (expressions ++ children.flatMap(_.expressions)).collect {
+      case a: Alias => (a.toAttribute, a.child)
+    })
+
   /**
    * Infers an additional set of constraints from a given set of equality constraints.
    * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
-   * additional constraint of the form `b = 5`
+   * additional constraint of the form `b = 5`.
+   *
+   * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
+   * as they are often useless and can lead to a non-converging set of constraints.
    */
   private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
+    val constraintClasses = generateEquivalentConstraintClasses(constraints)
+
     var inferredConstraints = Set.empty[Expression]
     constraints.foreach {
       case eq @ EqualTo(l: Attribute, r: Attribute) =>
-        inferredConstraints ++= (constraints - eq).map(_ transform {
-          case a: Attribute if a.semanticEquals(l) => r
+        val candidateConstraints = constraints - eq
+        inferredConstraints ++= candidateConstraints.map(_ transform {
+          case a: Attribute if a.semanticEquals(l) &&
+            !isRecursiveDeduction(r, constraintClasses) => r
         })
-        inferredConstraints ++= (constraints - eq).map(_ transform {
-          case a: Attribute if a.semanticEquals(r) => l
+        inferredConstraints ++= candidateConstraints.map(_ transform {
+          case a: Attribute if a.semanticEquals(r) &&
+            !isRecursiveDeduction(l, constraintClasses) => l
         })
       case _ => // No inference
     }
     inferredConstraints -- constraints
   }
 
+  /*
+   * Generate a sequence of expression sets from constraints, where each set stores an equivalence
+   * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
+   * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
+   * to an selected attribute.
+   */
+  private def generateEquivalentConstraintClasses(
+      constraints: Set[Expression]): Seq[Set[Expression]] = {
+    var constraintClasses = Seq.empty[Set[Expression]]
+    constraints.foreach {
+      case eq @ EqualTo(l: Attribute, r: Attribute) =>
+        // Transform [[Alias]] to its child.
+        val left = aliasMap.getOrElse(l, l)
+        val right = aliasMap.getOrElse(r, r)
+        // Get the expression set for an equivalence constraint class.
+        val leftConstraintClass = getConstraintClass(left, constraintClasses)
+        val rightConstraintClass = getConstraintClass(right, constraintClasses)
+        if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
+          // Combine the two sets.
+          constraintClasses = constraintClasses
+            .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
+            (leftConstraintClass ++ rightConstraintClass)
+        } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
+          // Update equivalence class of `left` expression.
+          constraintClasses = constraintClasses
+            .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
+        } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
+          // Update equivalence class of `right` expression.
+          constraintClasses = constraintClasses
+            .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
+        } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
+          // Create new equivalence constraint class since neither expression presents
+          // in any classes.
+          constraintClasses = constraintClasses :+ Set(left, right)
+        }
+      case _ => // Skip
+    }
+
+    constraintClasses
+  }
+
+  /*
+   * Get all expressions equivalent to the selected expression.
+   */
+  private def getConstraintClass(
+      expr: Expression,
+      constraintClasses: Seq[Set[Expression]]): Set[Expression] =
+    constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])
+
+  /*
+   *  Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
+   *  has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
+   *  Here we first get all expressions equal to `attr` and then check whether at least one of them
+   *  is a child of the referenced expression.
+   */
+  private def isRecursiveDeduction(
+      attr: Attribute,
+      constraintClasses: Seq[Set[Expression]]): Boolean = {
+    val expr = aliasMap.getOrElse(attr, attr)
+    getConstraintClass(expr, constraintClasses).exists { e =>
+      expr.children.exists(_.semanticEquals(e))
+    }
+  }
+
   /**
    * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
    * example, if this set contains the expression `a = 2` then that expression is guaranteed to

http://git-wip-us.apache.org/repos/asf/spark/blob/3c023570/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
index e7fdd5a..9f57f66 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
@@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.rules._
 class InferFiltersFromConstraintsSuite extends PlanTest {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
-    val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) ::
-      Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) ::
-      Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil
+    val batches =
+      Batch("InferAndPushDownFilters", FixedPoint(100),
+        PushPredicateThroughJoin,
+        PushDownPredicate,
+        InferFiltersFromConstraints,
+        CombineFilters) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -120,4 +123,82 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
     val optimized = Optimize.execute(originalQuery)
     comparePlans(optimized, correctAnswer)
   }
+
+  test("inner join with alias: alias contains multiple attributes") {
+    val t1 = testRelation.subquery('t1)
+    val t2 = testRelation.subquery('t2)
+
+    val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+      .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
+      .analyze
+    val correctAnswer = t1
+      .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)))
+      .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+      .join(t2.where(IsNotNull('a)), Inner,
+        Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr))
+      .analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("inner join with alias: alias contains single attributes") {
+    val t1 = testRelation.subquery('t1)
+    val t2 = testRelation.subquery('t2)
+
+    val originalQuery = t1.select('a, 'b.as('d)).as("t")
+      .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
+      .analyze
+    val correctAnswer = t1
+      .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
+      .select('a, 'b.as('d)).as("t")
+      .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
+        Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
+      .analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("inner join with alias: don't generate constraints for recursive functions") {
+    val t1 = testRelation.subquery('t1)
+    val t2 = testRelation.subquery('t2)
+
+    val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+      .join(t2, Inner,
+        Some("t.a".attr === "t2.a".attr
+          && "t.d".attr === "t2.a".attr
+          && "t.int_col".attr === "t2.a".attr))
+      .analyze
+    val correctAnswer = t1
+      .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
+        && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
+        && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))
+        && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
+        && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b
+        && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
+        && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))
+        && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
+      .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
+      .join(t2
+        .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
+          && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
+          && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner,
+        Some("t.a".attr === "t2.a".attr
+          && "t.d".attr === "t2.a".attr
+          && "t.int_col".attr === "t2.a".attr
+          && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr))
+      .analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("generate correct filters for alias that don't produce recursive constraints") {
+    val t1 = testRelation.subquery('t1)
+
+    val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze
+    val correctAnswer =
+      t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b))
+        .select('a.as('x), 'b.as('y)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, correctAnswer)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3c023570/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 6310f0c..64e2687 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util._
 
 /**
@@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
    *   ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
    *   etc., will all now be equivalent.
    * - Sample the seed will replaced by 0L.
+   * - Join conditions will be resorted by hashCode.
    */
   private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
     plan transform {
       case filter @ Filter(condition: Expression, child: LogicalPlan) =>
-        Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
+        Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
+          .reduce(And), child)
       case sample: Sample =>
         sample.copy(seed = 0L)(true)
+      case join @ Join(left, right, joinType, condition) if condition.isDefined =>
+        val newCondition =
+          splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode())
+            .reduce(And)
+        Join(left, right, joinType, Some(newCondition))
     }
   }
 
+  /**
+   * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be
+   * equivalent:
+   * 1. (a = b), (b = a);
+   * 2. (a <=> b), (b <=> a).
+   */
+  private def rewriteEqual(condition: Expression): Expression = condition match {
+    case eq @ EqualTo(l: Expression, r: Expression) =>
+      Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo)
+    case eq @ EqualNullSafe(l: Expression, r: Expression) =>
+      Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe)
+    case _ => condition // Don't reorder.
+  }
+
   /** Fails the test if the two plans do not match */
   protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
     val normalized1 = normalizePlan(normalizeExprIds(plan1))

http://git-wip-us.apache.org/repos/asf/spark/blob/3c023570/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 60978ef..bd4c253 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,12 +19,9 @@ package org.apache.spark.sql
 
 import java.io.File
 import java.math.MathContext
-import java.sql.{Date, Timestamp}
+import java.sql.Timestamp
 
 import org.apache.spark.{AccumulatorSuite, SparkException}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.expressions.SortOrder
-import org.apache.spark.sql.catalyst.plans.logical.Aggregate
 import org.apache.spark.sql.catalyst.util.StringUtils
 import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org