You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2024/01/23 02:13:45 UTC
(doris) 35/43: [fix](Nereids) result nullable of sum distinct in scalar agg is wrong (#30221)
This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
commit ce47354d598b099fb0ebe2784b38978c13a4c808
Author: morrySnow <10...@users.noreply.github.com>
AuthorDate: Mon Jan 22 17:53:25 2024 +0800
[fix](Nereids) result nullable of sum distinct in scalar agg is wrong (#30221)
---
.../AdjustAggregateNullableForEmptySet.java | 3 +-
.../rules/implementation/AggregateStrategies.java | 44 ++++++++--------------
.../trees/expressions/functions/agg/Sum.java | 6 +++
.../data/nereids_syntax_p0/agg_with_empty_set.out | 3 ++
.../nereids_syntax_p0/agg_with_empty_set.groovy | 1 +
5 files changed, 26 insertions(+), 31 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
index 75400a8e6b0..86a70d35ccc 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
@@ -90,8 +90,7 @@ public class AdjustAggregateNullableForEmptySet implements RewriteRuleFactory {
@Override
public Expression visitNullableAggregateFunction(NullableAggregateFunction nullableAggregateFunction,
Boolean alwaysNullable) {
- return nullableAggregateFunction.isDistinct() ? nullableAggregateFunction
- : nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
+ return nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
}
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index a0eb011ba92..c9907ae7c3d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -50,7 +50,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
-import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
@@ -108,9 +107,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
logicalAggregate(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
- ).when(filter -> filter.getConjuncts().size() > 0))
+ ).when(filter -> !filter.getConjuncts().isEmpty()))
.when(agg -> enablePushDownCountOnIndex())
- .when(agg -> agg.getGroupByExpressions().size() == 0)
+ .when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
@@ -128,9 +127,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
logicalProject(
logicalFilter(
logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
- ).when(filter -> filter.getConjuncts().size() > 0)))
+ ).when(filter -> !filter.getConjuncts().isEmpty())))
.when(agg -> enablePushDownCountOnIndex())
- .when(agg -> agg.getGroupByExpressions().size() == 0)
+ .when(agg -> agg.getGroupByExpressions().isEmpty())
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct());
@@ -154,7 +153,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
- return column.isPresent() ? column.get().isDeleteSignColumn() : false;
+ return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
})
@@ -187,8 +186,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
.children().get(0);
if (childExpr instanceof SlotReference) {
Optional<Column> column = ((SlotReference) childExpr).getColumn();
- return column.isPresent() ? column.get().isDeleteSignColumn()
- : false;
+ return column.map(Column::isDeleteSignColumn).orElse(false);
}
return false;
}))
@@ -253,12 +251,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
),
RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
- .when(agg -> agg.getDistinctArguments().size() == 0)
+ .when(agg -> agg.getDistinctArguments().isEmpty())
.thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
- .when(agg -> agg.getDistinctArguments().size() == 0)
+ .when(agg -> agg.getDistinctArguments().isEmpty())
.thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
@@ -435,12 +433,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
.map(ExpressionTrait::getArguments)
.flatMap(List::stream)
- .allMatch(argument -> {
- if (argument instanceof SlotReference) {
- return true;
- }
- return false;
- });
+ .allMatch(argument -> argument instanceof SlotReference);
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
@@ -457,19 +450,13 @@ public class AggregateStrategies implements ImplementationRuleFactory {
}
onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
.stream()
- .allMatch(argument -> {
- if (argument instanceof SlotReference) {
- return true;
- }
- return false;
- });
+ .allMatch(argument -> argument instanceof SlotReference);
if (!onlyContainsSlotOrNumericCastSlot) {
return false;
}
Set<SlotReference> aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction,
SlotReference.class::isInstance);
- List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
- outPutSlots);
+ List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots, outPutSlots);
for (SlotReference slot : usedSlotInTable) {
Column column = slot.getColumn().get();
PrimitiveType colType = column.getType().getPrimitiveType();
@@ -630,7 +617,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
if (logicalScan instanceof LogicalOlapScan) {
PhysicalOlapScan physicalScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
.build()
- .transform((LogicalOlapScan) logicalScan, cascadesContext)
+ .transform(logicalScan, cascadesContext)
.get(0);
if (project != null) {
@@ -647,7 +634,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
} else if (logicalScan instanceof LogicalFileScan) {
PhysicalFileScan physicalScan = (PhysicalFileScan) new LogicalFileScanToPhysicalFileScan()
.build()
- .transform((LogicalFileScan) logicalScan, cascadesContext)
+ .transform(logicalScan, cascadesContext)
.get(0);
if (project != null) {
return aggregate.withChildren(ImmutableList.of(
@@ -1193,8 +1180,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT);
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
- return new AggregateExpression(
- aggregateFunction, bufferToResultParam, alias.toSlot());
+ return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
}
} else {
return outputChild;
@@ -1582,7 +1568,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
return new MultiDistinctCount(function.getArgument(0),
function.getArguments().subList(1, function.arity()).toArray(new Expression[0]));
} else if (function instanceof Sum && function.isDistinct()) {
- return new MultiDistinctSum(function.getArgument(0));
+ return ((Sum) function).convertToMultiDistinct();
} else if (function instanceof GroupConcat && function.isDistinct()) {
return ((GroupConcat) function).convertToMultiDistinct();
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
index f0dbd839583..0b00536d6a8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
@@ -78,6 +78,12 @@ public class Sum extends NullableAggregateFunction
super("sum", distinct, alwaysNullable, arg);
}
+ public MultiDistinctSum convertToMultiDistinct() {
+ Preconditions.checkArgument(distinct,
+ "can't convert to multi_distinct_sum because there is no distinct args");
+ return new MultiDistinctSum(false, alwaysNullable, child());
+ }
+
@Override
public void checkLegalityBeforeTypeCoercion() {
DataType argType = child().getDataType();
diff --git a/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out b/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out
index 1db851c70c1..ffe8f93eb58 100644
--- a/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out
+++ b/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out
@@ -18,3 +18,6 @@
-- !select6 --
0 \N \N \N \N
+-- !ditinct_sum --
+\N
+
diff --git a/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy b/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy
index 5fc117445a4..bdcb4526a2c 100644
--- a/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy
+++ b/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy
@@ -29,4 +29,5 @@ suite("agg_with_empty_set") {
(select min(c_custkey) from customer)"""
qt_select6 """select count(c_custkey), max(c_custkey), min(c_custkey), avg(c_custkey), sum(c_custkey) from customer where c_custkey <
(select min(c_custkey) from customer) having min(c_custkey) is null"""
+ qt_ditinct_sum """select sum(distinct ifnull(c_custkey, 0)) from customer where 1 = 0"""
}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org