You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/07/12 02:18:46 UTC

[spark] branch master updated: [SPARK-39737][SQL] `PERCENTILE_CONT` and `PERCENTILE_DISC` should support aggregate filter

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ab277e123c5 [SPARK-39737][SQL] `PERCENTILE_CONT` and `PERCENTILE_DISC` should support aggregate filter
ab277e123c5 is described below

commit ab277e123c5f8cdc9f147ae019dbb38bc0e50262
Author: Jiaan Geng <be...@163.com>
AuthorDate: Tue Jul 12 10:18:23 2022 +0800

    [SPARK-39737][SQL] `PERCENTILE_CONT` and `PERCENTILE_DISC` should support aggregate filter
    
    ### What changes were proposed in this pull request?
    Currently, Spark support ANSI aggregation function percentile_cont and percentile_disc.
    But the two aggregate functions does not support aggregate filter.
    
    ### Why are the changes needed?
    aggregate filter could improve performance and is very useful.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New test cases.
    
    Closes #37150 from beliefer/SPARK-39737.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |  3 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  3 +-
 .../sql/catalyst/parser/PlanParserSuite.scala      | 14 ++++++-
 .../resources/sql-tests/inputs/percentiles.sql     | 16 ++++++--
 .../sql-tests/results/percentiles.sql.out          | 48 +++++++++++++---------
 5 files changed, 57 insertions(+), 27 deletions(-)

diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index ce37a09d5ba..f398ddd76f7 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -849,7 +849,8 @@ primaryExpression
     | OVERLAY LEFT_PAREN input=valueExpression PLACING replace=valueExpression
       FROM position=valueExpression (FOR length=valueExpression)? RIGHT_PAREN                  #overlay
     | name=(PERCENTILE_CONT | PERCENTILE_DISC) LEFT_PAREN percentage=valueExpression RIGHT_PAREN
-      WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN ( OVER windowSpec)?                #percentile
+        WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN
+        (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? ( OVER windowSpec)?     #percentile
     ;
 
 constant
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 46847411bf0..05b3ddca022 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1865,7 +1865,8 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
           case Descending => PercentileDisc(sortOrder.child, percentage, true)
         }
     }
-    val aggregateExpression = percentile.toAggregateExpression()
+    val filter = Option(ctx.where).map(expression(_))
+    val aggregateExpression = percentile.toAggregateExpression(false, filter)
     ctx.windowSpec match {
       case spec: WindowRefContext =>
         UnresolvedWindowExpression(aggregateExpression, visitWindowRef(spec))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 3c757442e13..6c0d970143b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -1327,6 +1327,12 @@ class PlanParserSuite extends AnalysisTest {
         Literal(Decimal(0.1), DecimalType(1, 1)), true).toAggregateExpression()
     )
 
+    assertPercentilePlans(
+      "SELECT PERCENTILE_CONT(0.1) WITHIN GROUP (ORDER BY col) FILTER (WHERE id > 10)",
+      PercentileCont(UnresolvedAttribute("col"), Literal(Decimal(0.1), DecimalType(1, 1)))
+        .toAggregateExpression(false, Some(GreaterThan(UnresolvedAttribute("id"), Literal(10))))
+    )
+
     assertPercentilePlans(
       "SELECT PERCENTILE_DISC(0.1) WITHIN GROUP (ORDER BY col)",
       PercentileDisc(UnresolvedAttribute("col"), Literal(Decimal(0.1), DecimalType(1, 1)))
@@ -1335,8 +1341,14 @@ class PlanParserSuite extends AnalysisTest {
 
     assertPercentilePlans(
       "SELECT PERCENTILE_DISC(0.1) WITHIN GROUP (ORDER BY col DESC)",
-      new PercentileDisc(UnresolvedAttribute("col"),
+      PercentileDisc(UnresolvedAttribute("col"),
         Literal(Decimal(0.1), DecimalType(1, 1)), true).toAggregateExpression()
     )
+
+    assertPercentilePlans(
+      "SELECT PERCENTILE_DISC(0.1) WITHIN GROUP (ORDER BY col) FILTER (WHERE id > 10)",
+      PercentileDisc(UnresolvedAttribute("col"), Literal(Decimal(0.1), DecimalType(1, 1)))
+        .toAggregateExpression(false, Some(GreaterThan(UnresolvedAttribute("id"), Literal(10))))
+    )
   }
 }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
index 8cdba718622..f02b728e113 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
@@ -25,26 +25,34 @@ AS basic_pays(employee_name, department, salary);
 
 SELECT
   percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr;
 
 SELECT
   k,
   percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr
 GROUP BY k
 ORDER BY k;
 
 SELECT
   percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr;
 
 SELECT
   k,
   percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr
 GROUP BY k
 ORDER BY k;
diff --git a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
index f36f4ac086c..f124dcc322e 100644
--- a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
@@ -38,59 +38,67 @@ struct<>
 -- !query
 SELECT
   percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr
 -- !query schema
-struct<percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
 -- !query output
-10.0	30.0
+10.0	15.0	30.0	27.5
 
 
 -- !query
 SELECT
   k,
   percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr
 GROUP BY k
 ORDER BY k
 -- !query schema
-struct<k:int,percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<k:int,percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
 -- !query output
-0	10.0	30.0
-1	12.5	17.5
-2	17.5	26.25
-3	60.0	60.0
-4	NULL	NULL
+0	10.0	NULL	30.0	NULL
+1	12.5	12.5	17.5	17.5
+2	17.5	17.5	26.25	26.25
+3	60.0	60.0	60.0	60.0
+4	NULL	NULL	NULL	NULL
 
 
 -- !query
 SELECT
   percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr
 -- !query schema
-struct<percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
 -- !query output
-10.0	30.0
+10.0	10.0	30.0	30.0
 
 
 -- !query
 SELECT
   k,
   percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
-  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+  percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
 FROM aggr
 GROUP BY k
 ORDER BY k
 -- !query schema
-struct<k:int,percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<k:int,percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
 -- !query output
-0	10.0	30.0
-1	10.0	20.0
-2	10.0	30.0
-3	60.0	60.0
-4	NULL	NULL
+0	10.0	NULL	30.0	NULL
+1	10.0	10.0	20.0	20.0
+2	10.0	10.0	30.0	30.0
+3	60.0	60.0	60.0	60.0
+4	NULL	NULL	NULL	NULL
 
 
 -- !query


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org