You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/02/25 04:43:04 UTC

spark git commit: [SPARK-13092][SQL] Add ExpressionSet for constraint tracking

Repository: spark
Updated Branches:
  refs/heads/master 5a7af9e7a -> 2b042577f


[SPARK-13092][SQL] Add ExpressionSet for constraint tracking

This PR adds a new abstraction called an `ExpressionSet` which attempts to canonicalize expressions to remove cosmetic differences.  Deterministic expressions that are in the set after canonicalization will always return the same answer given the same input (i.e. false positives should not be possible). However, it is possible that two canonical expressions that are not equal will in fact return the same answer given any input (i.e. false negatives are possible).

```scala
val set = AttributeSet('a + 1 :: 1 + 'a :: Nil)

set.iterator => Iterator('a + 1)
set.contains('a + 1) => true
set.contains(1 + 'a) => true
set.contains('a + 2) => false
```

Other relevant changes include:
 - Since this concept overlaps with the existing `semanticEquals` and `semanticHash`, those functions are also ported to this new infrastructure.
 - A memoized `canonicalized` version of the expression is added as a `lazy val` to `Expression` and is used by both `semanticEquals` and `ExpressionSet`.
 - A set of unit tests for `ExpressionSet` are added
 - Tests which expect `semanticEquals` to be less intelligent than it now is are updated.

As a followup, we should consider auditing the places where we do `O(n)` `semanticEquals` operations and replace them with `ExpressionSet`.  We should also consider consolidating `AttributeSet` as a specialized factory for an `ExpressionSet.`

Author: Michael Armbrust <mi...@databricks.com>

Closes #11338 from marmbrus/expressionSet.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b042577
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b042577
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b042577

Branch: refs/heads/master
Commit: 2b042577fb077865c3fce69c9d4eda22fde92673
Parents: 5a7af9e
Author: Michael Armbrust <mi...@databricks.com>
Authored: Wed Feb 24 19:43:00 2016 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Wed Feb 24 19:43:00 2016 -0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala |  2 +-
 .../sql/catalyst/expressions/Canonicalize.scala | 81 ++++++++++++++++++
 .../sql/catalyst/expressions/Expression.scala   | 56 +++++-------
 .../catalyst/expressions/ExpressionSet.scala    | 87 +++++++++++++++++++
 .../spark/sql/catalyst/plans/QueryPlan.scala    | 12 +--
 .../expressions/ExpressionSetSuite.scala        | 89 ++++++++++++++++++++
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  3 +-
 7 files changed, 285 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2b042577/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 1a2ec7e..a12f739 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -235,7 +235,7 @@ package object dsl {
 
     implicit class DslAttribute(a: AttributeReference) {
       def notNull: AttributeReference = a.withNullability(false)
-      def nullable: AttributeReference = a.withNullability(true)
+      def canBeNull: AttributeReference = a.withNullability(true)
       def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2b042577/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
new file mode 100644
index 0000000..b58a527
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Rewrites an expression using rules that are guaranteed preserve the result while attempting
+ * to remove cosmetic variations. Deterministic expressions that are `equal` after canonicalization
+ * will always return the same answer given the same input (i.e. false positives should not be
+ * possible). However, it is possible that two canonical expressions that are not equal will in fact
+ * return the same answer given any input (i.e. false negatives are possible).
+ *
+ * The following rules are applied:
+ *  - Names and nullability hints for [[org.apache.spark.sql.types.DataType]]s are stripped.
+ *  - Commutative and associative operations ([[Add]] and [[Multiply]]) have their children ordered
+ *    by `hashCode`.
+*   - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
+ *  - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`.
+ */
+object Canonicalize extends RuleExecutor[Expression] {
+  override protected def batches: Seq[Batch] =
+    Batch(
+      "Expression Canonicalization", FixedPoint(100),
+      IgnoreNamesTypes,
+      Reorder) :: Nil
+
+  /** Remove names and nullability from types. */
+  protected object IgnoreNamesTypes extends Rule[Expression] {
+    override def apply(e: Expression): Expression = e transformUp {
+      case a: AttributeReference =>
+        AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId)
+    }
+  }
+
+  /** Collects adjacent commutative operations. */
+  protected def gatherCommutative(
+      e: Expression,
+      f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] = e match {
+    case c if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f))
+    case other => other :: Nil
+  }
+
+  /** Orders a set of commutative operations by their hash code. */
+  protected def orderCommutative(
+      e: Expression,
+      f: PartialFunction[Expression, Seq[Expression]]): Seq[Expression] =
+    gatherCommutative(e, f).sortBy(_.hashCode())
+
+  /** Rearrange expressions that are commutative or associative. */
+  protected object Reorder extends Rule[Expression] {
+    override def apply(e: Expression): Expression = e transformUp {
+      case a: Add => orderCommutative(a, { case Add(l, r) => Seq(l, r) }).reduce(Add)
+      case m: Multiply => orderCommutative(m, { case Multiply(l, r) => Seq(l, r) }).reduce(Multiply)
+
+      case EqualTo(l, r) if l.hashCode() > r.hashCode() => EqualTo(r, l)
+      case EqualNullSafe(l, r) if l.hashCode() > r.hashCode() => EqualNullSafe(r, l)
+
+      case GreaterThan(l, r) if l.hashCode() > r.hashCode() => LessThan(r, l)
+      case LessThan(l, r) if l.hashCode() > r.hashCode() => GreaterThan(r, l)
+
+      case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l)
+      case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2b042577/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 119496c..692c160 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -145,48 +145,31 @@ abstract class Expression extends TreeNode[Expression] {
   def childrenResolved: Boolean = children.forall(_.resolved)
 
   /**
+   * Returns an expression where a best effort attempt has been made to transform `this` in a way
+   * that preserves the result but removes cosmetic variations (case sensitivity, ordering for
+   * commutative operations, etc.)  See [[Canonicalize]] for more details.
+   *
+   * `deterministic` expressions where `this.canonicalized == other.canonicalized` will always
+   * evaluate to the same result.
+   */
+   lazy val canonicalized: Expression = Canonicalize.execute(this)
+
+  /**
    * Returns true when two expressions will always compute the same result, even if they differ
    * cosmetically (i.e. capitalization of names in attributes may be different).
+   *
+   * See [[Canonicalize]] for more details.
    */
-  def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
-    def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
-      elements1.length == elements2.length && elements1.zip(elements2).forall {
-        case (e1: Expression, e2: Expression) => e1 semanticEquals e2
-        case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
-        case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
-        case (i1, i2) => i1 == i2
-      }
-    }
-    // Non-deterministic expressions cannot be semantic equal
-    if (!deterministic || !other.deterministic) return false
-    val elements1 = this.productIterator.toSeq
-    val elements2 = other.asInstanceOf[Product].productIterator.toSeq
-    checkSemantic(elements1, elements2)
-  }
+  def semanticEquals(other: Expression): Boolean =
+    deterministic && other.deterministic  && canonicalized == other.canonicalized
 
   /**
-   * Returns the hash for this expression. Expressions that compute the same result, even if
-   * they differ cosmetically should return the same hash.
+   * Returns a `hashCode` for the calculation performed by this expression. Unlike the standard
+   * `hashCode`, an attempt has been made to eliminate cosmetic differences.
+   *
+   * See [[Canonicalize]] for more details.
    */
-  def semanticHash() : Int = {
-    def computeHash(e: Seq[Any]): Int = {
-      // See http://stackoverflow.com/questions/113511/hash-code-implementation
-      var hash: Int = 17
-      e.foreach(i => {
-        val h: Int = i match {
-          case e: Expression => e.semanticHash()
-          case Some(e: Expression) => e.semanticHash()
-          case t: Traversable[_] => computeHash(t.toSeq)
-          case null => 0
-          case other => other.hashCode()
-        }
-        hash = hash * 37 + h
-      })
-      hash
-    }
-
-    computeHash(this.productIterator.toSeq)
-  }
+  def semanticHash(): Int = canonicalized.hashCode()
 
   /**
    * Checks the input data types, returns `TypeCheckResult.success` if it's valid,
@@ -369,7 +352,6 @@ abstract class UnaryExpression extends Expression {
   }
 }
 
-
 /**
  * An expression with two inputs and one output. The output is by default evaluated to null
  * if any input is evaluated to null.

http://git-wip-us.apache.org/repos/asf/spark/blob/2b042577/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
new file mode 100644
index 0000000..acea049
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+object ExpressionSet {
+  /** Constructs a new [[ExpressionSet]] by applying [[Canonicalize]] to `expressions`. */
+  def apply(expressions: TraversableOnce[Expression]): ExpressionSet = {
+    val set = new ExpressionSet()
+    expressions.foreach(set.add)
+    set
+  }
+}
+
+/**
+ * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]]
+ * (i.e. one that attempts to ignore cosmetic differences).  See [[Canonicalize]] for more details.
+ *
+ * Internally this set uses the canonical representation, but keeps also track of the original
+ * expressions to ease debugging.  Since different expressions can share the same canonical
+ * representation, this means that operations that extract expressions from this set are only
+ * guranteed to see at least one such expression.  For example:
+ *
+ * {{{
+ *   val set = AttributeSet(a + 1, 1 + a)
+ *
+ *   set.iterator => Iterator(a + 1)
+ *   set.contains(a + 1) => true
+ *   set.contains(1 + a) => true
+ *   set.contains(a + 2) => false
+ * }}}
+ */
+class ExpressionSet protected(
+    protected val baseSet: mutable.Set[Expression] = new mutable.HashSet,
+    protected val originals: mutable.Buffer[Expression] = new ArrayBuffer)
+  extends Set[Expression] {
+
+  protected def add(e: Expression): Unit = {
+    if (!baseSet.contains(e.canonicalized)) {
+      baseSet.add(e.canonicalized)
+      originals.append(e)
+    }
+  }
+
+  override def contains(elem: Expression): Boolean = baseSet.contains(elem.canonicalized)
+
+  override def +(elem: Expression): ExpressionSet = {
+    val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
+    newSet.add(elem)
+    newSet
+  }
+
+  override def -(elem: Expression): ExpressionSet = {
+    val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized)
+    val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized)
+    new ExpressionSet(newBaseSet, newOriginals)
+  }
+
+  override def iterator: Iterator[Expression] = originals.iterator
+
+  /**
+   * Returns a string containing both the post [[Canonicalize]] expressions and the original
+   * expressions in this set.
+   */
+  def toDebugString: String =
+    s"""
+       |baseSet: ${baseSet.mkString(", ")}
+       |originals: ${originals.mkString(", ")}
+     """.stripMargin
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2b042577/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 5e7d144..a74b288 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -63,17 +63,19 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
   }
 
   /**
-   * A sequence of expressions that describes the data property of the output rows of this
-   * operator. For example, if the output of this operator is column `a`, an example `constraints`
-   * can be `Set(a > 10, a < 20)`.
+   * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
+   * example, if this set contains the expression `a = 2` then that expression is guaranteed to
+   * evaluate to `true` for all rows produced.
    */
-  lazy val constraints: Set[Expression] = getRelevantConstraints(validConstraints)
+  lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))
 
   /**
    * This method can be overridden by any child class of QueryPlan to specify a set of constraints
    * based on the given operator's constraint propagation logic. These constraints are then
    * canonicalized and filtered automatically to contain only those attributes that appear in the
-   * [[outputSet]]
+   * [[outputSet]].
+   *
+   * See [[Canonicalize]] for more details.
    */
   protected def validConstraints: Set[Expression] = Set.empty
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2b042577/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
new file mode 100644
index 0000000..ce42e57
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.types.IntegerType
+
+class ExpressionSetSuite extends SparkFunSuite {
+
+  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
+  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
+  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
+
+  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
+  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
+
+  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)
+
+  def setTest(size: Int, exprs: Expression*): Unit = {
+    test(s"expect $size: ${exprs.mkString(", ")}") {
+      val set = ExpressionSet(exprs)
+      if (set.size != size) {
+        fail(set.toDebugString)
+      }
+    }
+  }
+
+  def setTestIgnore(size: Int, exprs: Expression*): Unit =
+    ignore(s"expect $size: ${exprs.mkString(", ")}") {}
+
+  // Commutative
+  setTest(1, aUpper + 1, aLower + 1)
+  setTest(2, aUpper + 1, aLower + 2)
+  setTest(2, aUpper + 1, fakeA + 1)
+  setTest(2, aUpper + 1, bUpper + 1)
+
+  setTest(1, aUpper + aLower, aLower + aUpper)
+  setTest(1, aUpper + bUpper, bUpper + aUpper)
+  setTest(1,
+    aUpper + bUpper + 3,
+    bUpper + 3 + aUpper,
+    bUpper + aUpper + 3,
+    Literal(3) + aUpper + bUpper)
+  setTest(1,
+    aUpper * bUpper * 3,
+    bUpper * 3 * aUpper,
+    bUpper * aUpper * 3,
+    Literal(3) * aUpper * bUpper)
+  setTest(1, aUpper === bUpper, bUpper === aUpper)
+
+  setTest(1, aUpper + 1 === bUpper, bUpper === Literal(1) + aUpper)
+
+
+  // Not commutative
+  setTest(2, aUpper - bUpper, bUpper - aUpper)
+
+  // Reversable
+  setTest(1, aUpper > bUpper, bUpper < aUpper)
+  setTest(1, aUpper >= bUpper, bUpper <= aUpper)
+
+  test("add to / remove from set") {
+    val initialSet = ExpressionSet(aUpper + 1 :: Nil)
+
+    assert((initialSet + (aUpper + 1)).size == 1)
+    assert((initialSet + (aUpper + 2)).size == 2)
+    assert((initialSet - (aUpper + 1)).size == 0)
+    assert((initialSet - (aUpper + 2)).size == 1)
+
+    assert((initialSet + (aLower + 1)).size == 1)
+    assert((initialSet - (aLower + 1)).size == 0)
+
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2b042577/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 1d9db27..13ff4a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1980,9 +1980,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       verifyCallCount(
         df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
 
-      // Would be nice if semantic equals for `+` understood commutative
       verifyCallCount(
-        df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+        df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1)
 
       // Try disabling it via configuration.
       sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org