You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ja...@apache.org on 2023/04/12 04:13:15 UTC

[doris] branch master updated: [feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556)

This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 39a7a4cc55 [feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556)
39a7a4cc55 is described below

commit 39a7a4cc55c9e3b00a6465d99f85db911b92d234
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Apr 12 12:13:06 2023 +0800

    [feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556)
---
 .../org/apache/doris/nereids/rules/RuleType.java   |   4 +
 .../nereids/rules/exploration/EagerCount.java      |   4 +-
 .../nereids/rules/exploration/EagerGroupBy.java    |   4 +-
 .../rules/exploration/EagerGroupByCount.java       | 138 +++++++++++++++++
 .../nereids/rules/exploration/EagerSplit.java      | 164 +++++++++++++++++++++
 .../rules/exploration/EagerGroupByCountTest.java   | 101 +++++++++++++
 .../nereids/rules/exploration/EagerSplitTest.java  | 102 +++++++++++++
 7 files changed, 514 insertions(+), 3 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b14d8b5029..9f1234803d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -241,6 +241,10 @@ public enum RuleType {
     PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
     EAGER_COUNT(RuleTypeClass.EXPLORATION),
     EAGER_GROUP_BY(RuleTypeClass.EXPLORATION),
+    EAGER_GROUP_BY_COUNT(RuleTypeClass.EXPLORATION),
+    EAGER_SPLIT(RuleTypeClass.EXPLORATION),
+
+    EXPLORATION_SENTINEL(RuleTypeClass.EXPLORATION),
 
     // implementation rules
     LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java
index add6d9ccc2..cde94b33eb 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java
@@ -49,7 +49,7 @@ import java.util.Set;
  * |    *
  * (x)
  * ->
- * aggregate: SUM(x * cnt)
+ * aggregate: SUM(x) * cnt
  * |
  * join
  * |   \
@@ -62,7 +62,7 @@ public class EagerCount extends OneExplorationRuleFactory {
 
     @Override
     public Rule build() {
-        return logicalAggregate(logicalJoin())
+        return logicalAggregate(innerLogicalJoin())
                 .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
                 .when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
                 .when(agg -> agg.getAggregateFunctions().stream()
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java
index 0db10dd1ee..27fcd149b2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java
@@ -54,13 +54,15 @@ import java.util.stream.Collectors;
  * |    *
  * aggregate: SUM(x) as sum1
  * </pre>
+ * After Eager Group By, new plan also can apply `Eager Count`.
+ * It's `Double Eager`.
  */
 public class EagerGroupBy extends OneExplorationRuleFactory {
     public static final EagerGroupBy INSTANCE = new EagerGroupBy();
 
     @Override
     public Rule build() {
-        return logicalAggregate(logicalJoin())
+        return logicalAggregate(innerLogicalJoin())
                 .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
                 .when(agg -> agg.getAggregateFunctions().stream()
                         .allMatch(f -> f instanceof Sum
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java
new file mode 100644
index 0000000000..c538250538
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Related paper "Eager aggregation and lazy aggregation".
+ * <pre>
+ * aggregate: SUM(x), SUM(y)
+ * |
+ * join
+ * |   \
+ * |   (y)
+ * (x)
+ * ->
+ * aggregate: SUM(sum1), SUM(y) * cnt
+ * |
+ * join
+ * |   \
+ * |   (y)
+ * aggregate: SUM(x) as sum1 , COUNT as cnt
+ * </pre>
+ */
+public class EagerGroupByCount extends OneExplorationRuleFactory {
+    public static final EagerGroupByCount INSTANCE = new EagerGroupByCount();
+
+    @Override
+    public Rule build() {
+        return logicalAggregate(innerLogicalJoin())
+                .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
+                .when(agg -> agg.getAggregateFunctions().stream()
+                        .allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof Slot))
+                .then(agg -> {
+                    LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
+                    List<Slot> leftOutput = join.left().getOutput();
+                    List<Sum> leftSums = new ArrayList<>();
+                    List<Sum> rightSums = new ArrayList<>();
+                    for (AggregateFunction f : agg.getAggregateFunctions()) {
+                        Sum sum = (Sum) f;
+                        if (leftOutput.contains((Slot) sum.child())) {
+                            leftSums.add(sum);
+                        } else {
+                            rightSums.add(sum);
+                        }
+                    }
+                    if (leftSums.size() == 0 || rightSums.size() == 0) {
+                        return null;
+                    }
+
+                    // left bottom agg
+                    Set<Slot> bottomAggGroupBy = new HashSet<>();
+                    agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains)
+                            .forEach(bottomAggGroupBy::add);
+                    join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
+                        if (leftOutput.contains(slot)) {
+                            bottomAggGroupBy.add(slot);
+                        }
+                    }));
+                    List<NamedExpression> bottomSums = new ArrayList<>();
+                    for (int i = 0; i < leftSums.size(); i++) {
+                        bottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "sum" + i));
+                    }
+                    Alias cnt = new Alias(new Count(Literal.of(1)), "cnt");
+                    List<NamedExpression> bottomAggOutput = ImmutableList.<NamedExpression>builder()
+                            .addAll(bottomAggGroupBy).addAll(bottomSums).add(cnt).build();
+                    LogicalAggregate<GroupPlan> bottomAgg = new LogicalAggregate<>(
+                            ImmutableList.copyOf(bottomAggGroupBy), bottomAggOutput, join.left());
+                    Plan newJoin = join.withChildren(bottomAgg, join.right());
+
+                    // top agg
+                    List<NamedExpression> newOutputExprs = new ArrayList<>();
+                    List<Alias> leftSumOutputExprs = new ArrayList<>();
+                    List<Alias> rightSumOutputExprs = new ArrayList<>();
+                    for (NamedExpression ne : agg.getOutputExpressions()) {
+                        if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
+                            Alias sumOutput = (Alias) ne;
+                            Slot child = (Slot) ((Sum) (sumOutput).child()).child();
+                            if (leftOutput.contains(child)) {
+                                leftSumOutputExprs.add(sumOutput);
+                            } else {
+                                rightSumOutputExprs.add(sumOutput);
+                            }
+                        } else {
+                            newOutputExprs.add(ne);
+                        }
+                    }
+                    for (int i = 0; i < leftSumOutputExprs.size(); i++) {
+                        Alias oldSum = leftSumOutputExprs.get(i);
+                        // sum in bottom Agg
+                        Slot bottomSum = bottomSums.get(i).toSlot();
+                        Alias newSum = new Alias(oldSum.getExprId(), new Sum(bottomSum), oldSum.getName());
+                        newOutputExprs.add(newSum);
+                    }
+                    for (Alias oldSum : rightSumOutputExprs) {
+                        Sum oldSumFunc = (Sum) oldSum.child();
+                        newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()),
+                                oldSum.getName()));
+                    }
+                    return agg.withAggOutput(newOutputExprs).withChildren(newJoin);
+                }).toRule(RuleType.EAGER_GROUP_BY_COUNT);
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java
new file mode 100644
index 0000000000..abf6dabad8
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java
@@ -0,0 +1,164 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Related paper "Eager aggregation and lazy aggregation".
+ * <pre>
+ * aggregate: SUM(x), SUM(y)
+ * |
+ * join
+ * |   \
+ * |   (y)
+ * (x)
+ * ->
+ * aggregate: SUM(sum1) * cnt2, SUM(sum2) * cnt1
+ * |
+ * join
+ * |   \
+ * |   aggregate: SUM(y) as sum2, COUNT: cnt2
+ * aggregate: SUM(x) as sum1, COUNT: cnt1
+ * </pre>
+ */
+public class EagerSplit extends OneExplorationRuleFactory {
+    public static final EagerSplit INSTANCE = new EagerSplit();
+
+    @Override
+    public Rule build() {
+        return logicalAggregate(innerLogicalJoin())
+                .when(agg -> agg.getAggregateFunctions().stream()
+                        .allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof SlotReference))
+                .then(agg -> {
+                    LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
+                    List<Slot> leftOutput = join.left().getOutput();
+                    List<Slot> rightOutput = join.right().getOutput();
+                    List<Sum> leftSums = new ArrayList<>();
+                    List<Sum> rightSums = new ArrayList<>();
+                    for (AggregateFunction f : agg.getAggregateFunctions()) {
+                        Sum sum = (Sum) f;
+                        if (leftOutput.contains((Slot) sum.child())) {
+                            leftSums.add(sum);
+                        } else {
+                            rightSums.add(sum);
+                        }
+                    }
+                    if (leftSums.size() == 0 || rightSums.size() == 0) {
+                        return null;
+                    }
+
+                    // left bottom agg
+                    Set<Slot> leftBottomAggGroupBy = new HashSet<>();
+                    agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains)
+                            .forEach(leftBottomAggGroupBy::add);
+                    join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
+                        if (leftOutput.contains(slot)) {
+                            leftBottomAggGroupBy.add(slot);
+                        }
+                    }));
+                    List<NamedExpression> leftBottomSums = new ArrayList<>();
+                    for (int i = 0; i < leftSums.size(); i++) {
+                        leftBottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "left_sum" + i));
+                    }
+                    Alias leftCnt = new Alias(new Count(Literal.of(1)), "left_cnt");
+                    List<NamedExpression> leftBottomAggOutput = ImmutableList.<NamedExpression>builder()
+                            .addAll(leftBottomAggGroupBy).addAll(leftBottomSums).add(leftCnt).build();
+                    LogicalAggregate<GroupPlan> leftBottomAgg = new LogicalAggregate<>(
+                            ImmutableList.copyOf(leftBottomAggGroupBy), leftBottomAggOutput, join.left());
+
+                    // right bottom agg
+                    Set<Slot> rightBottomAggGroupBy = new HashSet<>();
+                    agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(rightOutput::contains)
+                            .forEach(rightBottomAggGroupBy::add);
+                    join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
+                        if (rightOutput.contains(slot)) {
+                            rightBottomAggGroupBy.add(slot);
+                        }
+                    }));
+                    List<NamedExpression> rightBottomSums = new ArrayList<>();
+                    for (int i = 0; i < rightSums.size(); i++) {
+                        rightBottomSums.add(new Alias(new Sum(rightSums.get(i).child()), "right_sum" + i));
+                    }
+                    Alias rightCnt = new Alias(new Count(Literal.of(1)), "right_cnt");
+                    List<NamedExpression> rightBottomAggOutput = ImmutableList.<NamedExpression>builder()
+                            .addAll(rightBottomAggGroupBy).addAll(rightBottomSums).add(rightCnt).build();
+                    LogicalAggregate<GroupPlan> rightBottomAgg = new LogicalAggregate<>(
+                            ImmutableList.copyOf(rightBottomAggGroupBy), rightBottomAggOutput, join.right());
+
+                    Plan newJoin = join.withChildren(leftBottomAgg, rightBottomAgg);
+
+                    // top agg
+                    List<NamedExpression> newOutputExprs = new ArrayList<>();
+                    List<Alias> leftSumOutputExprs = new ArrayList<>();
+                    List<Alias> rightSumOutputExprs = new ArrayList<>();
+                    for (NamedExpression ne : agg.getOutputExpressions()) {
+                        if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
+                            Alias sumOutput = (Alias) ne;
+                            Slot child = (Slot) ((Sum) (sumOutput).child()).child();
+                            if (leftOutput.contains(child)) {
+                                leftSumOutputExprs.add(sumOutput);
+                            } else {
+                                rightSumOutputExprs.add(sumOutput);
+                            }
+                        } else {
+                            newOutputExprs.add(ne);
+                        }
+                    }
+                    Preconditions.checkState(leftSumOutputExprs.size() == leftBottomSums.size());
+                    Preconditions.checkState(rightSumOutputExprs.size() == rightBottomSums.size());
+                    for (int i = 0; i < leftSumOutputExprs.size(); i++) {
+                        Alias oldSum = leftSumOutputExprs.get(i);
+                        Slot bottomSum = leftBottomSums.get(i).toSlot();
+                        Alias newSum = new Alias(oldSum.getExprId(),
+                                new Multiply(new Sum(bottomSum), rightCnt.toSlot()), oldSum.getName());
+                        newOutputExprs.add(newSum);
+                    }
+                    for (int i = 0; i < rightSumOutputExprs.size(); i++) {
+                        Alias oldSum = rightSumOutputExprs.get(i);
+                        Slot bottomSum = rightBottomSums.get(i).toSlot();
+                        Alias newSum = new Alias(oldSum.getExprId(),
+                                new Multiply(new Sum(bottomSum), leftCnt.toSlot()), oldSum.getName());
+                        newOutputExprs.add(newSum);
+                    }
+                    return agg.withAggOutput(newOutputExprs).withChildren(newJoin);
+                }).toRule(RuleType.EAGER_SPLIT);
+    }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java
new file mode 100644
index 0000000000..de132d22d2
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+class EagerGroupByCountTest implements MemoPatternMatchSupported {
+
+    private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+            PlanConstructor.student, ImmutableList.of(""));
+    private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+            PlanConstructor.score, ImmutableList.of(""));
+
+    @Test
+    void singleSum() {
+        LogicalPlan agg = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0, 4),
+                        ImmutableList.of(
+                                new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"),
+                                new Alias(new Sum(scan2.getOutput().get(2)), "rsum0")
+                        ))
+                .build();
+        PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+                .applyExploration(EagerGroupByCount.INSTANCE.build())
+                .printlnOrigin()
+                .printlnExploration()
+                .matchesExploration(
+                        logicalAggregate(
+                                logicalJoin(
+                                        logicalAggregate().when(
+                                                bottomAgg -> bottomAgg.getOutputExprsSql().equals("id, sum(age) AS `sum0`, count(1) AS `cnt`")),
+                                        logicalOlapScan()
+                                )
+                        ).when(newAgg ->
+                                newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+                                        && newAgg.getOutputExprsSql().equals("sum(sum0) AS `lsum0`, (sum(grade) * cnt) AS `rsum0`"))
+                );
+    }
+
+    @Test
+    void multiSum() {
+        LogicalPlan agg = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0, 4),
+                        ImmutableList.of(
+                                new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"),
+                                new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"),
+                                new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"),
+                                new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"),
+                                new Alias(new Sum(scan2.getOutput().get(2)), "rsum1")
+                        ))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+                .applyExploration(EagerGroupByCount.INSTANCE.build())
+                .printlnOrigin()
+                .printlnExploration()
+                .matchesExploration(
+                        logicalAggregate(
+                                logicalJoin(
+                                        logicalAggregate().when(cntAgg -> cntAgg.getOutputExprsSql()
+                                                .equals("id, sum(gender) AS `sum0`, sum(name) AS `sum1`, sum(age) AS `sum2`, count(1) AS `cnt`")),
+                                        logicalOlapScan()
+                                )
+                        ).when(newAgg ->
+                                newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+                                        && newAgg.getOutputExprsSql()
+                                        .equals("sum(sum0) AS `lsum0`, sum(sum1) AS `lsum1`, sum(sum2) AS `lsum2`, (sum(cid) * cnt) AS `rsum0`, (sum(grade) * cnt) AS `rsum1`"))
+                );
+    }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java
new file mode 100644
index 0000000000..37e347894f
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+class EagerSplitTest implements MemoPatternMatchSupported {
+
+    private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+            PlanConstructor.student, ImmutableList.of(""));
+    private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+            PlanConstructor.score, ImmutableList.of(""));
+
+    @Test
+    void singleSum() {
+        LogicalPlan agg = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0, 4),
+                        ImmutableList.of(
+                                new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"),
+                                new Alias(new Sum(scan2.getOutput().get(2)), "rsum0")
+                        ))
+                .build();
+        PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+                .applyExploration(EagerSplit.INSTANCE.build())
+                .printlnOrigin()
+                .printlnExploration()
+                .matchesExploration(
+                        logicalAggregate(
+                                logicalJoin(
+                                        logicalAggregate().when(
+                                                a -> a.getOutputExprsSql().equals("id, sum(age) AS `left_sum0`, count(1) AS `left_cnt`")),
+                                        logicalAggregate().when(
+                                                a -> a.getOutputExprsSql().equals("sid, sum(grade) AS `right_sum0`, count(1) AS `right_cnt`"))
+                                )
+                        ).when(newAgg ->
+                                newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+                                        && newAgg.getOutputExprsSql().equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(right_sum0) * left_cnt) AS `rsum0`"))
+                );
+    }
+
+    @Test
+    void multiSum() {
+        LogicalPlan agg = new LogicalPlanBuilder(scan1)
+                .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0, 4),
+                        ImmutableList.of(
+                                new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"),
+                                new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"),
+                                new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"),
+                                new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"),
+                                new Alias(new Sum(scan2.getOutput().get(2)), "rsum1")
+                        ))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+                .applyExploration(EagerSplit.INSTANCE.build())
+                .printlnExploration()
+                .matchesExploration(
+                        logicalAggregate(
+                                logicalJoin(
+                                        logicalAggregate().when(a -> a.getOutputExprsSql()
+                                                .equals("id, sum(gender) AS `left_sum0`, sum(name) AS `left_sum1`, sum(age) AS `left_sum2`, count(1) AS `left_cnt`")),
+                                        logicalAggregate().when(a -> a.getOutputExprsSql()
+                                                .equals("sid, sum(cid) AS `right_sum0`, sum(grade) AS `right_sum1`, count(1) AS `right_cnt`"))
+                                )
+                        ).when(newAgg ->
+                                newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+                                        && newAgg.getOutputExprsSql()
+                                        .equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(left_sum1) * right_cnt) AS `lsum1`, (sum(left_sum2) * right_cnt) AS `lsum2`, (sum(right_sum0) * left_cnt) AS `rsum0`, (sum(right_sum1) * left_cnt) AS `rsum1`"))
+                );
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org