You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/12/30 15:07:43 UTC

[doris] branch master updated: [fix](nereids) fix some arrgregate bugs in Nereids (#15326)

This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 100834df8b [fix](nereids) fix some arrgregate bugs in Nereids (#15326)
100834df8b is described below

commit 100834df8b91b4aad0ef7c96a30b1379dac968d7
Author: starocean999 <40...@users.noreply.github.com>
AuthorDate: Fri Dec 30 23:07:37 2022 +0800

    [fix](nereids) fix some arrgregate bugs in Nereids (#15326)
    
    1. the agg function without distinct keyword should be a "merge" funcion in threePhaseAggregateWithDistinct
    2. use aggregateParam.aggMode.consumeAggregateBuffer instead of aggregateParam.aggPhase.isGlobal() to indicate if a agg function is a "merge" function
    3. add an AvgDistinctToSumDivCount rule to support avg(distinct xxx) in some case
    4. AggregateExpression's nullable method should call inner function's nullable method.
    5. add a bind slot rule to bind pattern "logicalSort(logicalHaving(logicalProject()))"
    6. don't remove project node in PhysicalPlanTranslator
    7. add a cast to bigint expr when count( distinct datelike type )
    8. fallback to old optimizer if bitmap runtime filter is enabled.
    9. fix exchange node mem leak
---
 be/src/exec/exec_node.cpp                          |  9 ++-
 be/src/vec/exec/vexchange_node.cpp                 |  1 +
 .../glue/translator/ExpressionTranslator.java      |  3 +-
 .../glue/translator/PhysicalPlanTranslator.java    | 30 ++--------
 .../doris/nereids/jobs/batch/AnalyzeRulesJob.java  |  2 +
 .../org/apache/doris/nereids/rules/RuleType.java   |  1 +
 .../rules/analysis/AvgDistinctToSumDivCount.java   | 70 ++++++++++++++++++++++
 .../nereids/rules/analysis/BindSlotReference.java  | 16 +++++
 .../nereids/rules/rewrite/AggregateStrategies.java | 10 ++--
 .../rules/rewrite/logical/InApplyToJoin.java       | 14 ++++-
 .../trees/expressions/AggregateExpression.java     |  8 ++-
 .../functions/agg/MultiDistinctCount.java          | 14 ++++-
 .../rewrite/logical/AggregateStrategiesTest.java   |  2 +-
 .../logical/EliminateUnnecessaryProjectTest.java   | 34 +++++------
 .../data/query_p0/aggregate/aggregate.out          |  6 ++
 .../query_p0/join/test_runtimefilter_on_datev2.out | 10 ++++
 .../data/query_p0/keyword/test_keyword.out         | 12 ++++
 .../suites/query_p0/aggregate/aggregate.groovy     |  6 ++
 .../join/test_runtimefilter_on_datev2.groovy       |  7 +++
 .../suites/query_p0/keyword/test_keyword.groovy    |  4 ++
 20 files changed, 202 insertions(+), 57 deletions(-)

diff --git a/be/src/exec/exec_node.cpp b/be/src/exec/exec_node.cpp
index ab9f52f4af..1fa98ef2b3 100644
--- a/be/src/exec/exec_node.cpp
+++ b/be/src/exec/exec_node.cpp
@@ -762,7 +762,14 @@ Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Blo
             RETURN_IF_ERROR(_projections[i]->execute(origin_block, &result_column_id));
             auto column_ptr = origin_block->get_by_position(result_column_id)
                                       .column->convert_to_full_column_if_const();
-            mutable_columns[i]->insert_range_from(*column_ptr, 0, rows);
+            //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it
+            if (mutable_columns[i]->is_nullable() xor column_ptr->is_nullable()) {
+                DCHECK(mutable_columns[i]->is_nullable() && !column_ptr->is_nullable());
+                reinterpret_cast<ColumnNullable*>(mutable_columns[i].get())
+                        ->insert_range_from_not_nullable(*column_ptr, 0, rows);
+            } else {
+                mutable_columns[i]->insert_range_from(*column_ptr, 0, rows);
+            }
         }
 
         if (!is_mem_reuse) output_block->swap(mutable_block.to_block());
diff --git a/be/src/vec/exec/vexchange_node.cpp b/be/src/vec/exec/vexchange_node.cpp
index 4e3b596988..83bb606bf8 100644
--- a/be/src/vec/exec/vexchange_node.cpp
+++ b/be/src/vec/exec/vexchange_node.cpp
@@ -133,6 +133,7 @@ void VExchangeNode::release_resource(RuntimeState* state) {
     if (_is_merging) {
         _vsort_exec_exprs.close(state);
     }
+    ExecNode::release_resource(state);
 }
 
 Status VExchangeNode::collect_query_statistics(QueryStatistics* statistics) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index a55d382838..32b531f334 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -444,8 +444,7 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
                 true, true, nullableMode
         );
 
-        boolean isMergeFn = aggregateParam.aggPhase.isGlobal();
-
+        boolean isMergeFn = aggregateParam.aggMode.consumeAggregateBuffer;
         // create catalog FunctionCallExpr without analyze again
         return new FunctionCallExpr(catalogFunction, fnParams, aggFnParams, isMergeFn, catalogArguments);
     }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 6ee4db75d4..aac80b95f4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -127,7 +127,6 @@ import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
-import org.apache.commons.collections.CollectionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -184,13 +183,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
             rootFragment = exchangeToMergeFragment(rootFragment, context);
         }
         List<Expr> outputExprs = Lists.newArrayList();
-        if (physicalPlan instanceof PhysicalProject) {
-            PhysicalProject project = (PhysicalProject) physicalPlan;
-            if (isUnnecessaryProject(project) && !projectOnAgg(project)) {
-                List<Slot> slotReferences = removeAlias(project);
-                physicalPlan = (PhysicalPlan) physicalPlan.child(0).withOutput(slotReferences);
-            }
-        }
         physicalPlan.getOutput().stream().map(Slot::getExprId)
                 .forEach(exprId -> outputExprs.add(context.findSlotRef(exprId)));
         rootFragment.setOutputExprs(outputExprs);
@@ -1079,7 +1071,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
         for (Expr expr : predicateList) {
             extractExecSlot(expr, requiredSlotIdList);
         }
-        boolean nonPredicate = CollectionUtils.isEmpty(requiredSlotIdList);
+
         for (Expr expr : execExprList) {
             extractExecSlot(expr, requiredSlotIdList);
         }
@@ -1087,21 +1079,11 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
             TableFunctionNode tableFunctionNode = (TableFunctionNode) inputPlanNode;
             tableFunctionNode.setOutputSlotIds(Lists.newArrayList(requiredSlotIdList));
         }
-        if (!hasExprCalc(project) && (!hasPrune(project) || nonPredicate) && !projectOnAgg(project)) {
-            List<NamedExpression> namedExpressions = project.getProjects();
-            for (int i = 0; i < namedExpressions.size(); i++) {
-                NamedExpression n = namedExpressions.get(i);
-                for (Expression e : n.children()) {
-                    SlotReference slotReference = (SlotReference) e;
-                    SlotRef slotRef = context.findSlotRef(slotReference.getExprId());
-                    context.addExprIdSlotRefPair(slotList.get(i).getExprId(), slotRef);
-                }
-            }
-        } else {
-            TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
-            inputPlanNode.setProjectList(execExprList);
-            inputPlanNode.setOutputTupleDesc(tupleDescriptor);
-        }
+
+        TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
+        inputPlanNode.setProjectList(execExprList);
+        inputPlanNode.setOutputTupleDesc(tupleDescriptor);
+
         if (inputPlanNode instanceof OlapScanNode) {
             updateChildSlotsMaterialization(inputPlanNode, requiredSlotIdList, context);
             return inputFragment;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java
index 9e7cebac75..e430757b80 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.jobs.batch;
 
 import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
 import org.apache.doris.nereids.rules.analysis.BindFunction;
 import org.apache.doris.nereids.rules.analysis.BindRelation;
 import org.apache.doris.nereids.rules.analysis.BindSlotReference;
@@ -69,6 +70,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
                     // should make sure isDisinct property is correctly passed around.
                     // please see rule BindSlotReference or BindFunction for example
                     new ProjectWithDistinctToAggregate(),
+                    new AvgDistinctToSumDivCount(),
                     new ResolveOrdinalInOrderByAndGroupBy(),
                     new ReplaceExpressionByChildOutput(),
                     new HideOneRowRelationUnderUnion(),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 71149d8d83..754abd189a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -68,6 +68,7 @@ public enum RuleType {
     RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
     PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
     PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
+    AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),
     REGISTER_CTE(RuleTypeClass.REWRITE),
 
     RELATION_AUTHENTICATION(RuleTypeClass.VALIDATION),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AvgDistinctToSumDivCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AvgDistinctToSumDivCount.java
new file mode 100644
index 0000000000..8813d6fec5
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AvgDistinctToSumDivCount.java
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.analysis;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Divide;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableMap;
+
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * AvgDistinctToSumDivCount.
+ *
+ * change avg( distinct a ) into sum( distinct a ) / count( distinct a ) if there are more than 1 distinct arguments
+ */
+public class AvgDistinctToSumDivCount extends OneAnalysisRuleFactory {
+    @Override
+    public Rule build() {
+        return RuleType.AVG_DISTINCT_TO_SUM_DIV_COUNT.build(
+                logicalAggregate().when(agg -> agg.getDistinctArguments().size() > 1).then(agg -> {
+                    Map<AggregateFunction, Expression> avgToSumDivCount = agg.getAggregateFunctions()
+                            .stream()
+                            .filter(function -> function instanceof Avg && function.isDistinct())
+                            .collect(ImmutableMap.toImmutableMap(function -> function, function -> {
+                                Sum sum = new Sum(true, ((Avg) function).child());
+                                Count count = new Count(true, ((Avg) function).child());
+                                Divide divide = new Divide(sum, count);
+                                return divide;
+                            }));
+                    if (!avgToSumDivCount.isEmpty()) {
+                        List<NamedExpression> newOutput = agg.getOutputExpressions().stream()
+                                .map(expr -> (NamedExpression) ExpressionUtils.replace(expr, avgToSumDivCount))
+                                .collect(Collectors.toList());
+                        return new LogicalAggregate<>(agg.getGroupByExpressions(), newOutput,
+                                agg.child());
+                    } else {
+                        return agg;
+                    }
+                })
+        );
+    }
+}
+
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
index 003967757b..bf171d0541 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
@@ -278,6 +278,22 @@ public class BindSlotReference implements AnalysisRuleFactory {
                     return bindSortWithAggregateFunction(sort, aggregate, ctx.cascadesContext);
                 })
             ),
+            RuleType.BINDING_SORT_SLOT.build(
+                logicalSort(logicalHaving(logicalProject())).when(Plan::canBind).thenApply(ctx -> {
+                    LogicalSort<LogicalHaving<LogicalProject<GroupPlan>>> sort = ctx.root;
+                    List<OrderKey> sortItemList = sort.getOrderKeys()
+                            .stream()
+                            .map(orderKey -> {
+                                Expression item = bind(orderKey.getExpr(), sort.children(), sort, ctx.cascadesContext);
+                                if (item.containsType(UnboundSlot.class)) {
+                                    item = bind(item, sort.child().children(), sort, ctx.cascadesContext);
+                                }
+                                return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
+                            }).collect(Collectors.toList());
+
+                    return new LogicalSort<>(sortItemList, sort.child());
+                })
+            ),
             RuleType.BINDING_SORT_SLOT.build(
                 logicalSort(logicalProject()).when(Plan::canBind).thenApply(ctx -> {
                     LogicalSort<LogicalProject<GroupPlan>> sort = ctx.root;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
index 05bd79c65d..7194c89cc8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
@@ -787,7 +787,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                 false, Optional.empty(), logicalAgg.getLogicalProperties(),
                 requireGather, logicalAgg.child());
 
-        AggregateParam inputToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT);
+        AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
         List<NamedExpression> globalOutput = ExpressionUtils.rewriteDownShortCircuit(
                 logicalAgg.getOutputExpressions(), outputChild -> {
                     if (outputChild instanceof AggregateFunction) {
@@ -800,7 +800,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                         } else {
                             Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
                             return new AggregateExpression(
-                                    aggregateFunction, inputToResultParam, alias.toSlot());
+                                    aggregateFunction, bufferToResultParam, alias.toSlot());
                         }
                     } else {
                         return outputChild;
@@ -809,7 +809,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
 
         PhysicalHashAggregate<Plan> gatherLocalGatherGlobalAgg
                 = new PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), globalOutput,
-                Optional.empty(), inputToResultParam, false,
+                Optional.empty(), bufferToResultParam, false,
                 logicalAgg.getLogicalProperties(), requireGather, gatherLocalAgg);
 
         if (logicalAgg.getGroupByExpressions().isEmpty()) {
@@ -949,7 +949,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                                     bufferToResultParam, aggregateFunction.child(0));
                         } else {
                             Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
-                            return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
+                            return new AggregateExpression(aggregateFunction,
+                                    new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_RESULT),
+                                    alias.toSlot());
                         }
                     }
                     return expr;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
index ba6b1fa325..11377b80e0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.rewrite.logical;
 
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
@@ -27,10 +28,13 @@ import org.apache.doris.nereids.trees.plans.JoinHint;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.types.BitmapType;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.collect.Lists;
 
+import java.util.List;
+
 /**
  * Convert InApply to LogicalJoin.
  * <p>
@@ -52,14 +56,20 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
                         apply.right().getOutput().get(0));
             }
 
+            //TODO nereids should support bitmap runtime filter in future
+            List<Expression> conjuncts = ExpressionUtils.extractConjunction(predicate);
+            if (conjuncts.stream().anyMatch(expression -> expression.children().stream()
+                    .anyMatch(expr -> expr.getDataType() == BitmapType.INSTANCE))) {
+                throw new AnalysisException("nereids don't support bitmap runtime filter");
+            }
             if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
                 return new LogicalJoin<>(JoinType.NULL_AWARE_LEFT_ANTI_JOIN, Lists.newArrayList(),
-                        ExpressionUtils.extractConjunction(predicate),
+                        conjuncts,
                         JoinHint.NONE,
                         apply.left(), apply.right());
             } else {
                 return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(),
-                        ExpressionUtils.extractConjunction(predicate),
+                        conjuncts,
                         JoinHint.NONE,
                         apply.left(), apply.right());
             }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java
index 6b51100c9f..b88ff0d8bf 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java
@@ -17,7 +17,6 @@
 
 package org.apache.doris.nereids.trees.expressions;
 
-import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
 import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
 import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
@@ -38,7 +37,7 @@ import java.util.Objects;
  * so the aggregate function don't need to care about the phase of
  * aggregate.
  */
-public class AggregateExpression extends Expression implements UnaryExpression, PropagateNullable {
+public class AggregateExpression extends Expression implements UnaryExpression {
     private final AggregateFunction function;
 
     private final AggregateParam aggregateParam;
@@ -143,4 +142,9 @@ public class AggregateExpression extends Expression implements UnaryExpression,
     public int hashCode() {
         return Objects.hash(super.hashCode(), function, aggregateParam, child());
     }
+
+    @Override
+    public boolean nullable() {
+        return function.nullable();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
index 9ca2a203ff..c9fe9c2884 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
@@ -18,28 +18,38 @@
 package org.apache.doris.nereids.trees.expressions.functions.agg;
 
 import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
 import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.coercion.DateLikeType;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 
 import java.util.List;
+import java.util.stream.Collectors;
 
 /** MultiDistinctCount */
 public class MultiDistinctCount extends AggregateFunction
         implements AlwaysNotNullable, ExplicitlyCastableSignature {
+    // MultiDistinctCount is created in AggregateStrategies phase
+    // can't change getSignatures to use type coercion rule to add a cast expr
+    // because AggregateStrategies phase is after type coercion
     public MultiDistinctCount(Expression arg0, Expression... varArgs) {
-        super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs));
+        super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
+                .map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
+                .collect(Collectors.toList()));
     }
 
     public MultiDistinctCount(boolean isDistinct, Expression arg0, Expression... varArgs) {
-        super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs));
+        super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
+                .map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
+                .collect(Collectors.toList()));
     }
 
     @Override
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java
index 0376aa23a5..df2ca10ed7 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java
@@ -265,7 +265,7 @@ public class AggregateStrategiesTest implements PatternMatchSupported {
         // id
         AggregateParam phaseTwoCountAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT);
         AggregateParam phaseOneSumAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
-        AggregateParam phaseTwoSumAggParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT);
+        AggregateParam phaseTwoSumAggParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
         // sum
         Sum sumId = new Sum(false, id.toSlot());
 
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
index b46ea7dcef..73369adfdc 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
@@ -26,20 +26,15 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.util.LogicalPlanBuilder;
 import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.planner.OlapScanNode;
-import org.apache.doris.planner.PlanFragment;
 import org.apache.doris.utframe.TestWithFeService;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.List;
-import java.util.Set;
 
 /**
  * test ELIMINATE_UNNECESSARY_PROJECT rule.
@@ -103,18 +98,19 @@ public class EliminateUnnecessaryProjectTest extends TestWithFeService {
         Assertions.assertTrue(actual instanceof LogicalProject);
     }
 
-    @Test
-    public void testEliminationForThoseNeitherDoPruneNorDoExprCalc() {
-        PlanChecker.from(connectContext).checkPlannerResult("SELECT col1 FROM t1",
-                p -> {
-                    List<PlanFragment> fragments = p.getFragments();
-                    Assertions.assertTrue(fragments.stream()
-                            .flatMap(fragment -> {
-                                Set<OlapScanNode> scans = Sets.newHashSet();
-                                fragment.getPlanRoot().collect(OlapScanNode.class, scans);
-                                return scans.stream();
-                            })
-                            .noneMatch(s -> s.getProjectList() != null));
-                });
-    }
+    // TODO: uncomment this after the Elimination project rule is correctly implemented
+    // @Test
+    // public void testEliminationForThoseNeitherDoPruneNorDoExprCalc() {
+    //     PlanChecker.from(connectContext).checkPlannerResult("SELECT col1 FROM t1",
+    //             p -> {
+    //                 List<PlanFragment> fragments = p.getFragments();
+    //                 Assertions.assertTrue(fragments.stream()
+    //                         .flatMap(fragment -> {
+    //                             Set<OlapScanNode> scans = Sets.newHashSet();
+    //                             fragment.getPlanRoot().collect(OlapScanNode.class, scans);
+    //                             return scans.stream();
+    //                         })
+    //                         .noneMatch(s -> s.getProjectList() != null));
+    //             });
+    // }
 }
diff --git a/regression-test/data/query_p0/aggregate/aggregate.out b/regression-test/data/query_p0/aggregate/aggregate.out
index cd06915a8c..d9234ed13d 100644
--- a/regression-test/data/query_p0/aggregate/aggregate.out
+++ b/regression-test/data/query_p0/aggregate/aggregate.out
@@ -683,3 +683,9 @@ TESTING	AGAIN
 8	255
 9	1991
 
+-- !aggregate --
+3566.3333333333335	50.65743333333333
+
+-- !aggregate --
+9	9	10	8	7	7	7	2
+
diff --git a/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out b/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out
index d979124cfd..a4ab505348 100644
--- a/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out
+++ b/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out
@@ -351,3 +351,13 @@
 1	2022-01-01	1	2022-01-01
 1	2022-01-01	1	2022-01-01
 
+-- !join1 --
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
+1	2022-01-01	1	2022-01-01
\ No newline at end of file
diff --git a/regression-test/data/query_p0/keyword/test_keyword.out b/regression-test/data/query_p0/keyword/test_keyword.out
index b5a4681ac5..6e3d301db7 100644
--- a/regression-test/data/query_p0/keyword/test_keyword.out
+++ b/regression-test/data/query_p0/keyword/test_keyword.out
@@ -651,3 +651,15 @@ false	1	1989	1001	11011902	123.123	true	1989-03-21	1989-03-21T13:00	wangjuoo4	0.
 false	2	1986	1001	11011903	1243.500	false	1901-12-31	1989-03-21T13:00	wangynnsf	20.268	789.25	string12345	-170141183460469231731687303715884105727
 false	3	1989	1002	11011905	24453.325	false	2012-03-14	2000-01-01T00:00	yunlj8@nk	78945.0	3654.0	string12345	0
 
+-- !having2 --
+3	1989
+6	32767
+9	1991
+12	32767
+15	1992
+
+-- !distinct25 --
+2.0	2.0
+
+-- !distinct26 --
+2
\ No newline at end of file
diff --git a/regression-test/suites/query_p0/aggregate/aggregate.groovy b/regression-test/suites/query_p0/aggregate/aggregate.groovy
index f9fca691f1..025606d88e 100644
--- a/regression-test/suites/query_p0/aggregate/aggregate.groovy
+++ b/regression-test/suites/query_p0/aggregate/aggregate.groovy
@@ -291,4 +291,10 @@ suite("aggregate") {
     sql""" DROP TABLE IF EXISTS tempbaseall """
     sql"""create table tempbaseall PROPERTIES("replication_num" = "1")  as select k1, k2 from baseall where k1 is not null;"""
     qt_aggregate32"select k1, k2 from (select k1, max(k2) as k2 from tempbaseall where k1 > 0 group by k1 order by k1)a where k1 > 0 and k1 < 10 order by k1;"
+
+    sql 'set enable_vectorized_engine=true;'
+    sql 'set enable_fallback_to_original_planner=false'
+    sql 'set enable_nereids_planner=true'
+    qt_aggregate """ select avg(distinct c_bigint), avg(distinct c_double) from regression_test_query_p0_aggregate.${tableName} """
+    qt_aggregate """ select count(distinct c_bigint),count(distinct c_double),count(distinct c_string),count(distinct c_date_1),count(distinct c_timestamp_1),count(distinct c_timestamp_2),count(distinct c_timestamp_3),count(distinct c_boolean) from regression_test_query_p0_aggregate.${tableName} """
 }
diff --git a/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy b/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy
index de2636a170..3247f331ef 100644
--- a/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy
+++ b/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy
@@ -215,4 +215,11 @@ suite("test_runtimefilter_on_datev2", "query_p0") {
     qt_join8 """
         SELECT * FROM ${dateV2Table} a, ${dateV2Table2} b WHERE a.date = b.date;
     """
+
+    sql 'set enable_vectorized_engine=true'
+    sql 'set enable_fallback_to_original_planner=false'
+    sql 'set enable_nereids_planner=true'
+    qt_join1 """
+        SELECT * FROM ${dateTable} a, ${dateV2Table} b WHERE a.date = b.date;
+    """
 }
diff --git a/regression-test/suites/query_p0/keyword/test_keyword.groovy b/regression-test/suites/query_p0/keyword/test_keyword.groovy
index c30308ac88..8ce23e4ee4 100644
--- a/regression-test/suites/query_p0/keyword/test_keyword.groovy
+++ b/regression-test/suites/query_p0/keyword/test_keyword.groovy
@@ -116,4 +116,8 @@ suite("test_keyword", "query,p0") {
 
     qt_distinct "select distinct upper(k6) from ${tableName1} order by upper(k6)"
     qt_distinct "select distinct * from ${tableName1} where k1<20 order by k1, k2, k3, k4"
+    qt_having2 "select k1, k2 from ${tableName2} having k1 % 3 = 0 order by k1, k2"
+    qt_distinct25 "select avg(distinct k1), avg(k1) from ${tableName1}"
+    qt_distinct26 "select count(*) from (select count(distinct k1) from ${tableName1} group by k2) v \
+		    order by count(*)"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org