You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/12/30 15:07:43 UTC
[doris] branch master updated: [fix](nereids) fix some arrgregate bugs in Nereids (#15326)
This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 100834df8b [fix](nereids) fix some arrgregate bugs in Nereids (#15326)
100834df8b is described below
commit 100834df8b91b4aad0ef7c96a30b1379dac968d7
Author: starocean999 <40...@users.noreply.github.com>
AuthorDate: Fri Dec 30 23:07:37 2022 +0800
[fix](nereids) fix some arrgregate bugs in Nereids (#15326)
1. the agg function without distinct keyword should be a "merge" funcion in threePhaseAggregateWithDistinct
2. use aggregateParam.aggMode.consumeAggregateBuffer instead of aggregateParam.aggPhase.isGlobal() to indicate if a agg function is a "merge" function
3. add an AvgDistinctToSumDivCount rule to support avg(distinct xxx) in some case
4. AggregateExpression's nullable method should call inner function's nullable method.
5. add a bind slot rule to bind pattern "logicalSort(logicalHaving(logicalProject()))"
6. don't remove project node in PhysicalPlanTranslator
7. add a cast to bigint expr when count( distinct datelike type )
8. fallback to old optimizer if bitmap runtime filter is enabled.
9. fix exchange node mem leak
---
be/src/exec/exec_node.cpp | 9 ++-
be/src/vec/exec/vexchange_node.cpp | 1 +
.../glue/translator/ExpressionTranslator.java | 3 +-
.../glue/translator/PhysicalPlanTranslator.java | 30 ++--------
.../doris/nereids/jobs/batch/AnalyzeRulesJob.java | 2 +
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../rules/analysis/AvgDistinctToSumDivCount.java | 70 ++++++++++++++++++++++
.../nereids/rules/analysis/BindSlotReference.java | 16 +++++
.../nereids/rules/rewrite/AggregateStrategies.java | 10 ++--
.../rules/rewrite/logical/InApplyToJoin.java | 14 ++++-
.../trees/expressions/AggregateExpression.java | 8 ++-
.../functions/agg/MultiDistinctCount.java | 14 ++++-
.../rewrite/logical/AggregateStrategiesTest.java | 2 +-
.../logical/EliminateUnnecessaryProjectTest.java | 34 +++++------
.../data/query_p0/aggregate/aggregate.out | 6 ++
.../query_p0/join/test_runtimefilter_on_datev2.out | 10 ++++
.../data/query_p0/keyword/test_keyword.out | 12 ++++
.../suites/query_p0/aggregate/aggregate.groovy | 6 ++
.../join/test_runtimefilter_on_datev2.groovy | 7 +++
.../suites/query_p0/keyword/test_keyword.groovy | 4 ++
20 files changed, 202 insertions(+), 57 deletions(-)
diff --git a/be/src/exec/exec_node.cpp b/be/src/exec/exec_node.cpp
index ab9f52f4af..1fa98ef2b3 100644
--- a/be/src/exec/exec_node.cpp
+++ b/be/src/exec/exec_node.cpp
@@ -762,7 +762,14 @@ Status ExecNode::do_projections(vectorized::Block* origin_block, vectorized::Blo
RETURN_IF_ERROR(_projections[i]->execute(origin_block, &result_column_id));
auto column_ptr = origin_block->get_by_position(result_column_id)
.column->convert_to_full_column_if_const();
- mutable_columns[i]->insert_range_from(*column_ptr, 0, rows);
+ //TODO: this is a quick fix, we need a new function like "change_to_nullable" to do it
+ if (mutable_columns[i]->is_nullable() xor column_ptr->is_nullable()) {
+ DCHECK(mutable_columns[i]->is_nullable() && !column_ptr->is_nullable());
+ reinterpret_cast<ColumnNullable*>(mutable_columns[i].get())
+ ->insert_range_from_not_nullable(*column_ptr, 0, rows);
+ } else {
+ mutable_columns[i]->insert_range_from(*column_ptr, 0, rows);
+ }
}
if (!is_mem_reuse) output_block->swap(mutable_block.to_block());
diff --git a/be/src/vec/exec/vexchange_node.cpp b/be/src/vec/exec/vexchange_node.cpp
index 4e3b596988..83bb606bf8 100644
--- a/be/src/vec/exec/vexchange_node.cpp
+++ b/be/src/vec/exec/vexchange_node.cpp
@@ -133,6 +133,7 @@ void VExchangeNode::release_resource(RuntimeState* state) {
if (_is_merging) {
_vsort_exec_exprs.close(state);
}
+ ExecNode::release_resource(state);
}
Status VExchangeNode::collect_query_statistics(QueryStatistics* statistics) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index a55d382838..32b531f334 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -444,8 +444,7 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
true, true, nullableMode
);
- boolean isMergeFn = aggregateParam.aggPhase.isGlobal();
-
+ boolean isMergeFn = aggregateParam.aggMode.consumeAggregateBuffer;
// create catalog FunctionCallExpr without analyze again
return new FunctionCallExpr(catalogFunction, fnParams, aggFnParams, isMergeFn, catalogArguments);
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 6ee4db75d4..aac80b95f4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -127,7 +127,6 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
-import org.apache.commons.collections.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -184,13 +183,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
rootFragment = exchangeToMergeFragment(rootFragment, context);
}
List<Expr> outputExprs = Lists.newArrayList();
- if (physicalPlan instanceof PhysicalProject) {
- PhysicalProject project = (PhysicalProject) physicalPlan;
- if (isUnnecessaryProject(project) && !projectOnAgg(project)) {
- List<Slot> slotReferences = removeAlias(project);
- physicalPlan = (PhysicalPlan) physicalPlan.child(0).withOutput(slotReferences);
- }
- }
physicalPlan.getOutput().stream().map(Slot::getExprId)
.forEach(exprId -> outputExprs.add(context.findSlotRef(exprId)));
rootFragment.setOutputExprs(outputExprs);
@@ -1079,7 +1071,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
for (Expr expr : predicateList) {
extractExecSlot(expr, requiredSlotIdList);
}
- boolean nonPredicate = CollectionUtils.isEmpty(requiredSlotIdList);
+
for (Expr expr : execExprList) {
extractExecSlot(expr, requiredSlotIdList);
}
@@ -1087,21 +1079,11 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
TableFunctionNode tableFunctionNode = (TableFunctionNode) inputPlanNode;
tableFunctionNode.setOutputSlotIds(Lists.newArrayList(requiredSlotIdList));
}
- if (!hasExprCalc(project) && (!hasPrune(project) || nonPredicate) && !projectOnAgg(project)) {
- List<NamedExpression> namedExpressions = project.getProjects();
- for (int i = 0; i < namedExpressions.size(); i++) {
- NamedExpression n = namedExpressions.get(i);
- for (Expression e : n.children()) {
- SlotReference slotReference = (SlotReference) e;
- SlotRef slotRef = context.findSlotRef(slotReference.getExprId());
- context.addExprIdSlotRefPair(slotList.get(i).getExprId(), slotRef);
- }
- }
- } else {
- TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
- inputPlanNode.setProjectList(execExprList);
- inputPlanNode.setOutputTupleDesc(tupleDescriptor);
- }
+
+ TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
+ inputPlanNode.setProjectList(execExprList);
+ inputPlanNode.setOutputTupleDesc(tupleDescriptor);
+
if (inputPlanNode instanceof OlapScanNode) {
updateChildSlotsMaterialization(inputPlanNode, requiredSlotIdList, context);
return inputFragment;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java
index 9e7cebac75..e430757b80 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/AnalyzeRulesJob.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.rules.analysis.AvgDistinctToSumDivCount;
import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
@@ -69,6 +70,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
// should make sure isDisinct property is correctly passed around.
// please see rule BindSlotReference or BindFunction for example
new ProjectWithDistinctToAggregate(),
+ new AvgDistinctToSumDivCount(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ReplaceExpressionByChildOutput(),
new HideOneRowRelationUnderUnion(),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 71149d8d83..754abd189a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -68,6 +68,7 @@ public enum RuleType {
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
+ AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),
REGISTER_CTE(RuleTypeClass.REWRITE),
RELATION_AUTHENTICATION(RuleTypeClass.VALIDATION),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AvgDistinctToSumDivCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AvgDistinctToSumDivCount.java
new file mode 100644
index 0000000000..8813d6fec5
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AvgDistinctToSumDivCount.java
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.analysis;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Divide;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableMap;
+
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * AvgDistinctToSumDivCount.
+ *
+ * change avg( distinct a ) into sum( distinct a ) / count( distinct a ) if there are more than 1 distinct arguments
+ */
+public class AvgDistinctToSumDivCount extends OneAnalysisRuleFactory {
+ @Override
+ public Rule build() {
+ return RuleType.AVG_DISTINCT_TO_SUM_DIV_COUNT.build(
+ logicalAggregate().when(agg -> agg.getDistinctArguments().size() > 1).then(agg -> {
+ Map<AggregateFunction, Expression> avgToSumDivCount = agg.getAggregateFunctions()
+ .stream()
+ .filter(function -> function instanceof Avg && function.isDistinct())
+ .collect(ImmutableMap.toImmutableMap(function -> function, function -> {
+ Sum sum = new Sum(true, ((Avg) function).child());
+ Count count = new Count(true, ((Avg) function).child());
+ Divide divide = new Divide(sum, count);
+ return divide;
+ }));
+ if (!avgToSumDivCount.isEmpty()) {
+ List<NamedExpression> newOutput = agg.getOutputExpressions().stream()
+ .map(expr -> (NamedExpression) ExpressionUtils.replace(expr, avgToSumDivCount))
+ .collect(Collectors.toList());
+ return new LogicalAggregate<>(agg.getGroupByExpressions(), newOutput,
+ agg.child());
+ } else {
+ return agg;
+ }
+ })
+ );
+ }
+}
+
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
index 003967757b..bf171d0541 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
@@ -278,6 +278,22 @@ public class BindSlotReference implements AnalysisRuleFactory {
return bindSortWithAggregateFunction(sort, aggregate, ctx.cascadesContext);
})
),
+ RuleType.BINDING_SORT_SLOT.build(
+ logicalSort(logicalHaving(logicalProject())).when(Plan::canBind).thenApply(ctx -> {
+ LogicalSort<LogicalHaving<LogicalProject<GroupPlan>>> sort = ctx.root;
+ List<OrderKey> sortItemList = sort.getOrderKeys()
+ .stream()
+ .map(orderKey -> {
+ Expression item = bind(orderKey.getExpr(), sort.children(), sort, ctx.cascadesContext);
+ if (item.containsType(UnboundSlot.class)) {
+ item = bind(item, sort.child().children(), sort, ctx.cascadesContext);
+ }
+ return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
+ }).collect(Collectors.toList());
+
+ return new LogicalSort<>(sortItemList, sort.child());
+ })
+ ),
RuleType.BINDING_SORT_SLOT.build(
logicalSort(logicalProject()).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<LogicalProject<GroupPlan>> sort = ctx.root;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
index 05bd79c65d..7194c89cc8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateStrategies.java
@@ -787,7 +787,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
false, Optional.empty(), logicalAgg.getLogicalProperties(),
requireGather, logicalAgg.child());
- AggregateParam inputToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT);
+ AggregateParam bufferToResultParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
List<NamedExpression> globalOutput = ExpressionUtils.rewriteDownShortCircuit(
logicalAgg.getOutputExpressions(), outputChild -> {
if (outputChild instanceof AggregateFunction) {
@@ -800,7 +800,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
return new AggregateExpression(
- aggregateFunction, inputToResultParam, alias.toSlot());
+ aggregateFunction, bufferToResultParam, alias.toSlot());
}
} else {
return outputChild;
@@ -809,7 +809,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
PhysicalHashAggregate<Plan> gatherLocalGatherGlobalAgg
= new PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), globalOutput,
- Optional.empty(), inputToResultParam, false,
+ Optional.empty(), bufferToResultParam, false,
logicalAgg.getLogicalProperties(), requireGather, gatherLocalAgg);
if (logicalAgg.getGroupByExpressions().isEmpty()) {
@@ -949,7 +949,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
bufferToResultParam, aggregateFunction.child(0));
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
- return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
+ return new AggregateExpression(aggregateFunction,
+ new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_RESULT),
+ alias.toSlot());
}
}
return expr;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
index ba6b1fa325..11377b80e0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java
@@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite.logical;
+import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
@@ -27,10 +28,13 @@ import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
+import java.util.List;
+
/**
* Convert InApply to LogicalJoin.
* <p>
@@ -52,14 +56,20 @@ public class InApplyToJoin extends OneRewriteRuleFactory {
apply.right().getOutput().get(0));
}
+ //TODO nereids should support bitmap runtime filter in future
+ List<Expression> conjuncts = ExpressionUtils.extractConjunction(predicate);
+ if (conjuncts.stream().anyMatch(expression -> expression.children().stream()
+ .anyMatch(expr -> expr.getDataType() == BitmapType.INSTANCE))) {
+ throw new AnalysisException("nereids don't support bitmap runtime filter");
+ }
if (((InSubquery) apply.getSubqueryExpr()).isNot()) {
return new LogicalJoin<>(JoinType.NULL_AWARE_LEFT_ANTI_JOIN, Lists.newArrayList(),
- ExpressionUtils.extractConjunction(predicate),
+ conjuncts,
JoinHint.NONE,
apply.left(), apply.right());
} else {
return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(),
- ExpressionUtils.extractConjunction(predicate),
+ conjuncts,
JoinHint.NONE,
apply.left(), apply.right());
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java
index 6b51100c9f..b88ff0d8bf 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/AggregateExpression.java
@@ -17,7 +17,6 @@
package org.apache.doris.nereids.trees.expressions;
-import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
@@ -38,7 +37,7 @@ import java.util.Objects;
* so the aggregate function don't need to care about the phase of
* aggregate.
*/
-public class AggregateExpression extends Expression implements UnaryExpression, PropagateNullable {
+public class AggregateExpression extends Expression implements UnaryExpression {
private final AggregateFunction function;
private final AggregateParam aggregateParam;
@@ -143,4 +142,9 @@ public class AggregateExpression extends Expression implements UnaryExpression,
public int hashCode() {
return Objects.hash(super.hashCode(), function, aggregateParam, child());
}
+
+ @Override
+ public boolean nullable() {
+ return function.nullable();
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
index 9ca2a203ff..c9fe9c2884 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java
@@ -18,28 +18,38 @@
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
+import java.util.stream.Collectors;
/** MultiDistinctCount */
public class MultiDistinctCount extends AggregateFunction
implements AlwaysNotNullable, ExplicitlyCastableSignature {
+ // MultiDistinctCount is created in AggregateStrategies phase
+ // can't change getSignatures to use type coercion rule to add a cast expr
+ // because AggregateStrategies phase is after type coercion
public MultiDistinctCount(Expression arg0, Expression... varArgs) {
- super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs));
+ super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
+ .map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
+ .collect(Collectors.toList()));
}
public MultiDistinctCount(boolean isDistinct, Expression arg0, Expression... varArgs) {
- super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs));
+ super("multi_distinct_count", true, ExpressionUtils.mergeArguments(arg0, varArgs).stream()
+ .map(arg -> arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg)
+ .collect(Collectors.toList()));
}
@Override
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java
index 0376aa23a5..df2ca10ed7 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateStrategiesTest.java
@@ -265,7 +265,7 @@ public class AggregateStrategiesTest implements PatternMatchSupported {
// id
AggregateParam phaseTwoCountAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_RESULT);
AggregateParam phaseOneSumAggParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER);
- AggregateParam phaseTwoSumAggParam = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT);
+ AggregateParam phaseTwoSumAggParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT);
// sum
Sum sumId = new Sum(false, id.toSlot());
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
index b46ea7dcef..73369adfdc 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateUnnecessaryProjectTest.java
@@ -26,20 +26,15 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.planner.OlapScanNode;
-import org.apache.doris.planner.PlanFragment;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
-import java.util.Set;
/**
* test ELIMINATE_UNNECESSARY_PROJECT rule.
@@ -103,18 +98,19 @@ public class EliminateUnnecessaryProjectTest extends TestWithFeService {
Assertions.assertTrue(actual instanceof LogicalProject);
}
- @Test
- public void testEliminationForThoseNeitherDoPruneNorDoExprCalc() {
- PlanChecker.from(connectContext).checkPlannerResult("SELECT col1 FROM t1",
- p -> {
- List<PlanFragment> fragments = p.getFragments();
- Assertions.assertTrue(fragments.stream()
- .flatMap(fragment -> {
- Set<OlapScanNode> scans = Sets.newHashSet();
- fragment.getPlanRoot().collect(OlapScanNode.class, scans);
- return scans.stream();
- })
- .noneMatch(s -> s.getProjectList() != null));
- });
- }
+ // TODO: uncomment this after the Elimination project rule is correctly implemented
+ // @Test
+ // public void testEliminationForThoseNeitherDoPruneNorDoExprCalc() {
+ // PlanChecker.from(connectContext).checkPlannerResult("SELECT col1 FROM t1",
+ // p -> {
+ // List<PlanFragment> fragments = p.getFragments();
+ // Assertions.assertTrue(fragments.stream()
+ // .flatMap(fragment -> {
+ // Set<OlapScanNode> scans = Sets.newHashSet();
+ // fragment.getPlanRoot().collect(OlapScanNode.class, scans);
+ // return scans.stream();
+ // })
+ // .noneMatch(s -> s.getProjectList() != null));
+ // });
+ // }
}
diff --git a/regression-test/data/query_p0/aggregate/aggregate.out b/regression-test/data/query_p0/aggregate/aggregate.out
index cd06915a8c..d9234ed13d 100644
--- a/regression-test/data/query_p0/aggregate/aggregate.out
+++ b/regression-test/data/query_p0/aggregate/aggregate.out
@@ -683,3 +683,9 @@ TESTING AGAIN
8 255
9 1991
+-- !aggregate --
+3566.3333333333335 50.65743333333333
+
+-- !aggregate --
+9 9 10 8 7 7 7 2
+
diff --git a/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out b/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out
index d979124cfd..a4ab505348 100644
--- a/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out
+++ b/regression-test/data/query_p0/join/test_runtimefilter_on_datev2.out
@@ -351,3 +351,13 @@
1 2022-01-01 1 2022-01-01
1 2022-01-01 1 2022-01-01
+-- !join1 --
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
+1 2022-01-01 1 2022-01-01
\ No newline at end of file
diff --git a/regression-test/data/query_p0/keyword/test_keyword.out b/regression-test/data/query_p0/keyword/test_keyword.out
index b5a4681ac5..6e3d301db7 100644
--- a/regression-test/data/query_p0/keyword/test_keyword.out
+++ b/regression-test/data/query_p0/keyword/test_keyword.out
@@ -651,3 +651,15 @@ false 1 1989 1001 11011902 123.123 true 1989-03-21 1989-03-21T13:00 wangjuoo4 0.
false 2 1986 1001 11011903 1243.500 false 1901-12-31 1989-03-21T13:00 wangynnsf 20.268 789.25 string12345 -170141183460469231731687303715884105727
false 3 1989 1002 11011905 24453.325 false 2012-03-14 2000-01-01T00:00 yunlj8@nk 78945.0 3654.0 string12345 0
+-- !having2 --
+3 1989
+6 32767
+9 1991
+12 32767
+15 1992
+
+-- !distinct25 --
+2.0 2.0
+
+-- !distinct26 --
+2
\ No newline at end of file
diff --git a/regression-test/suites/query_p0/aggregate/aggregate.groovy b/regression-test/suites/query_p0/aggregate/aggregate.groovy
index f9fca691f1..025606d88e 100644
--- a/regression-test/suites/query_p0/aggregate/aggregate.groovy
+++ b/regression-test/suites/query_p0/aggregate/aggregate.groovy
@@ -291,4 +291,10 @@ suite("aggregate") {
sql""" DROP TABLE IF EXISTS tempbaseall """
sql"""create table tempbaseall PROPERTIES("replication_num" = "1") as select k1, k2 from baseall where k1 is not null;"""
qt_aggregate32"select k1, k2 from (select k1, max(k2) as k2 from tempbaseall where k1 > 0 group by k1 order by k1)a where k1 > 0 and k1 < 10 order by k1;"
+
+ sql 'set enable_vectorized_engine=true;'
+ sql 'set enable_fallback_to_original_planner=false'
+ sql 'set enable_nereids_planner=true'
+ qt_aggregate """ select avg(distinct c_bigint), avg(distinct c_double) from regression_test_query_p0_aggregate.${tableName} """
+ qt_aggregate """ select count(distinct c_bigint),count(distinct c_double),count(distinct c_string),count(distinct c_date_1),count(distinct c_timestamp_1),count(distinct c_timestamp_2),count(distinct c_timestamp_3),count(distinct c_boolean) from regression_test_query_p0_aggregate.${tableName} """
}
diff --git a/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy b/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy
index de2636a170..3247f331ef 100644
--- a/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy
+++ b/regression-test/suites/query_p0/join/test_runtimefilter_on_datev2.groovy
@@ -215,4 +215,11 @@ suite("test_runtimefilter_on_datev2", "query_p0") {
qt_join8 """
SELECT * FROM ${dateV2Table} a, ${dateV2Table2} b WHERE a.date = b.date;
"""
+
+ sql 'set enable_vectorized_engine=true'
+ sql 'set enable_fallback_to_original_planner=false'
+ sql 'set enable_nereids_planner=true'
+ qt_join1 """
+ SELECT * FROM ${dateTable} a, ${dateV2Table} b WHERE a.date = b.date;
+ """
}
diff --git a/regression-test/suites/query_p0/keyword/test_keyword.groovy b/regression-test/suites/query_p0/keyword/test_keyword.groovy
index c30308ac88..8ce23e4ee4 100644
--- a/regression-test/suites/query_p0/keyword/test_keyword.groovy
+++ b/regression-test/suites/query_p0/keyword/test_keyword.groovy
@@ -116,4 +116,8 @@ suite("test_keyword", "query,p0") {
qt_distinct "select distinct upper(k6) from ${tableName1} order by upper(k6)"
qt_distinct "select distinct * from ${tableName1} where k1<20 order by k1, k2, k3, k4"
+ qt_having2 "select k1, k2 from ${tableName2} having k1 % 3 = 0 order by k1, k2"
+ qt_distinct25 "select avg(distinct k1), avg(k1) from ${tableName1}"
+ qt_distinct26 "select count(*) from (select count(distinct k1) from ${tableName1} group by k2) v \
+ order by count(*)"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org