You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/05/18 21:13:03 UTC

spark git commit: [SPARK-7269] [SQL] Incorrect analysis for aggregation(use semanticEquals)

Repository: spark
Updated Branches:
  refs/heads/master fc2480ed1 -> 103c863c2


[SPARK-7269] [SQL] Incorrect analysis for aggregation(use semanticEquals)

A modified version of https://github.com/apache/spark/pull/6110, use `semanticEquals` to make it more efficient.

Author: Wenchen Fan <cl...@outlook.com>

Closes #6173 from cloud-fan/7269 and squashes the following commits:

e4a3cc7 [Wenchen Fan] address comments
cc02045 [Wenchen Fan] consider elements length equal
d7ff8f4 [Wenchen Fan] fix 7269


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/103c863c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/103c863c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/103c863c

Branch: refs/heads/master
Commit: 103c863c2ef3d9e6186cfc7d95251a9515e9f180
Parents: fc2480e
Author: Wenchen Fan <cl...@outlook.com>
Authored: Mon May 18 12:08:28 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon May 18 12:12:55 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 29 +++++---------------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  4 +--
 .../sql/catalyst/expressions/Expression.scala   | 13 +++++++++
 .../catalyst/expressions/namedExpressions.scala |  5 ++++
 .../spark/sql/catalyst/planning/patterns.scala  |  5 ++--
 .../sql/hive/execution/SQLQuerySuite.scala      | 18 ++++++++++++
 6 files changed, 48 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0b6e1d4..dfa4215 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
 
 /**
  * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -142,25 +141,6 @@ class Analyzer(
   }
 
   object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
-    /**
-     * Extract attribute set according to the grouping id
-     * @param bitmask bitmask to represent the selected of the attribute sequence
-     * @param exprs the attributes in sequence
-     * @return the attributes of non selected specified via bitmask (with the bit set to 1)
-     */
-    private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
-    : OpenHashSet[Expression] = {
-      val set = new OpenHashSet[Expression](2)
-
-      var bit = exprs.length - 1
-      while (bit >= 0) {
-        if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
-        bit -= 1
-      }
-
-      set
-    }
-
     /*
      *  GROUP BY a, b, c WITH ROLLUP
      *  is equivalent to
@@ -197,10 +177,15 @@ class Analyzer(
 
       g.bitmasks.foreach { bitmask =>
         // get the non selected grouping attributes according to the bit mask
-        val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
+        val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
+        var bit = g.groupByExprs.length - 1
+        while (bit >= 0) {
+          if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
+          bit -= 1
+        }
 
         val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
-          case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+          case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
             // if the input attribute in the Invalid Grouping Expression set of for this group
             // replace it with constant null
             Literal.create(null, expr.dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index f104e74..06a0504 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -86,12 +86,12 @@ trait CheckAnalysis {
           case Aggregate(groupingExprs, aggregateExprs, child) =>
             def checkValidAggregateExpression(expr: Expression): Unit = expr match {
               case _: AggregateExpression => // OK
-              case e: Attribute if !groupingExprs.contains(e) =>
+              case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
                 failAnalysis(
                   s"expression '${e.prettyString}' is neither present in the group by, " +
                     s"nor is it an aggregate function. " +
                     "Add to group by or wrap in first() if you don't care which value you get.")
-              case e if groupingExprs.contains(e) => // OK
+              case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
               case e if e.references.isEmpty => // OK
               case e => e.children.foreach(checkValidAggregateExpression)
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 0837a31..c7ae9da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -76,6 +76,19 @@ abstract class Expression extends TreeNode[Expression] {
       case u: UnresolvedAttribute => PrettyAttribute(u.name)
     }.toString
   }
+
+  /**
+   * Returns true when two expressions will always compute the same result, even if they differ
+   * cosmetically (i.e. capitalization of names in attributes may be different).
+   */
+  def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
+    val elements1 = this.productIterator.toSeq
+    val elements2 = other.asInstanceOf[Product].productIterator.toSeq
+    elements1.length == elements2.length && elements1.zip(elements2).forall {
+      case (e1: Expression, e2: Expression) => e1 semanticEquals e2
+      case (i1, i2) => i1 == i2
+    }
+  }
 }
 
 abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {

http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index a917058..50be26d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -181,6 +181,11 @@ case class AttributeReference(
     case _ => false
   }
 
+  override def semanticEquals(other: Expression): Boolean = other match {
+    case ar: AttributeReference => sameRef(ar)
+    case _ => false
+  }
+
   override def hashCode: Int = {
     // See http://stackoverflow.com/questions/113511/hash-code-implementation
     var h = 17

http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index cd54d04..1dd75a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -159,9 +159,10 @@ object PartialAggregation {
             // Should trim aliases around `GetField`s. These aliases are introduced while
             // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
             // (Should we just turn `GetField` into a `NamedExpression`?)
+            val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
             namedGroupingExpressions
-              .get(e.transform { case Alias(g: ExtractValue, _) => g })
-              .map(_.toAttribute)
+              .find { case (k, v) => k semanticEquals trimmed }
+              .map(_._2.toAttribute)
               .getOrElse(e)
         }).asInstanceOf[Seq[NamedExpression]]
 

http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index ca2c4b4..e60d00e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -773,4 +773,22 @@ class SQLQuerySuite extends QueryTest {
         | select * from v2 order by key limit 1
       """.stripMargin), Row(0, 3))
   }
+
+  test("SPARK-7269 Check analysis failed in case in-sensitive") {
+    Seq(1, 2, 3).map { i =>
+      (i.toString, i.toString)
+    }.toDF("key", "value").registerTempTable("df_analysis")
+    sql("SELECT kEy from df_analysis group by key").collect()
+    sql("SELECT kEy+3 from df_analysis group by key+3").collect()
+    sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
+    sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
+    sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
+    sql("SELECT 2 from df_analysis A group by key+1").collect()
+    intercept[AnalysisException] {
+      sql("SELECT kEy+1 from df_analysis group by key+3")
+    }
+    intercept[AnalysisException] {
+      sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org