You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/11/08 11:09:44 UTC

spark git commit: [SPARK-18137][SQL] Fix RewriteDistinctAggregates UnresolvedException when a UDAF has a foldable TypeCheck

Repository: spark
Updated Branches:
  refs/heads/master 47731e186 -> c291bd274


[SPARK-18137][SQL] Fix RewriteDistinctAggregates UnresolvedException when a UDAF has a foldable TypeCheck

## What changes were proposed in this pull request?

In RewriteDistinctAggregates rewrite funtion,after the UDAF's childs are mapped to AttributeRefference, If the UDAF(such as ApproximatePercentile) has a foldable TypeCheck for the input, It will failed because the AttributeRefference is not foldable,then the UDAF is not resolved, and then nullify on the unresolved object will throw a Exception.

In this PR, only map Unfoldable child to AttributeRefference, this can avoid the UDAF's foldable TypeCheck. and then only Expand Unfoldable child, there is no need to Expand a static value(foldable value).

**Before sql result**

> select percentile_approxy(key,0.99999),count(distinct key),sume(distinc key) from src limit 1
> org.apache.spark.sql.catalyst.analysis.UnresolvedException: Invalid call to dataType on unresolved object, tree: 'percentile_approx(CAST(src.`key` AS DOUBLE), CAST(0.99999BD AS DOUBLE), 10000)
> at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:92)
>     at org.apache.spark.sql.catalyst.optimizer.RewriteDistinctAggregates$.org$apache$spark$sql$catalyst$optimizer$RewriteDistinctAggregates$$nullify(RewriteDistinctAggregates.scala:261)

**After sql result**

> select percentile_approxy(key,0.99999),count(distinct key),sume(distinc key) from src limit 1
> [498.0,309,79136]
## How was this patch tested?

Add a test case in HiveUDFSuit.

Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)>

Closes #15668 from windpiger/RewriteDistinctUDAFUnresolveExcep.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c291bd27
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c291bd27
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c291bd27

Branch: refs/heads/master
Commit: c291bd2745a8a2e4ba91d8697879eb8da10287e2
Parents: 47731e1
Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)>
Authored: Tue Nov 8 12:09:32 2016 +0100
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Tue Nov 8 12:09:32 2016 +0100

----------------------------------------------------------------------
 .../optimizer/RewriteDistinctAggregates.scala   | 35 +++++++++++++++-----
 .../spark/sql/hive/execution/HiveUDFSuite.scala | 35 ++++++++++++++++++++
 2 files changed, 61 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c291bd27/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index d6a39ec..cd8912f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -115,9 +115,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
     }
 
     // Extract distinct aggregate expressions.
-    val distinctAggGroups = aggExpressions
-      .filter(_.isDistinct)
-      .groupBy(_.aggregateFunction.children.toSet)
+    val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
+        val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
+        if (unfoldableChildren.nonEmpty) {
+          // Only expand the unfoldable children
+          unfoldableChildren
+        } else {
+          // If aggregateFunction's children are all foldable
+          // we must expand at least one of the children (here we take the first child),
+          // or If we don't, we will get the wrong result, for example:
+          // count(distinct 1) will be explained to count(1) after the rewrite function.
+          // Generally, the distinct aggregateFunction should not run
+          // foldable TypeCheck for the first child.
+          e.aggregateFunction.children.take(1).toSet
+        }
+    }
 
     // Check if the aggregates contains functions that do not support partial aggregation.
     val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
@@ -136,8 +148,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
       def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
       def patchAggregateFunctionChildren(
           af: AggregateFunction)(
-          attrs: Expression => Expression): AggregateFunction = {
-        af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction]
+          attrs: Expression => Option[Expression]): AggregateFunction = {
+        val newChildren = af.children.map(c => attrs(c).getOrElse(c))
+        af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
       }
 
       // Setup unique distinct aggregate children.
@@ -161,7 +174,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
           val operators = expressions.map { e =>
             val af = e.aggregateFunction
             val naf = patchAggregateFunctionChildren(af) { x =>
-              evalWithinGroup(id, distinctAggChildAttrLookup(x))
+              distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
             }
             (e, e.copy(aggregateFunction = naf, isDistinct = false))
           }
@@ -170,8 +183,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
       }
 
       // Setup expand for the 'regular' aggregate expressions.
-      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
-      val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+      // only expand unfoldable children
+      val regularAggExprs = aggExpressions
+        .filter(e => !e.isDistinct && e.children.exists(!_.foldable))
+      val regularAggChildren = regularAggExprs
+        .flatMap(_.aggregateFunction.children.filter(!_.foldable))
+        .distinct
       val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
 
       // Setup aggregates for 'regular' aggregate expressions.
@@ -179,7 +196,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
       val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
       val regularAggOperatorMap = regularAggExprs.map { e =>
         // Perform the actual aggregation in the initial aggregate.
-        val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
+        val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
         val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
 
         // Select the result of the first aggregate in the last aggregate.

http://git-wip-us.apache.org/repos/asf/spark/blob/c291bd27/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index f690035..48adc83 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -150,6 +150,41 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
   }
 
   test("Generic UDAF aggregates") {
+
+    checkAnswer(sql(
+     """
+       |SELECT percentile_approx(2, 0.99999),
+       |       sum(distinct 1),
+       |       count(distinct 1,2,3,4) FROM src LIMIT 1
+     """.stripMargin), sql("SELECT 2, 1, 1 FROM src LIMIT 1").collect().toSeq)
+
+    checkAnswer(sql(
+      """
+        |SELECT ceiling(percentile_approx(distinct key, 0.99999)),
+        |       count(distinct key),
+        |       sum(distinct key),
+        |       count(distinct 1),
+        |       sum(distinct 1),
+        |       sum(1) FROM src LIMIT 1
+      """.stripMargin),
+      sql(
+        """
+          |SELECT max(key),
+          |       count(distinct key),
+          |       sum(distinct key),
+          |       1, 1, sum(1) FROM src LIMIT 1
+        """.stripMargin).collect().toSeq)
+
+    checkAnswer(sql(
+      """
+        |SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999)),
+        |       count(distinct key), sum(distinct key),
+        |       count(distinct 1), sum(distinct 1),
+        |       sum(1) FROM src LIMIT 1
+      """.stripMargin),
+      sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1")
+        .collect().toSeq)
+
     checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
       sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org