You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2024/01/23 02:13:45 UTC

(doris) 35/43: [fix](Nereids) result nullable of sum distinct in scalar agg is wrong (#30221)

This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit ce47354d598b099fb0ebe2784b38978c13a4c808
Author: morrySnow <10...@users.noreply.github.com>
AuthorDate: Mon Jan 22 17:53:25 2024 +0800

    [fix](Nereids) result nullable of sum distinct in scalar agg is wrong (#30221)
---
 .../AdjustAggregateNullableForEmptySet.java        |  3 +-
 .../rules/implementation/AggregateStrategies.java  | 44 ++++++++--------------
 .../trees/expressions/functions/agg/Sum.java       |  6 +++
 .../data/nereids_syntax_p0/agg_with_empty_set.out  |  3 ++
 .../nereids_syntax_p0/agg_with_empty_set.groovy    |  1 +
 5 files changed, 26 insertions(+), 31 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
index 75400a8e6b0..86a70d35ccc 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java
@@ -90,8 +90,7 @@ public class AdjustAggregateNullableForEmptySet implements RewriteRuleFactory {
         @Override
         public Expression visitNullableAggregateFunction(NullableAggregateFunction nullableAggregateFunction,
                 Boolean alwaysNullable) {
-            return nullableAggregateFunction.isDistinct() ? nullableAggregateFunction
-                    : nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
+            return nullableAggregateFunction.withAlwaysNullable(alwaysNullable);
         }
     }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
index a0eb011ba92..c9907ae7c3d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java
@@ -50,7 +50,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
 import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
-import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
@@ -108,9 +107,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                 logicalAggregate(
                     logicalFilter(
                         logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
-                    ).when(filter -> filter.getConjuncts().size() > 0))
+                    ).when(filter -> !filter.getConjuncts().isEmpty()))
                     .when(agg -> enablePushDownCountOnIndex())
-                    .when(agg -> agg.getGroupByExpressions().size() == 0)
+                    .when(agg -> agg.getGroupByExpressions().isEmpty())
                     .when(agg -> {
                         Set<AggregateFunction> funcs = agg.getAggregateFunctions();
                         return !funcs.isEmpty() && funcs.stream()
@@ -128,9 +127,9 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                     logicalProject(
                         logicalFilter(
                             logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable)
-                        ).when(filter -> filter.getConjuncts().size() > 0)))
+                        ).when(filter -> !filter.getConjuncts().isEmpty())))
                     .when(agg -> enablePushDownCountOnIndex())
-                    .when(agg -> agg.getGroupByExpressions().size() == 0)
+                    .when(agg -> agg.getGroupByExpressions().isEmpty())
                     .when(agg -> {
                         Set<AggregateFunction> funcs = agg.getAggregateFunctions();
                         return !funcs.isEmpty() && funcs.stream().allMatch(f -> f instanceof Count && !f.isDistinct());
@@ -154,7 +153,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                                     Expression childExpr = filter.getConjuncts().iterator().next().children().get(0);
                                     if (childExpr instanceof SlotReference) {
                                         Optional<Column> column = ((SlotReference) childExpr).getColumn();
-                                        return column.isPresent() ? column.get().isDeleteSignColumn() : false;
+                                        return column.map(Column::isDeleteSignColumn).orElse(false);
                                     }
                                     return false;
                                 })
@@ -187,8 +186,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                                                         .children().get(0);
                                                 if (childExpr instanceof SlotReference) {
                                                     Optional<Column> column = ((SlotReference) childExpr).getColumn();
-                                                    return column.isPresent() ? column.get().isDeleteSignColumn()
-                                                            : false;
+                                                    return column.map(Column::isDeleteSignColumn).orElse(false);
                                                 }
                                                 return false;
                                             }))
@@ -253,12 +251,12 @@ public class AggregateStrategies implements ImplementationRuleFactory {
             ),
             RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
                 basePattern
-                    .when(agg -> agg.getDistinctArguments().size() == 0)
+                    .when(agg -> agg.getDistinctArguments().isEmpty())
                     .thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
             ),
             RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
                 basePattern
-                    .when(agg -> agg.getDistinctArguments().size() == 0)
+                    .when(agg -> agg.getDistinctArguments().isEmpty())
                     .thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
             ),
             // RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
@@ -435,12 +433,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
         boolean onlyContainsSlotOrNumericCastSlot = aggregateFunctions.stream()
                 .map(ExpressionTrait::getArguments)
                 .flatMap(List::stream)
-                .allMatch(argument -> {
-                    if (argument instanceof SlotReference) {
-                        return true;
-                    }
-                    return false;
-                });
+                .allMatch(argument -> argument instanceof SlotReference);
         if (!onlyContainsSlotOrNumericCastSlot) {
             return false;
         }
@@ -457,19 +450,13 @@ public class AggregateStrategies implements ImplementationRuleFactory {
         }
         onlyContainsSlotOrNumericCastSlot = argumentsOfAggregateFunction
                 .stream()
-                .allMatch(argument -> {
-                    if (argument instanceof SlotReference) {
-                        return true;
-                    }
-                    return false;
-                });
+                .allMatch(argument -> argument instanceof SlotReference);
         if (!onlyContainsSlotOrNumericCastSlot) {
             return false;
         }
         Set<SlotReference> aggUsedSlots = ExpressionUtils.collect(argumentsOfAggregateFunction,
                 SlotReference.class::isInstance);
-        List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots,
-                outPutSlots);
+        List<SlotReference> usedSlotInTable = (List<SlotReference>) Project.findProject(aggUsedSlots, outPutSlots);
         for (SlotReference slot : usedSlotInTable) {
             Column column = slot.getColumn().get();
             PrimitiveType colType = column.getType().getPrimitiveType();
@@ -630,7 +617,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
         if (logicalScan instanceof LogicalOlapScan) {
             PhysicalOlapScan physicalScan = (PhysicalOlapScan) new LogicalOlapScanToPhysicalOlapScan()
                     .build()
-                    .transform((LogicalOlapScan) logicalScan, cascadesContext)
+                    .transform(logicalScan, cascadesContext)
                     .get(0);
 
             if (project != null) {
@@ -647,7 +634,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
         } else if (logicalScan instanceof LogicalFileScan) {
             PhysicalFileScan physicalScan = (PhysicalFileScan) new LogicalFileScanToPhysicalFileScan()
                     .build()
-                    .transform((LogicalFileScan) logicalScan, cascadesContext)
+                    .transform(logicalScan, cascadesContext)
                     .get(0);
             if (project != null) {
                 return aggregate.withChildren(ImmutableList.of(
@@ -1193,8 +1180,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
                             return new AggregateExpression(nonDistinct, AggregateParam.LOCAL_RESULT);
                         } else {
                             Alias alias = nonDistinctAggFunctionToAliasPhase1.get(outputChild);
-                            return new AggregateExpression(
-                                    aggregateFunction, bufferToResultParam, alias.toSlot());
+                            return new AggregateExpression(aggregateFunction, bufferToResultParam, alias.toSlot());
                         }
                     } else {
                         return outputChild;
@@ -1582,7 +1568,7 @@ public class AggregateStrategies implements ImplementationRuleFactory {
             return new MultiDistinctCount(function.getArgument(0),
                     function.getArguments().subList(1, function.arity()).toArray(new Expression[0]));
         } else if (function instanceof Sum && function.isDistinct()) {
-            return new MultiDistinctSum(function.getArgument(0));
+            return ((Sum) function).convertToMultiDistinct();
         } else if (function instanceof GroupConcat && function.isDistinct()) {
             return ((GroupConcat) function).convertToMultiDistinct();
         }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
index f0dbd839583..0b00536d6a8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
@@ -78,6 +78,12 @@ public class Sum extends NullableAggregateFunction
         super("sum", distinct, alwaysNullable, arg);
     }
 
+    public MultiDistinctSum convertToMultiDistinct() {
+        Preconditions.checkArgument(distinct,
+                "can't convert to multi_distinct_sum because there is no distinct args");
+        return new MultiDistinctSum(false, alwaysNullable, child());
+    }
+
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType argType = child().getDataType();
diff --git a/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out b/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out
index 1db851c70c1..ffe8f93eb58 100644
--- a/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out
+++ b/regression-test/data/nereids_syntax_p0/agg_with_empty_set.out
@@ -18,3 +18,6 @@
 -- !select6 --
 0	\N	\N	\N	\N
 
+-- !ditinct_sum --
+\N
+
diff --git a/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy b/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy
index 5fc117445a4..bdcb4526a2c 100644
--- a/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy
+++ b/regression-test/suites/nereids_syntax_p0/agg_with_empty_set.groovy
@@ -29,4 +29,5 @@ suite("agg_with_empty_set") {
         (select min(c_custkey) from customer)"""
     qt_select6 """select count(c_custkey), max(c_custkey), min(c_custkey), avg(c_custkey), sum(c_custkey) from customer where c_custkey < 
         (select min(c_custkey) from customer) having min(c_custkey) is null"""
+    qt_ditinct_sum """select sum(distinct ifnull(c_custkey, 0)) from customer where 1 = 0"""
 }
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org