You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ja...@apache.org on 2023/04/12 04:13:15 UTC
[doris] branch master updated: [feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556)
This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 39a7a4cc55 [feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556)
39a7a4cc55 is described below
commit 39a7a4cc55c9e3b00a6465d99f85db911b92d234
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Apr 12 12:13:06 2023 +0800
[feat](Nereids): a new CBO rule: Eager Split/GroupByCount (#18556)
---
.../org/apache/doris/nereids/rules/RuleType.java | 4 +
.../nereids/rules/exploration/EagerCount.java | 4 +-
.../nereids/rules/exploration/EagerGroupBy.java | 4 +-
.../rules/exploration/EagerGroupByCount.java | 138 +++++++++++++++++
.../nereids/rules/exploration/EagerSplit.java | 164 +++++++++++++++++++++
.../rules/exploration/EagerGroupByCountTest.java | 101 +++++++++++++
.../nereids/rules/exploration/EagerSplitTest.java | 102 +++++++++++++
7 files changed, 514 insertions(+), 3 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b14d8b5029..9f1234803d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -241,6 +241,10 @@ public enum RuleType {
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
EAGER_COUNT(RuleTypeClass.EXPLORATION),
EAGER_GROUP_BY(RuleTypeClass.EXPLORATION),
+ EAGER_GROUP_BY_COUNT(RuleTypeClass.EXPLORATION),
+ EAGER_SPLIT(RuleTypeClass.EXPLORATION),
+
+ EXPLORATION_SENTINEL(RuleTypeClass.EXPLORATION),
// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java
index add6d9ccc2..cde94b33eb 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerCount.java
@@ -49,7 +49,7 @@ import java.util.Set;
* | *
* (x)
* ->
- * aggregate: SUM(x * cnt)
+ * aggregate: SUM(x) * cnt
* |
* join
* | \
@@ -62,7 +62,7 @@ public class EagerCount extends OneExplorationRuleFactory {
@Override
public Rule build() {
- return logicalAggregate(logicalJoin())
+ return logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.getGroupByExpressions().stream().allMatch(e -> e instanceof Slot))
.when(agg -> agg.getAggregateFunctions().stream()
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java
index 0db10dd1ee..27fcd149b2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupBy.java
@@ -54,13 +54,15 @@ import java.util.stream.Collectors;
* | *
* aggregate: SUM(x) as sum1
* </pre>
+ * After Eager Group By, new plan also can apply `Eager Count`.
+ * It's `Double Eager`.
*/
public class EagerGroupBy extends OneExplorationRuleFactory {
public static final EagerGroupBy INSTANCE = new EagerGroupBy();
@Override
public Rule build() {
- return logicalAggregate(logicalJoin())
+ return logicalAggregate(innerLogicalJoin())
.when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
.when(agg -> agg.getAggregateFunctions().stream()
.allMatch(f -> f instanceof Sum
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java
new file mode 100644
index 0000000000..c538250538
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCount.java
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Related paper "Eager aggregation and lazy aggregation".
+ * <pre>
+ * aggregate: SUM(x), SUM(y)
+ * |
+ * join
+ * | \
+ * | (y)
+ * (x)
+ * ->
+ * aggregate: SUM(sum1), SUM(y) * cnt
+ * |
+ * join
+ * | \
+ * | (y)
+ * aggregate: SUM(x) as sum1 , COUNT as cnt
+ * </pre>
+ */
+public class EagerGroupByCount extends OneExplorationRuleFactory {
+ public static final EagerGroupByCount INSTANCE = new EagerGroupByCount();
+
+ @Override
+ public Rule build() {
+ return logicalAggregate(innerLogicalJoin())
+ .when(agg -> agg.child().getOtherJoinConjuncts().size() == 0)
+ .when(agg -> agg.getAggregateFunctions().stream()
+ .allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof Slot))
+ .then(agg -> {
+ LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
+ List<Slot> leftOutput = join.left().getOutput();
+ List<Sum> leftSums = new ArrayList<>();
+ List<Sum> rightSums = new ArrayList<>();
+ for (AggregateFunction f : agg.getAggregateFunctions()) {
+ Sum sum = (Sum) f;
+ if (leftOutput.contains((Slot) sum.child())) {
+ leftSums.add(sum);
+ } else {
+ rightSums.add(sum);
+ }
+ }
+ if (leftSums.size() == 0 || rightSums.size() == 0) {
+ return null;
+ }
+
+ // left bottom agg
+ Set<Slot> bottomAggGroupBy = new HashSet<>();
+ agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains)
+ .forEach(bottomAggGroupBy::add);
+ join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
+ if (leftOutput.contains(slot)) {
+ bottomAggGroupBy.add(slot);
+ }
+ }));
+ List<NamedExpression> bottomSums = new ArrayList<>();
+ for (int i = 0; i < leftSums.size(); i++) {
+ bottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "sum" + i));
+ }
+ Alias cnt = new Alias(new Count(Literal.of(1)), "cnt");
+ List<NamedExpression> bottomAggOutput = ImmutableList.<NamedExpression>builder()
+ .addAll(bottomAggGroupBy).addAll(bottomSums).add(cnt).build();
+ LogicalAggregate<GroupPlan> bottomAgg = new LogicalAggregate<>(
+ ImmutableList.copyOf(bottomAggGroupBy), bottomAggOutput, join.left());
+ Plan newJoin = join.withChildren(bottomAgg, join.right());
+
+ // top agg
+ List<NamedExpression> newOutputExprs = new ArrayList<>();
+ List<Alias> leftSumOutputExprs = new ArrayList<>();
+ List<Alias> rightSumOutputExprs = new ArrayList<>();
+ for (NamedExpression ne : agg.getOutputExpressions()) {
+ if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
+ Alias sumOutput = (Alias) ne;
+ Slot child = (Slot) ((Sum) (sumOutput).child()).child();
+ if (leftOutput.contains(child)) {
+ leftSumOutputExprs.add(sumOutput);
+ } else {
+ rightSumOutputExprs.add(sumOutput);
+ }
+ } else {
+ newOutputExprs.add(ne);
+ }
+ }
+ for (int i = 0; i < leftSumOutputExprs.size(); i++) {
+ Alias oldSum = leftSumOutputExprs.get(i);
+ // sum in bottom Agg
+ Slot bottomSum = bottomSums.get(i).toSlot();
+ Alias newSum = new Alias(oldSum.getExprId(), new Sum(bottomSum), oldSum.getName());
+ newOutputExprs.add(newSum);
+ }
+ for (Alias oldSum : rightSumOutputExprs) {
+ Sum oldSumFunc = (Sum) oldSum.child();
+ newOutputExprs.add(new Alias(oldSum.getExprId(), new Multiply(oldSumFunc, cnt.toSlot()),
+ oldSum.getName()));
+ }
+ return agg.withAggOutput(newOutputExprs).withChildren(newJoin);
+ }).toRule(RuleType.EAGER_GROUP_BY_COUNT);
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java
new file mode 100644
index 0000000000..abf6dabad8
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/EagerSplit.java
@@ -0,0 +1,164 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/**
+ * Related paper "Eager aggregation and lazy aggregation".
+ * <pre>
+ * aggregate: SUM(x), SUM(y)
+ * |
+ * join
+ * | \
+ * | (y)
+ * (x)
+ * ->
+ * aggregate: SUM(sum1) * cnt2, SUM(sum2) * cnt1
+ * |
+ * join
+ * | \
+ * | aggregate: SUM(y) as sum2, COUNT: cnt2
+ * aggregate: SUM(x) as sum1, COUNT: cnt1
+ * </pre>
+ */
+public class EagerSplit extends OneExplorationRuleFactory {
+ public static final EagerSplit INSTANCE = new EagerSplit();
+
+ @Override
+ public Rule build() {
+ return logicalAggregate(innerLogicalJoin())
+ .when(agg -> agg.getAggregateFunctions().stream()
+ .allMatch(f -> f instanceof Sum && ((Sum) f).child() instanceof SlotReference))
+ .then(agg -> {
+ LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
+ List<Slot> leftOutput = join.left().getOutput();
+ List<Slot> rightOutput = join.right().getOutput();
+ List<Sum> leftSums = new ArrayList<>();
+ List<Sum> rightSums = new ArrayList<>();
+ for (AggregateFunction f : agg.getAggregateFunctions()) {
+ Sum sum = (Sum) f;
+ if (leftOutput.contains((Slot) sum.child())) {
+ leftSums.add(sum);
+ } else {
+ rightSums.add(sum);
+ }
+ }
+ if (leftSums.size() == 0 || rightSums.size() == 0) {
+ return null;
+ }
+
+ // left bottom agg
+ Set<Slot> leftBottomAggGroupBy = new HashSet<>();
+ agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(leftOutput::contains)
+ .forEach(leftBottomAggGroupBy::add);
+ join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
+ if (leftOutput.contains(slot)) {
+ leftBottomAggGroupBy.add(slot);
+ }
+ }));
+ List<NamedExpression> leftBottomSums = new ArrayList<>();
+ for (int i = 0; i < leftSums.size(); i++) {
+ leftBottomSums.add(new Alias(new Sum(leftSums.get(i).child()), "left_sum" + i));
+ }
+ Alias leftCnt = new Alias(new Count(Literal.of(1)), "left_cnt");
+ List<NamedExpression> leftBottomAggOutput = ImmutableList.<NamedExpression>builder()
+ .addAll(leftBottomAggGroupBy).addAll(leftBottomSums).add(leftCnt).build();
+ LogicalAggregate<GroupPlan> leftBottomAgg = new LogicalAggregate<>(
+ ImmutableList.copyOf(leftBottomAggGroupBy), leftBottomAggOutput, join.left());
+
+ // right bottom agg
+ Set<Slot> rightBottomAggGroupBy = new HashSet<>();
+ agg.getGroupByExpressions().stream().map(e -> (Slot) e).filter(rightOutput::contains)
+ .forEach(rightBottomAggGroupBy::add);
+ join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
+ if (rightOutput.contains(slot)) {
+ rightBottomAggGroupBy.add(slot);
+ }
+ }));
+ List<NamedExpression> rightBottomSums = new ArrayList<>();
+ for (int i = 0; i < rightSums.size(); i++) {
+ rightBottomSums.add(new Alias(new Sum(rightSums.get(i).child()), "right_sum" + i));
+ }
+ Alias rightCnt = new Alias(new Count(Literal.of(1)), "right_cnt");
+ List<NamedExpression> rightBottomAggOutput = ImmutableList.<NamedExpression>builder()
+ .addAll(rightBottomAggGroupBy).addAll(rightBottomSums).add(rightCnt).build();
+ LogicalAggregate<GroupPlan> rightBottomAgg = new LogicalAggregate<>(
+ ImmutableList.copyOf(rightBottomAggGroupBy), rightBottomAggOutput, join.right());
+
+ Plan newJoin = join.withChildren(leftBottomAgg, rightBottomAgg);
+
+ // top agg
+ List<NamedExpression> newOutputExprs = new ArrayList<>();
+ List<Alias> leftSumOutputExprs = new ArrayList<>();
+ List<Alias> rightSumOutputExprs = new ArrayList<>();
+ for (NamedExpression ne : agg.getOutputExpressions()) {
+ if (ne instanceof Alias && ((Alias) ne).child() instanceof Sum) {
+ Alias sumOutput = (Alias) ne;
+ Slot child = (Slot) ((Sum) (sumOutput).child()).child();
+ if (leftOutput.contains(child)) {
+ leftSumOutputExprs.add(sumOutput);
+ } else {
+ rightSumOutputExprs.add(sumOutput);
+ }
+ } else {
+ newOutputExprs.add(ne);
+ }
+ }
+ Preconditions.checkState(leftSumOutputExprs.size() == leftBottomSums.size());
+ Preconditions.checkState(rightSumOutputExprs.size() == rightBottomSums.size());
+ for (int i = 0; i < leftSumOutputExprs.size(); i++) {
+ Alias oldSum = leftSumOutputExprs.get(i);
+ Slot bottomSum = leftBottomSums.get(i).toSlot();
+ Alias newSum = new Alias(oldSum.getExprId(),
+ new Multiply(new Sum(bottomSum), rightCnt.toSlot()), oldSum.getName());
+ newOutputExprs.add(newSum);
+ }
+ for (int i = 0; i < rightSumOutputExprs.size(); i++) {
+ Alias oldSum = rightSumOutputExprs.get(i);
+ Slot bottomSum = rightBottomSums.get(i).toSlot();
+ Alias newSum = new Alias(oldSum.getExprId(),
+ new Multiply(new Sum(bottomSum), leftCnt.toSlot()), oldSum.getName());
+ newOutputExprs.add(newSum);
+ }
+ return agg.withAggOutput(newOutputExprs).withChildren(newJoin);
+ }).toRule(RuleType.EAGER_SPLIT);
+ }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java
new file mode 100644
index 0000000000..de132d22d2
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerGroupByCountTest.java
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+class EagerGroupByCountTest implements MemoPatternMatchSupported {
+
+ private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+ PlanConstructor.student, ImmutableList.of(""));
+ private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+ PlanConstructor.score, ImmutableList.of(""));
+
+ @Test
+ void singleSum() {
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0, 4),
+ ImmutableList.of(
+ new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"),
+ new Alias(new Sum(scan2.getOutput().get(2)), "rsum0")
+ ))
+ .build();
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyExploration(EagerGroupByCount.INSTANCE.build())
+ .printlnOrigin()
+ .printlnExploration()
+ .matchesExploration(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate().when(
+ bottomAgg -> bottomAgg.getOutputExprsSql().equals("id, sum(age) AS `sum0`, count(1) AS `cnt`")),
+ logicalOlapScan()
+ )
+ ).when(newAgg ->
+ newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+ && newAgg.getOutputExprsSql().equals("sum(sum0) AS `lsum0`, (sum(grade) * cnt) AS `rsum0`"))
+ );
+ }
+
+ @Test
+ void multiSum() {
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0, 4),
+ ImmutableList.of(
+ new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"),
+ new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"),
+ new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"),
+ new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"),
+ new Alias(new Sum(scan2.getOutput().get(2)), "rsum1")
+ ))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyExploration(EagerGroupByCount.INSTANCE.build())
+ .printlnOrigin()
+ .printlnExploration()
+ .matchesExploration(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate().when(cntAgg -> cntAgg.getOutputExprsSql()
+ .equals("id, sum(gender) AS `sum0`, sum(name) AS `sum1`, sum(age) AS `sum2`, count(1) AS `cnt`")),
+ logicalOlapScan()
+ )
+ ).when(newAgg ->
+ newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+ && newAgg.getOutputExprsSql()
+ .equals("sum(sum0) AS `lsum0`, sum(sum1) AS `lsum1`, sum(sum2) AS `lsum2`, (sum(cid) * cnt) AS `rsum0`, (sum(grade) * cnt) AS `rsum1`"))
+ );
+ }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java
new file mode 100644
index 0000000000..37e347894f
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/EagerSplitTest.java
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.exploration;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+class EagerSplitTest implements MemoPatternMatchSupported {
+
+ private final LogicalOlapScan scan1 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+ PlanConstructor.student, ImmutableList.of(""));
+ private final LogicalOlapScan scan2 = new LogicalOlapScan(PlanConstructor.getNextRelationId(),
+ PlanConstructor.score, ImmutableList.of(""));
+
+ @Test
+ void singleSum() {
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0, 4),
+ ImmutableList.of(
+ new Alias(new Sum(scan1.getOutput().get(3)), "lsum0"),
+ new Alias(new Sum(scan2.getOutput().get(2)), "rsum0")
+ ))
+ .build();
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyExploration(EagerSplit.INSTANCE.build())
+ .printlnOrigin()
+ .printlnExploration()
+ .matchesExploration(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate().when(
+ a -> a.getOutputExprsSql().equals("id, sum(age) AS `left_sum0`, count(1) AS `left_cnt`")),
+ logicalAggregate().when(
+ a -> a.getOutputExprsSql().equals("sid, sum(grade) AS `right_sum0`, count(1) AS `right_cnt`"))
+ )
+ ).when(newAgg ->
+ newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+ && newAgg.getOutputExprsSql().equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(right_sum0) * left_cnt) AS `rsum0`"))
+ );
+ }
+
+ @Test
+ void multiSum() {
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .aggGroupUsingIndex(ImmutableList.of(0, 4),
+ ImmutableList.of(
+ new Alias(new Sum(scan1.getOutput().get(1)), "lsum0"),
+ new Alias(new Sum(scan1.getOutput().get(2)), "lsum1"),
+ new Alias(new Sum(scan1.getOutput().get(3)), "lsum2"),
+ new Alias(new Sum(scan2.getOutput().get(1)), "rsum0"),
+ new Alias(new Sum(scan2.getOutput().get(2)), "rsum1")
+ ))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyExploration(EagerSplit.INSTANCE.build())
+ .printlnExploration()
+ .matchesExploration(
+ logicalAggregate(
+ logicalJoin(
+ logicalAggregate().when(a -> a.getOutputExprsSql()
+ .equals("id, sum(gender) AS `left_sum0`, sum(name) AS `left_sum1`, sum(age) AS `left_sum2`, count(1) AS `left_cnt`")),
+ logicalAggregate().when(a -> a.getOutputExprsSql()
+ .equals("sid, sum(cid) AS `right_sum0`, sum(grade) AS `right_sum1`, count(1) AS `right_cnt`"))
+ )
+ ).when(newAgg ->
+ newAgg.getGroupByExpressions().equals(((Aggregate) agg).getGroupByExpressions())
+ && newAgg.getOutputExprsSql()
+ .equals("(sum(left_sum0) * right_cnt) AS `lsum0`, (sum(left_sum1) * right_cnt) AS `lsum1`, (sum(left_sum2) * right_cnt) AS `lsum2`, (sum(right_sum0) * left_cnt) AS `rsum0`, (sum(right_sum1) * left_cnt) AS `rsum1`"))
+ );
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org