You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by da...@apache.org on 2024/04/27 06:38:15 UTC
(doris) 23/30: [Fix](nereids) fix rule merge_aggregate when has project (#33892)
This is an automated email from the ASF dual-hosted git repository.
dataroaring pushed a commit to branch branch-4.0-preview
in repository https://gitbox.apache.org/repos/asf/doris.git
commit 4e6af32545f44c04a711bd614f806c259d8a2365
Author: feiniaofeiafei <53...@users.noreply.github.com>
AuthorDate: Fri Apr 26 12:34:24 2024 +0800
[Fix](nereids) fix rule merge_aggregate when has project (#33892)
---
.../doris/nereids/jobs/executor/Rewriter.java | 4 +-
.../nereids/rules/rewrite/MergeAggregate.java | 23 ++++---
.../merge_aggregate/merge_aggregate.out | 51 ++++++++++++++
.../merge_aggregate/merge_aggregate.groovy | 80 ++++++++++++++++++++++
4 files changed, 148 insertions(+), 10 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index e8223524367..80da080daf6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -297,7 +297,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
topic("Eliminate GroupBy",
topDown(new EliminateGroupBy(),
- new MergeAggregate())
+ new MergeAggregate(),
+ // need to adjust min/max/sum nullable attribute after merge aggregate
+ new AdjustAggregateNullableForEmptySet())
),
topic("Eager aggregation",
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
index 9a0b9f8b5e0..a2c23dd9b41 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
@@ -34,10 +34,12 @@ import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
+import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -87,15 +89,14 @@ public class MergeAggregate implements RewriteRuleFactory {
private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
LogicalAggregate<Plan> innerAgg = project.child();
-
+ List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
+ List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
+ project.getProjects(), (List) outputExpressions);
// rewrite agg function. e.g. max(max)
- List<NamedExpression> aggFunc = outerAgg.getOutputExpressions().stream()
+ List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
.collect(Collectors.toList());
- // rewrite agg function directly refer to the slot below the project
- List<Expression> replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(),
- (List) aggFunc);
// replace groupByKeys directly refer to the slot below the project
List<Expression> replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(),
outerAgg.getGroupByExpressions());
@@ -138,13 +139,17 @@ public class MergeAggregate implements RewriteRuleFactory {
}
boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
- boolean sameGroupBy) {
+ boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
.filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
.collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
(existValue, newValue) -> existValue));
Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
- for (AggregateFunction outerFunc : aggregateFunctions) {
+ List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
+ (List<AggregateFunction>) PlanUtils.replaceExpressionByProjections(
+ projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions)))
+ .orElse(new ArrayList<>(aggregateFunctions));
+ for (AggregateFunction outerFunc : replacedAggFunctions) {
if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) {
return false;
}
@@ -188,7 +193,7 @@ public class MergeAggregate implements RewriteRuleFactory {
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
- return commonCheck(outerAgg, innerAgg, sameGroupBy);
+ return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty());
}
private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
@@ -206,6 +211,6 @@ public class MergeAggregate implements RewriteRuleFactory {
return false;
}
boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
- return commonCheck(outerAgg, innerAgg, sameGroupBy);
+ return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
}
}
diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
index ba5b127a56f..fba17e8d7b9 100644
--- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
+++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
@@ -246,3 +246,54 @@ PhysicalResultSink
--------------------PhysicalProject
----------------------PhysicalOlapScan[mal_test1]
+-- !test_has_project_distinct_cant_transform --
+1
+
+-- !test_has_project_distinct_cant_transform_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----PhysicalDistribute[DistributionSpecGather]
+------hashAgg[LOCAL]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[mal_test_merge_agg]
+
+-- !test_distinct_expr_transform --
+-1
+
+-- !test_distinct_expr_transform_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----PhysicalDistribute[DistributionSpecGather]
+------hashAgg[LOCAL]
+--------PhysicalProject
+----------PhysicalOlapScan[mal_test_merge_agg]
+
+-- !test_has_project_distinct_expr_transform --
+1
+1
+1
+
+-- !test_has_project_distinct_expr_transform --
+PhysicalResultSink
+--PhysicalDistribute[DistributionSpecGather]
+----PhysicalProject
+------hashAgg[GLOBAL]
+--------PhysicalDistribute[DistributionSpecHash]
+----------hashAgg[LOCAL]
+------------PhysicalProject
+--------------PhysicalOlapScan[mal_test_merge_agg]
+
+-- !test_sum_empty_table --
+\N \N \N
+
+-- !test_sum_empty_table_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----PhysicalDistribute[DistributionSpecGather]
+------hashAgg[LOCAL]
+--------PhysicalOlapScan[mal_test2]
+
diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
index 44c256e2f57..46cd4a0a9b7 100644
--- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
+++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
@@ -174,4 +174,84 @@ suite("merge_aggregate") {
group by a order by 1,2;
"""
+ sql "drop table if exists mal_test_merge_agg"
+ sql """
+ create table mal_test_merge_agg(
+ k1 int null,
+ k2 int not null,
+ k3 string null,
+ k4 varchar(100) null
+ )
+ duplicate key (k1,k2)
+ distributed BY hash(k1) buckets 3
+ properties("replication_num" = "1");
+ """
+ sql "insert into mal_test_merge_agg select 1,1,'1','a';"
+ sql "insert into mal_test_merge_agg select 2,2,'2','b';"
+ sql "insert into mal_test_merge_agg select 3,-3,null,'c';"
+ sql "sync"
+
+ qt_test_has_project_distinct_cant_transform """
+ select max(count_col)
+ from (
+ select k4,
+ count(distinct case when k3 is null then 1 else 0 end) as count_col
+ from mal_test_merge_agg group by k4
+ ) t ;
+ """
+ qt_test_has_project_distinct_cant_transform_shape """
+ explain shape plan
+ select max(count_col)
+ from (
+ select k4,
+ count(distinct case when k3 is null then 1 else 0 end) as count_col
+ from mal_test_merge_agg group by k4
+ ) t ;
+ """
+
+ qt_test_distinct_expr_transform """
+ select max(count_col)
+ from (
+ select k4,
+ max(-abs(k1)) as count_col
+ from mal_test_merge_agg group by k4
+ ) t ;
+ """
+ qt_test_distinct_expr_transform_shape """
+ explain shape plan
+ select max(count_col)
+ from (
+ select k4,
+ max(-abs(k1)) as count_col
+ from mal_test_merge_agg group by k4
+ ) t ;
+ """
+
+ qt_test_has_project_distinct_expr_transform """
+ select sum(count_col)
+ from (
+ select k4,
+ count(distinct case when k3 is null then 1 else 0 end) as count_col
+ from mal_test_merge_agg group by k4
+ ) t group by k4;
+ """
+
+ qt_test_has_project_distinct_expr_transform """
+ explain shape plan
+ select sum(count_col)
+ from (
+ select k4,
+ count(distinct case when k3 is null then 1 else 0 end) as count_col
+ from mal_test_merge_agg group by k4
+ ) t group by k4;
+ """
+
+ qt_test_sum_empty_table """
+ select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t;
+ """
+
+ qt_test_sum_empty_table_shape """
+ explain shape plan
+ select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t;
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org