You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/11/08 11:09:44 UTC
spark git commit: [SPARK-18137][SQL] Fix RewriteDistinctAggregates
UnresolvedException when a UDAF has a foldable TypeCheck
Repository: spark
Updated Branches:
refs/heads/master 47731e186 -> c291bd274
[SPARK-18137][SQL] Fix RewriteDistinctAggregates UnresolvedException when a UDAF has a foldable TypeCheck
## What changes were proposed in this pull request?
In RewriteDistinctAggregates rewrite funtion,after the UDAF's childs are mapped to AttributeRefference, If the UDAF(such as ApproximatePercentile) has a foldable TypeCheck for the input, It will failed because the AttributeRefference is not foldable,then the UDAF is not resolved, and then nullify on the unresolved object will throw a Exception.
In this PR, only map Unfoldable child to AttributeRefference, this can avoid the UDAF's foldable TypeCheck. and then only Expand Unfoldable child, there is no need to Expand a static value(foldable value).
**Before sql result**
> select percentile_approxy(key,0.99999),count(distinct key),sume(distinc key) from src limit 1
> org.apache.spark.sql.catalyst.analysis.UnresolvedException: Invalid call to dataType on unresolved object, tree: 'percentile_approx(CAST(src.`key` AS DOUBLE), CAST(0.99999BD AS DOUBLE), 10000)
> at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:92)
> at org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates$.org$apache$spark$sql$catalyst$optimizer$RewriteDistinctAggregates$$nullify(RewriteDistinctAggregates.scala:261)
**After sql result**
> select percentile_approxy(key,0.99999),count(distinct key),sume(distinc key) from src limit 1
> [498.0,309,79136]
## How was this patch tested?
Add a test case in HiveUDFSuit.
Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)>
Closes #15668 from windpiger/RewriteDistinctUDAFUnresolveExcep.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c291bd27
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c291bd27
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c291bd27
Branch: refs/heads/master
Commit: c291bd2745a8a2e4ba91d8697879eb8da10287e2
Parents: 47731e1
Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)>
Authored: Tue Nov 8 12:09:32 2016 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Tue Nov 8 12:09:32 2016 +0100
----------------------------------------------------------------------
.../optimizer/RewriteDistinctAggregates.scala | 35 +++++++++++++++-----
.../spark/sql/hive/execution/HiveUDFSuite.scala | 35 ++++++++++++++++++++
2 files changed, 61 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c291bd27/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index d6a39ec..cd8912f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -115,9 +115,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
// Extract distinct aggregate expressions.
- val distinctAggGroups = aggExpressions
- .filter(_.isDistinct)
- .groupBy(_.aggregateFunction.children.toSet)
+ val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
+ val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
+ if (unfoldableChildren.nonEmpty) {
+ // Only expand the unfoldable children
+ unfoldableChildren
+ } else {
+ // If aggregateFunction's children are all foldable
+ // we must expand at least one of the children (here we take the first child),
+ // or If we don't, we will get the wrong result, for example:
+ // count(distinct 1) will be explained to count(1) after the rewrite function.
+ // Generally, the distinct aggregateFunction should not run
+ // foldable TypeCheck for the first child.
+ e.aggregateFunction.children.take(1).toSet
+ }
+ }
// Check if the aggregates contains functions that do not support partial aggregation.
val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
@@ -136,8 +148,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
af: AggregateFunction)(
- attrs: Expression => Expression): AggregateFunction = {
- af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction]
+ attrs: Expression => Option[Expression]): AggregateFunction = {
+ val newChildren = af.children.map(c => attrs(c).getOrElse(c))
+ af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}
// Setup unique distinct aggregate children.
@@ -161,7 +174,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
- evalWithinGroup(id, distinctAggChildAttrLookup(x))
+ distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
@@ -170,8 +183,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
// Setup expand for the 'regular' aggregate expressions.
- val regularAggExprs = aggExpressions.filter(!_.isDistinct)
- val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+ // only expand unfoldable children
+ val regularAggExprs = aggExpressions
+ .filter(e => !e.isDistinct && e.children.exists(!_.foldable))
+ val regularAggChildren = regularAggExprs
+ .flatMap(_.aggregateFunction.children.filter(!_.foldable))
+ .distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
// Setup aggregates for 'regular' aggregate expressions.
@@ -179,7 +196,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
- val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
+ val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
// Select the result of the first aggregate in the last aggregate.
http://git-wip-us.apache.org/repos/asf/spark/blob/c291bd27/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index f690035..48adc83 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -150,6 +150,41 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
test("Generic UDAF aggregates") {
+
+ checkAnswer(sql(
+ """
+ |SELECT percentile_approx(2, 0.99999),
+ | sum(distinct 1),
+ | count(distinct 1,2,3,4) FROM src LIMIT 1
+ """.stripMargin), sql("SELECT 2, 1, 1 FROM src LIMIT 1").collect().toSeq)
+
+ checkAnswer(sql(
+ """
+ |SELECT ceiling(percentile_approx(distinct key, 0.99999)),
+ | count(distinct key),
+ | sum(distinct key),
+ | count(distinct 1),
+ | sum(distinct 1),
+ | sum(1) FROM src LIMIT 1
+ """.stripMargin),
+ sql(
+ """
+ |SELECT max(key),
+ | count(distinct key),
+ | sum(distinct key),
+ | 1, 1, sum(1) FROM src LIMIT 1
+ """.stripMargin).collect().toSeq)
+
+ checkAnswer(sql(
+ """
+ |SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999)),
+ | count(distinct key), sum(distinct key),
+ | count(distinct 1), sum(distinct 1),
+ | sum(1) FROM src LIMIT 1
+ """.stripMargin),
+ sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1")
+ .collect().toSeq)
+
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org