You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/03/03 20:27:53 UTC

[spark] branch branch-3.0 updated: [SPARK-30997][SQL] Fix an analysis failure in generators with aggregate functions

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 7d853ab  [SPARK-30997][SQL] Fix an analysis failure in generators with aggregate functions
7d853ab is described below

commit 7d853ab6eba479a7cc5d8839b4fc497bc6b6d4c8
Author: Takeshi Yamamuro <ya...@apache.org>
AuthorDate: Tue Mar 3 12:25:12 2020 -0800

    [SPARK-30997][SQL] Fix an analysis failure in generators with aggregate functions
    
    ### What changes were proposed in this pull request?
    
    We have supported generators in SQL aggregate expressions by SPARK-28782.
    But, the generator(explode) query with aggregate functions in DataFrame failed as follows;
    
    ```
    // SPARK-28782: Generator support in aggregate expressions
    scala> spark.range(3).toDF("id").createOrReplaceTempView("t")
    scala> sql("select explode(array(min(id), max(id))) from t").show()
    +---+
    |col|
    +---+
    |  0|
    |  2|
    +---+
    
    // A failure case handled in this pr
    scala> spark.range(3).select(explode(array(min($"id"), max($"id")))).show()
    org.apache.spark.sql.AnalysisException:
    The query operator `Generate` contains one or more unsupported
    expression types Aggregate, Window or Generate.
    Invalid expressions: [min(`id`), max(`id`)];;
    Project [col#46L]
    +- Generate explode(array(min(id#42L), max(id#42L))), false, [col#46L]
       +- Range (0, 3, step=1, splits=Some(4))
    
      at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis(CheckAnalysis.scala:49)
      at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.failAnalysis$(CheckAnalysis.scala:48)
      at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:129)
    ```
    
    The root cause is that `ExtractGenerator` wrongly replaces a project w/ aggregate functions
    before `GlobalAggregates` replaces it with an aggregate as follows;
    
    ```
    scala> sql("SET spark.sql.optimizer.planChangeLog.level=warn")
    scala> spark.range(3).select(explode(array(min($"id"), max($"id")))).show()
    
    20/03/01 12:51:58 WARN HiveSessionStateBuilder$$anon$1:
    === Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences ===
    !'Project [explode(array(min('id), max('id))) AS List()]   'Project [explode(array(min(id#72L), max(id#72L))) AS List()]
     +- Range (0, 3, step=1, splits=Some(4))                   +- Range (0, 3, step=1, splits=Some(4))
    
    20/03/01 12:51:58 WARN HiveSessionStateBuilder$$anon$1:
    === Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractGenerator ===
    !'Project [explode(array(min(id#72L), max(id#72L))) AS List()]   Project [col#76L]
    !+- Range (0, 3, step=1, splits=Some(4))                         +- Generate explode(array(min(id#72L), max(id#72L))), false, [col#76L]
    !                                                                   +- Range (0, 3, step=1, splits=Some(4))
    
    20/03/01 12:51:58 WARN HiveSessionStateBuilder$$anon$1:
    === Result of Batch Resolution ===
    !'Project [explode(array(min('id), max('id))) AS List()]   Project [col#76L]
    !+- Range (0, 3, step=1, splits=Some(4))                   +- Generate explode(array(min(id#72L), max(id#72L))), false, [col#76L]
    !                                                             +- Range (0, 3, step=1, splits=Some(4))
    
    // the analysis failed here...
    ```
    
    To avoid the case in `ExtractGenerator`, this pr addes a condition to ignore generators having aggregate functions.
    A correct sequence of rules is as follows;
    
    ```
    20/03/01 13:19:06 WARN HiveSessionStateBuilder$$anon$1:
    === Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveReferences ===
    !'Project [explode(array(min('id), max('id))) AS List()]   'Project [explode(array(min(id#27L), max(id#27L))) AS List()]
     +- Range (0, 3, step=1, splits=Some(4))                   +- Range (0, 3, step=1, splits=Some(4))
    
    20/03/01 13:19:06 WARN HiveSessionStateBuilder$$anon$1:
    === Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$GlobalAggregates ===
    !'Project [explode(array(min(id#27L), max(id#27L))) AS List()]   'Aggregate [explode(array(min(id#27L), max(id#27L))) AS List()]
     +- Range (0, 3, step=1, splits=Some(4))                         +- Range (0, 3, step=1, splits=Some(4))
    
    20/03/01 13:19:06 WARN HiveSessionStateBuilder$$anon$1:
    === Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ExtractGenerator ===
    !'Aggregate [explode(array(min(id#27L), max(id#27L))) AS List()]   'Project [explode(_gen_input_0#31) AS List()]
    !+- Range (0, 3, step=1, splits=Some(4))                           +- Aggregate [array(min(id#27L), max(id#27L)) AS _gen_input_0#31]
    !                                                                     +- Range (0, 3, step=1, splits=Some(4))
    
    ```
    
    ### Why are the changes needed?
    
    A bug fix.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added tests.
    
    Closes #27749 from maropu/ExplodeInAggregate.
    
    Authored-by: Takeshi Yamamuro <ya...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 4a1d273a4aac66385b948c6130de0a26ef84bbb4)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++++++++++++
 .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala  | 15 +++++++++++++++
 .../org/apache/spark/sql/GeneratorFunctionSuite.scala     |  5 +++++
 3 files changed, 34 insertions(+)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 486b952..254dd44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2182,6 +2182,15 @@ class Analyzer(
       }
     }
 
+    private def hasAggFunctionInGenerator(ne: Seq[NamedExpression]): Boolean = {
+      ne.exists(_.find {
+        case g: Generator =>
+          g.children.exists(_.find(_.isInstanceOf[AggregateFunction]).isDefined)
+        case _ =>
+          false
+      }.nonEmpty)
+    }
+
     private def trimAlias(expr: NamedExpression): Expression = expr match {
       case UnresolvedAlias(child, _) => child
       case Alias(child, _) => child
@@ -2268,6 +2277,11 @@ class Analyzer(
         val newAgg = Aggregate(groupList, newAggList, child)
         Project(projectExprs.toList, newAgg)
 
+      case p @ Project(projectList, _) if hasAggFunctionInGenerator(projectList) =>
+        // If a generator has any aggregate function, we need to apply the `GlobalAggregates` rule
+        // first for replacing `Project` with `Aggregate`.
+        p
+
       case p @ Project(projectList, child) =>
         // Holds the resolved generator, if one exists in the project list.
         var resolvedGenerator: Generate = null
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 3db1053..09e0d9c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -453,6 +453,13 @@ class AnalysisErrorSuite extends AnalysisTest {
   )
 
   errorTest(
+    "generator nested in expressions for aggregates",
+    testRelation.select(Explode(CreateArray(min($"a") :: max($"a") :: Nil)) + 1),
+    "Generators are not supported when it's nested in expressions, but got: " +
+      "(explode(array(min(a), max(a))) + 1)" :: Nil
+  )
+
+  errorTest(
     "generator appears in operator which is not Project",
     listRelation.sortBy(Explode($"list").asc),
     "Generators are not supported outside the SELECT clause, but got: Sort" :: Nil
@@ -476,6 +483,14 @@ class AnalysisErrorSuite extends AnalysisTest {
     "Only one generator allowed per select clause but found 2: explode(list), explode(list)" :: Nil
   )
 
+  errorTest(
+    "more than one generators for aggregates in SELECT",
+    testRelation.select(Explode(CreateArray(min($"a") :: Nil)),
+      Explode(CreateArray(max($"a") :: Nil))),
+    "Only one generator allowed per select clause but found 2: " +
+      "explode(array(min(a))), explode(array(max(a)))" :: Nil
+  )
+
   test("SPARK-6452 regression test") {
     // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
     // Since we manually construct the logical plan at here and Sum only accept
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index 6785b31..8f44903 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -351,6 +351,11 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
     assert(errMsg.contains("Generators are not supported when it's nested in expressions, " +
       "but got: explode(explode(v))"))
   }
+
+  test("SPARK-30997: generators in aggregate expressions for dataframe") {
+    val df = Seq(1, 2, 3).toDF("v")
+    checkAnswer(df.select(explode(array(min($"v"), max($"v")))), Row(1) :: Row(3) :: Nil)
+  }
 }
 
 case class EmptyGenerator() extends Generator {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org