You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by hu...@apache.org on 2022/07/27 07:16:20 UTC
[doris] branch master updated: [feature] (Nereids) add rule to push down predicate through aggregate (#11162)
This is an automated email from the ASF dual-hosted git repository.
huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new daf2e27202 [feature] (Nereids) add rule to push down predicate through aggregate (#11162)
daf2e27202 is described below
commit daf2e2720250416c2908380794c51d8803b7a629
Author: minghong <en...@gmail.com>
AuthorDate: Wed Jul 27 15:16:15 2022 +0800
[feature] (Nereids) add rule to push down predicate through aggregate (#11162)
add rule to push predicates down to aggregation node
add PushDownPredicatesThroughAggregation.java
add ut for PushDownPredicatesThroughAggregation
For example:
```
Logical plan tree:
any_node
|
filter (a>0 and b>0)
|
group by(a, c)
|
scan
```
transformed to:
```
project
|
upper filter (b>0)
|
group by(a, c)
|
bottom filter (a>0)
|
scan
```
Note:
'a>0' could be push down, because 'a' is in group by keys;
but 'b>0' could not push down, because 'b' is not in group by keys.
---
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../logical/PushPredicateThroughAggregation.java | 109 +++++++++++
.../PushDownPredicateThroughAggregationTest.java | 206 +++++++++++++++++++++
3 files changed, 316 insertions(+)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 2485fad8a4..fa6df6d5f7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -43,6 +43,7 @@ public enum RuleType {
COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
// predicate push down rules
PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
+ PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
// column prune rules,
COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java
new file mode 100644
index 0000000000..bc4155bdec
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+
+/**
+ * Push the predicate in the LogicalFilter to the aggregate child.
+ * For example:
+ * Logical plan tree:
+ * any_node
+ * |
+ * filter (a>0 and b>0)
+ * |
+ * group by(a, c)
+ * |
+ * scan
+ * transformed to:
+ * project
+ * |
+ * upper filter (b>0)
+ * |
+ * group by(a, c)
+ * |
+ * bottom filter (a>0)
+ * |
+ * scan
+ * Note:
+ * 'a>0' could be push down, because 'a' is in group by keys;
+ * but 'b>0' could not push down, because 'b' is not in group by keys.
+ *
+ */
+
+public class PushPredicateThroughAggregation extends OneRewriteRuleFactory {
+
+ @Override
+ public Rule build() {
+ return logicalFilter(logicalAggregate()).then(filter -> {
+ LogicalAggregate<GroupPlan> aggregate = filter.child();
+ Set<Slot> groupBySlots = new HashSet<>();
+ for (Expression groupByExpression : aggregate.getGroupByExpressionList()) {
+ if (groupByExpression instanceof Slot) {
+ groupBySlots.add((Slot) groupByExpression);
+ }
+ }
+ List<Expression> pushDownPredicates = Lists.newArrayList();
+ List<Expression> filterPredicates = Lists.newArrayList();
+ ExpressionUtils.extractConjunct(filter.getPredicates()).forEach(conjunct -> {
+ Set<Slot> conjunctSlots = SlotExtractor.extractSlot(conjunct);
+ if (groupBySlots.containsAll(conjunctSlots)) {
+ pushDownPredicates.add(conjunct);
+ } else {
+ filterPredicates.add(conjunct);
+ }
+ });
+
+ return pushDownPredicate(filter, aggregate, pushDownPredicates, filterPredicates);
+ }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION);
+ }
+
+ private Plan pushDownPredicate(LogicalFilter filter, LogicalAggregate aggregate,
+ List<Expression> pushDownPredicates, List<Expression> filterPredicates) {
+ if (pushDownPredicates.size() == 0) {
+ //nothing pushed down, just return origin plan
+ return filter;
+ }
+ LogicalFilter bottomFilter = new LogicalFilter(ExpressionUtils.and(pushDownPredicates),
+ (Plan) aggregate.child(0));
+ if (filterPredicates.isEmpty()) {
+ //all predicates are pushed down, just exchange filter and aggregate
+ return aggregate.withChildren(Lists.newArrayList(bottomFilter));
+ } else {
+ aggregate = aggregate.withChildren(Lists.newArrayList(bottomFilter));
+ return new LogicalFilter<>(ExpressionUtils.and(filterPredicates), aggregate);
+ }
+ }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java
new file mode 100644
index 0000000000..9980f246ad
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+
+
+import org.apache.doris.catalog.AggregateType;
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.Table;
+import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.memo.Memo;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.Literal;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanRewriter;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.List;
+
+public class PushDownPredicateThroughAggregationTest {
+
+ /**
+ * origin plan:
+ * project
+ * |
+ * filter gender=1
+ * |
+ * aggregation group by gender
+ * |
+ * scan(student)
+ *
+ * transformed plan:
+ * project
+ * |
+ * aggregation group by gender
+ * |
+ * filter gender=1
+ * |
+ * scan(student)
+ */
+ @Test
+ public void pushDownPredicateOneFilterTest() {
+ Table student = new Table(0L, "student", Table.TableType.OLAP,
+ ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
+ new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""),
+ new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
+ new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
+ Plan scan = new LogicalOlapScan(student, ImmutableList.of("student"));
+ Slot gender = scan.getOutput().get(1);
+ Slot age = scan.getOutput().get(3);
+
+ List<Expression> groupByKeys = Lists.newArrayList(age, gender);
+ List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
+ Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
+ Expression filterPredicate = new GreaterThan(gender, Literal.of(1));
+ LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
+ Plan root = new LogicalProject<>(
+ Lists.newArrayList(gender),
+ filter
+ );
+
+ Memo memo = rewrite(root);
+ System.out.println(memo.copyOut().treeString());
+ Group rootGroup = memo.getRoot();
+
+ GroupExpression groupExpression = rootGroup
+ .getLogicalExpression().child(0)
+ .getLogicalExpression();
+ aggregation = groupExpression.getPlan();
+ Assert.assertTrue(aggregation instanceof LogicalAggregate);
+
+ groupExpression = groupExpression.child(0).getLogicalExpression();
+ Plan bottomFilter = groupExpression.getPlan();
+ Assert.assertTrue(bottomFilter instanceof LogicalFilter);
+ Expression greater = ((LogicalFilter<?>) bottomFilter).getPredicates();
+ Assert.assertTrue(greater instanceof GreaterThan);
+ Assert.assertTrue(greater.child(0) instanceof Slot);
+ Assert.assertEquals("gender", ((Slot) greater.child(0)).getName());
+
+ groupExpression = groupExpression.child(0).getLogicalExpression();
+ Plan scan2 = groupExpression.getPlan();
+ Assert.assertTrue(scan2 instanceof LogicalOlapScan);
+ }
+
+ /**
+ * origin plan:
+ * project
+ * |
+ * filter gender=1 and name="abc" and (gender+10)<100
+ * |
+ * aggregation group by gender
+ * |
+ * scan(student)
+ *
+ * transformed plan:
+ * project
+ * |
+ * filter name="abc"
+ * |
+ * aggregation group by gender
+ * |
+ * filter gender=1 and and (gender+10)<100
+ * |
+ * scan(student)
+ */
+ @Test
+ public void pushDownPredicateTwoFilterTest() {
+ Table student = new Table(0L, "student", Table.TableType.OLAP,
+ ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
+ new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""),
+ new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
+ new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
+ Plan scan = new LogicalOlapScan(student, ImmutableList.of("student"));
+ Slot gender = scan.getOutput().get(1);
+ Slot name = scan.getOutput().get(2);
+ Slot age = scan.getOutput().get(3);
+
+ List<Expression> groupByKeys = Lists.newArrayList(age, gender);
+ List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
+ Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
+ Expression filterPredicate = ExpressionUtils.and(
+ new GreaterThan(gender, Literal.of(1)),
+ new LessThanEqual(
+ new Add(
+ gender,
+ Literal.of(10)
+ ),
+ Literal.of(100)
+ ),
+ new EqualTo(name, Literal.of("abc"))
+ );
+ LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
+ Plan root = new LogicalProject<>(
+ Lists.newArrayList(gender),
+ filter
+ );
+
+ Memo memo = rewrite(root);
+ System.out.println(memo.copyOut().treeString());
+ Group rootGroup = memo.getRoot();
+ GroupExpression groupExpression = rootGroup.getLogicalExpression().child(0).getLogicalExpression();
+ Plan upperFilter = groupExpression.getPlan();
+ Assert.assertTrue(upperFilter instanceof LogicalFilter);
+ Expression upperPredicates = ((LogicalFilter<?>) upperFilter).getPredicates();
+ Assert.assertTrue(upperPredicates instanceof EqualTo);
+ Assert.assertTrue(upperPredicates.child(0) instanceof Slot);
+ groupExpression = groupExpression.child(0).getLogicalExpression();
+ aggregation = groupExpression.getPlan();
+ Assert.assertTrue(aggregation instanceof LogicalAggregate);
+ groupExpression = groupExpression.child(0).getLogicalExpression();
+ Plan bottomFilter = groupExpression.getPlan();
+ Assert.assertTrue(bottomFilter instanceof LogicalFilter);
+ Expression bottomPredicates = ((LogicalFilter<?>) bottomFilter).getPredicates();
+ Assert.assertTrue(bottomPredicates instanceof And);
+ Assert.assertEquals(2, bottomPredicates.children().size());
+ Expression greater = bottomPredicates.child(0);
+ Assert.assertTrue(greater instanceof GreaterThan);
+ Assert.assertTrue(greater.child(0) instanceof Slot);
+ Assert.assertEquals("gender", ((Slot) greater.child(0)).getName());
+ Expression less = bottomPredicates.child(1);
+ Assert.assertTrue(less instanceof LessThanEqual);
+ Assert.assertTrue(less.child(0) instanceof Add);
+
+ groupExpression = groupExpression.child(0).getLogicalExpression();
+ Plan scan2 = groupExpression.getPlan();
+ Assert.assertTrue(scan2 instanceof LogicalOlapScan);
+ }
+
+ private Memo rewrite(Plan plan) {
+ return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicateThroughAggregation());
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org