You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/07/12 02:18:46 UTC
[spark] branch master updated: [SPARK-39737][SQL] `PERCENTILE_CONT` and `PERCENTILE_DISC` should support aggregate filter
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ab277e123c5 [SPARK-39737][SQL] `PERCENTILE_CONT` and `PERCENTILE_DISC` should support aggregate filter
ab277e123c5 is described below
commit ab277e123c5f8cdc9f147ae019dbb38bc0e50262
Author: Jiaan Geng <be...@163.com>
AuthorDate: Tue Jul 12 10:18:23 2022 +0800
[SPARK-39737][SQL] `PERCENTILE_CONT` and `PERCENTILE_DISC` should support aggregate filter
### What changes were proposed in this pull request?
Currently, Spark support ANSI aggregation function percentile_cont and percentile_disc.
But the two aggregate functions does not support aggregate filter.
### Why are the changes needed?
aggregate filter could improve performance and is very useful.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New test cases.
Closes #37150 from beliefer/SPARK-39737.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 3 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 3 +-
.../sql/catalyst/parser/PlanParserSuite.scala | 14 ++++++-
.../resources/sql-tests/inputs/percentiles.sql | 16 ++++++--
.../sql-tests/results/percentiles.sql.out | 48 +++++++++++++---------
5 files changed, 57 insertions(+), 27 deletions(-)
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index ce37a09d5ba..f398ddd76f7 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -849,7 +849,8 @@ primaryExpression
| OVERLAY LEFT_PAREN input=valueExpression PLACING replace=valueExpression
FROM position=valueExpression (FOR length=valueExpression)? RIGHT_PAREN #overlay
| name=(PERCENTILE_CONT | PERCENTILE_DISC) LEFT_PAREN percentage=valueExpression RIGHT_PAREN
- WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN ( OVER windowSpec)? #percentile
+ WITHIN GROUP LEFT_PAREN ORDER BY sortItem RIGHT_PAREN
+ (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? ( OVER windowSpec)? #percentile
;
constant
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 46847411bf0..05b3ddca022 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1865,7 +1865,8 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
case Descending => PercentileDisc(sortOrder.child, percentage, true)
}
}
- val aggregateExpression = percentile.toAggregateExpression()
+ val filter = Option(ctx.where).map(expression(_))
+ val aggregateExpression = percentile.toAggregateExpression(false, filter)
ctx.windowSpec match {
case spec: WindowRefContext =>
UnresolvedWindowExpression(aggregateExpression, visitWindowRef(spec))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 3c757442e13..6c0d970143b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -1327,6 +1327,12 @@ class PlanParserSuite extends AnalysisTest {
Literal(Decimal(0.1), DecimalType(1, 1)), true).toAggregateExpression()
)
+ assertPercentilePlans(
+ "SELECT PERCENTILE_CONT(0.1) WITHIN GROUP (ORDER BY col) FILTER (WHERE id > 10)",
+ PercentileCont(UnresolvedAttribute("col"), Literal(Decimal(0.1), DecimalType(1, 1)))
+ .toAggregateExpression(false, Some(GreaterThan(UnresolvedAttribute("id"), Literal(10))))
+ )
+
assertPercentilePlans(
"SELECT PERCENTILE_DISC(0.1) WITHIN GROUP (ORDER BY col)",
PercentileDisc(UnresolvedAttribute("col"), Literal(Decimal(0.1), DecimalType(1, 1)))
@@ -1335,8 +1341,14 @@ class PlanParserSuite extends AnalysisTest {
assertPercentilePlans(
"SELECT PERCENTILE_DISC(0.1) WITHIN GROUP (ORDER BY col DESC)",
- new PercentileDisc(UnresolvedAttribute("col"),
+ PercentileDisc(UnresolvedAttribute("col"),
Literal(Decimal(0.1), DecimalType(1, 1)), true).toAggregateExpression()
)
+
+ assertPercentilePlans(
+ "SELECT PERCENTILE_DISC(0.1) WITHIN GROUP (ORDER BY col) FILTER (WHERE id > 10)",
+ PercentileDisc(UnresolvedAttribute("col"), Literal(Decimal(0.1), DecimalType(1, 1)))
+ .toAggregateExpression(false, Some(GreaterThan(UnresolvedAttribute("id"), Literal(10))))
+ )
}
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
index 8cdba718622..f02b728e113 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/percentiles.sql
@@ -25,26 +25,34 @@ AS basic_pays(employee_name, department, salary);
SELECT
percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
- percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr;
SELECT
k,
percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
- percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr
GROUP BY k
ORDER BY k;
SELECT
percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
- percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr;
SELECT
k,
percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
- percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr
GROUP BY k
ORDER BY k;
diff --git a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
index f36f4ac086c..f124dcc322e 100644
--- a/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/percentiles.sql.out
@@ -38,59 +38,67 @@ struct<>
-- !query
SELECT
percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
- percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr
-- !query schema
-struct<percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
-- !query output
-10.0 30.0
+10.0 15.0 30.0 27.5
-- !query
SELECT
k,
percentile_cont(0.25) WITHIN GROUP (ORDER BY v),
- percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr
GROUP BY k
ORDER BY k
-- !query schema
-struct<k:int,percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<k:int,percentile_cont(0.25) WITHIN GROUP (ORDER BY v):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_cont(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
-- !query output
-0 10.0 30.0
-1 12.5 17.5
-2 17.5 26.25
-3 60.0 60.0
-4 NULL NULL
+0 10.0 NULL 30.0 NULL
+1 12.5 12.5 17.5 17.5
+2 17.5 17.5 26.25 26.25
+3 60.0 60.0 60.0 60.0
+4 NULL NULL NULL NULL
-- !query
SELECT
percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
- percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr
-- !query schema
-struct<percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
-- !query output
-10.0 30.0
+10.0 10.0 30.0 30.0
-- !query
SELECT
k,
percentile_disc(0.25) WITHIN GROUP (ORDER BY v),
- percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC)
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE k > 0),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC),
+ percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE k > 0)
FROM aggr
GROUP BY k
ORDER BY k
-- !query schema
-struct<k:int,percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double>
+struct<k:int,percentile_disc(0.25) WITHIN GROUP (ORDER BY v):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v) FILTER (WHERE (k > 0)):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC):double,percentile_disc(0.25) WITHIN GROUP (ORDER BY v DESC) FILTER (WHERE (k > 0)):double>
-- !query output
-0 10.0 30.0
-1 10.0 20.0
-2 10.0 30.0
-3 60.0 60.0
-4 NULL NULL
+0 10.0 NULL 30.0 NULL
+1 10.0 10.0 20.0 20.0
+2 10.0 10.0 30.0 30.0
+3 60.0 60.0 60.0 60.0
+4 NULL NULL NULL NULL
-- !query
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org