You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/05/18 21:13:03 UTC
spark git commit: [SPARK-7269] [SQL] Incorrect analysis for
aggregation(use semanticEquals)
Repository: spark
Updated Branches:
refs/heads/master fc2480ed1 -> 103c863c2
[SPARK-7269] [SQL] Incorrect analysis for aggregation(use semanticEquals)
A modified version of https://github.com/apache/spark/pull/6110, use `semanticEquals` to make it more efficient.
Author: Wenchen Fan <cl...@outlook.com>
Closes #6173 from cloud-fan/7269 and squashes the following commits:
e4a3cc7 [Wenchen Fan] address comments
cc02045 [Wenchen Fan] consider elements length equal
d7ff8f4 [Wenchen Fan] fix 7269
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/103c863c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/103c863c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/103c863c
Branch: refs/heads/master
Commit: 103c863c2ef3d9e6186cfc7d95251a9515e9f180
Parents: fc2480e
Author: Wenchen Fan <cl...@outlook.com>
Authored: Mon May 18 12:08:28 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon May 18 12:12:55 2015 -0700
----------------------------------------------------------------------
.../spark/sql/catalyst/analysis/Analyzer.scala | 29 +++++---------------
.../sql/catalyst/analysis/CheckAnalysis.scala | 4 +--
.../sql/catalyst/expressions/Expression.scala | 13 +++++++++
.../catalyst/expressions/namedExpressions.scala | 5 ++++
.../spark/sql/catalyst/planning/patterns.scala | 5 ++--
.../sql/hive/execution/SQLQuerySuite.scala | 18 ++++++++++++
6 files changed, 48 insertions(+), 26 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0b6e1d4..dfa4215 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -142,25 +141,6 @@ class Analyzer(
}
object ResolveGroupingAnalytics extends Rule[LogicalPlan] {
- /**
- * Extract attribute set according to the grouping id
- * @param bitmask bitmask to represent the selected of the attribute sequence
- * @param exprs the attributes in sequence
- * @return the attributes of non selected specified via bitmask (with the bit set to 1)
- */
- private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
- : OpenHashSet[Expression] = {
- val set = new OpenHashSet[Expression](2)
-
- var bit = exprs.length - 1
- while (bit >= 0) {
- if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
- bit -= 1
- }
-
- set
- }
-
/*
* GROUP BY a, b, c WITH ROLLUP
* is equivalent to
@@ -197,10 +177,15 @@ class Analyzer(
g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
- val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs)
+ val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
+ var bit = g.groupByExprs.length - 1
+ while (bit >= 0) {
+ if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
+ bit -= 1
+ }
val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
- case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+ case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
Literal.create(null, expr.dataType)
http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index f104e74..06a0504 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -86,12 +86,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
- case e: Attribute if !groupingExprs.contains(e) =>
+ case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
- case e if groupingExprs.contains(e) => // OK
+ case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 0837a31..c7ae9da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -76,6 +76,19 @@ abstract class Expression extends TreeNode[Expression] {
case u: UnresolvedAttribute => PrettyAttribute(u.name)
}.toString
}
+
+ /**
+ * Returns true when two expressions will always compute the same result, even if they differ
+ * cosmetically (i.e. capitalization of names in attributes may be different).
+ */
+ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
+ val elements1 = this.productIterator.toSeq
+ val elements2 = other.asInstanceOf[Product].productIterator.toSeq
+ elements1.length == elements2.length && elements1.zip(elements2).forall {
+ case (e1: Expression, e2: Expression) => e1 semanticEquals e2
+ case (i1, i2) => i1 == i2
+ }
+ }
}
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index a917058..50be26d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -181,6 +181,11 @@ case class AttributeReference(
case _ => false
}
+ override def semanticEquals(other: Expression): Boolean = other match {
+ case ar: AttributeReference => sameRef(ar)
+ case _ => false
+ }
+
override def hashCode: Int = {
// See http://stackoverflow.com/questions/113511/hash-code-implementation
var h = 17
http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index cd54d04..1dd75a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -159,9 +159,10 @@ object PartialAggregation {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
+ val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions
- .get(e.transform { case Alias(g: ExtractValue, _) => g })
- .map(_.toAttribute)
+ .find { case (k, v) => k semanticEquals trimmed }
+ .map(_._2.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
http://git-wip-us.apache.org/repos/asf/spark/blob/103c863c/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index ca2c4b4..e60d00e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -773,4 +773,22 @@ class SQLQuerySuite extends QueryTest {
| select * from v2 order by key limit 1
""".stripMargin), Row(0, 3))
}
+
+ test("SPARK-7269 Check analysis failed in case in-sensitive") {
+ Seq(1, 2, 3).map { i =>
+ (i.toString, i.toString)
+ }.toDF("key", "value").registerTempTable("df_analysis")
+ sql("SELECT kEy from df_analysis group by key").collect()
+ sql("SELECT kEy+3 from df_analysis group by key+3").collect()
+ sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect()
+ sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect()
+ sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect()
+ sql("SELECT 2 from df_analysis A group by key+1").collect()
+ intercept[AnalysisException] {
+ sql("SELECT kEy+1 from df_analysis group by key+3")
+ }
+ intercept[AnalysisException] {
+ sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org