You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by da...@apache.org on 2024/04/27 06:38:15 UTC

(doris) 23/30: [Fix](nereids) fix rule merge_aggregate when has project (#33892)

This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-4.0-preview
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 4e6af32545f44c04a711bd614f806c259d8a2365
Author: feiniaofeiafei <53...@users.noreply.github.com>
AuthorDate: Fri Apr 26 12:34:24 2024 +0800

    [Fix](nereids) fix rule merge_aggregate when has project (#33892)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |  4 +-
 .../nereids/rules/rewrite/MergeAggregate.java      | 23 ++++---
 .../merge_aggregate/merge_aggregate.out            | 51 ++++++++++++++
 .../merge_aggregate/merge_aggregate.groovy         | 80 ++++++++++++++++++++++
 4 files changed, 148 insertions(+), 10 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index e8223524367..80da080daf6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -297,7 +297,9 @@ public class Rewriter extends AbstractBatchJobExecutor {
 
             topic("Eliminate GroupBy",
                     topDown(new EliminateGroupBy(),
-                            new MergeAggregate())
+                            new MergeAggregate(),
+                            // need to adjust min/max/sum nullable attribute after merge aggregate
+                            new AdjustAggregateNullableForEmptySet())
             ),
 
             topic("Eager aggregation",
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
index 9a0b9f8b5e0..a2c23dd9b41 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
@@ -34,10 +34,12 @@ import org.apache.doris.nereids.util.PlanUtils;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -87,15 +89,14 @@ public class MergeAggregate implements RewriteRuleFactory {
     private Plan mergeAggProjectAgg(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
         LogicalProject<LogicalAggregate<Plan>> project = outerAgg.child();
         LogicalAggregate<Plan> innerAgg = project.child();
-
+        List<NamedExpression> outputExpressions = outerAgg.getOutputExpressions();
+        List<NamedExpression> replacedOutputExpressions = PlanUtils.replaceExpressionByProjections(
+                                project.getProjects(), (List) outputExpressions);
         // rewrite agg function. e.g. max(max)
-        List<NamedExpression> aggFunc = outerAgg.getOutputExpressions().stream()
+        List<NamedExpression> replacedAggFunc = replacedOutputExpressions.stream()
                 .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
                 .map(e -> rewriteAggregateFunction(e, innerAggExprIdToAggFunc))
                 .collect(Collectors.toList());
-        // rewrite agg function directly refer to the slot below the project
-        List<Expression> replacedAggFunc = PlanUtils.replaceExpressionByProjections(project.getProjects(),
-                (List) aggFunc);
         // replace groupByKeys directly refer to the slot below the project
         List<Expression> replacedGroupBy = PlanUtils.replaceExpressionByProjections(project.getProjects(),
                 outerAgg.getGroupByExpressions());
@@ -138,13 +139,17 @@ public class MergeAggregate implements RewriteRuleFactory {
     }
 
     boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, LogicalAggregate<Plan> innerAgg,
-            boolean sameGroupBy) {
+            boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
         innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
                 .filter(expr -> (expr instanceof Alias) && (expr.child(0) instanceof AggregateFunction))
                 .collect(Collectors.toMap(NamedExpression::getExprId, value -> (AggregateFunction) value.child(0),
                         (existValue, newValue) -> existValue));
         Set<AggregateFunction> aggregateFunctions = outerAgg.getAggregateFunctions();
-        for (AggregateFunction outerFunc : aggregateFunctions) {
+        List<AggregateFunction> replacedAggFunctions = projectOptional.map(project ->
+                (List<AggregateFunction>) PlanUtils.replaceExpressionByProjections(
+                projectOptional.get().getProjects(), new ArrayList<>(aggregateFunctions)))
+                .orElse(new ArrayList<>(aggregateFunctions));
+        for (AggregateFunction outerFunc : replacedAggFunctions) {
             if (!(ALLOW_MERGE_AGGREGATE_FUNCTIONS.contains(outerFunc.getName()))) {
                 return false;
             }
@@ -188,7 +193,7 @@ public class MergeAggregate implements RewriteRuleFactory {
         }
         boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
 
-        return commonCheck(outerAgg, innerAgg, sameGroupBy);
+        return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.empty());
     }
 
     private boolean canMergeAggregateWithProject(LogicalAggregate<LogicalProject<LogicalAggregate<Plan>>> outerAgg) {
@@ -206,6 +211,6 @@ public class MergeAggregate implements RewriteRuleFactory {
             return false;
         }
         boolean sameGroupBy = (innerAgg.getGroupByExpressions().size() == outerAgg.getGroupByExpressions().size());
-        return commonCheck(outerAgg, innerAgg, sameGroupBy);
+        return commonCheck(outerAgg, innerAgg, sameGroupBy, Optional.of(project));
     }
 }
diff --git a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
index ba5b127a56f..fba17e8d7b9 100644
--- a/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
+++ b/regression-test/data/nereids_rules_p0/merge_aggregate/merge_aggregate.out
@@ -246,3 +246,54 @@ PhysicalResultSink
 --------------------PhysicalProject
 ----------------------PhysicalOlapScan[mal_test1]
 
+-- !test_has_project_distinct_cant_transform --
+1
+
+-- !test_has_project_distinct_cant_transform_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----PhysicalDistribute[DistributionSpecGather]
+------hashAgg[LOCAL]
+--------PhysicalProject
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[mal_test_merge_agg]
+
+-- !test_distinct_expr_transform --
+-1
+
+-- !test_distinct_expr_transform_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----PhysicalDistribute[DistributionSpecGather]
+------hashAgg[LOCAL]
+--------PhysicalProject
+----------PhysicalOlapScan[mal_test_merge_agg]
+
+-- !test_has_project_distinct_expr_transform --
+1
+1
+1
+
+-- !test_has_project_distinct_expr_transform --
+PhysicalResultSink
+--PhysicalDistribute[DistributionSpecGather]
+----PhysicalProject
+------hashAgg[GLOBAL]
+--------PhysicalDistribute[DistributionSpecHash]
+----------hashAgg[LOCAL]
+------------PhysicalProject
+--------------PhysicalOlapScan[mal_test_merge_agg]
+
+-- !test_sum_empty_table --
+\N	\N	\N
+
+-- !test_sum_empty_table_shape --
+PhysicalResultSink
+--hashAgg[GLOBAL]
+----PhysicalDistribute[DistributionSpecGather]
+------hashAgg[LOCAL]
+--------PhysicalOlapScan[mal_test2]
+
diff --git a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
index 44c256e2f57..46cd4a0a9b7 100644
--- a/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
+++ b/regression-test/suites/nereids_rules_p0/merge_aggregate/merge_aggregate.groovy
@@ -174,4 +174,84 @@ suite("merge_aggregate") {
         group by a order by 1,2;
     """
 
+    sql "drop table if exists mal_test_merge_agg"
+    sql """
+         create table mal_test_merge_agg(
+            k1 int null,
+            k2 int not null,
+            k3 string null,
+            k4 varchar(100) null
+        )
+        duplicate key (k1,k2)
+        distributed BY hash(k1) buckets 3
+        properties("replication_num" = "1");
+    """
+    sql "insert into mal_test_merge_agg select 1,1,'1','a';"
+    sql "insert into mal_test_merge_agg select 2,2,'2','b';"
+    sql "insert into mal_test_merge_agg select 3,-3,null,'c';"
+    sql "sync"
+
+    qt_test_has_project_distinct_cant_transform """
+        select max(count_col)
+        from (
+            select k4,
+            count(distinct case when k3 is null then 1 else 0 end) as count_col
+            from mal_test_merge_agg group by k4
+        ) t ;
+    """
+    qt_test_has_project_distinct_cant_transform_shape """
+        explain shape plan
+        select max(count_col)
+        from (
+            select k4,
+            count(distinct case when k3 is null then 1 else 0 end) as count_col
+            from mal_test_merge_agg group by k4
+        ) t ;
+    """
+
+    qt_test_distinct_expr_transform """
+        select max(count_col)
+        from (
+            select k4,
+            max(-abs(k1)) as count_col
+            from mal_test_merge_agg group by k4
+        ) t ;
+    """
+    qt_test_distinct_expr_transform_shape """
+        explain shape plan
+        select max(count_col)
+        from (
+            select k4,
+            max(-abs(k1)) as count_col
+            from mal_test_merge_agg group by k4
+        ) t ;
+    """
+
+    qt_test_has_project_distinct_expr_transform """
+        select sum(count_col)
+        from (
+            select k4,
+            count(distinct case when k3 is null then 1 else 0 end) as count_col
+            from mal_test_merge_agg group by k4
+        ) t  group by k4;
+    """
+
+    qt_test_has_project_distinct_expr_transform """
+        explain shape plan
+        select sum(count_col)
+        from (
+            select k4,
+            count(distinct case when k3 is null then 1 else 0 end) as count_col
+            from mal_test_merge_agg group by k4
+        ) t  group by k4;
+    """
+
+    qt_test_sum_empty_table """
+        select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t;
+    """
+
+    qt_test_sum_empty_table_shape """
+        explain shape plan
+        select sum(col1),min(col2),max(col3) from (select sum(a) col1, min(b) col2, max(pk) col3 from mal_test2 group by a) t;
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org