You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by hu...@apache.org on 2022/07/27 07:16:20 UTC

[doris] branch master updated: [feature] (Nereids) add rule to push down predicate through aggregate (#11162)

This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new daf2e27202 [feature] (Nereids) add rule to push down predicate through aggregate (#11162)
daf2e27202 is described below

commit daf2e2720250416c2908380794c51d8803b7a629
Author: minghong <en...@gmail.com>
AuthorDate: Wed Jul 27 15:16:15 2022 +0800

    [feature] (Nereids) add rule to push down predicate through aggregate (#11162)
    
    add rule to push predicates down to aggregation node
    
    add PushDownPredicatesThroughAggregation.java
    add ut for PushDownPredicatesThroughAggregation
    For example:
    
    ```
      Logical plan tree:
                      any_node
                        |
                     filter (a>0 and b>0)
                        |
                     group by(a, c)
                        |
                      scan
    ```
    transformed to:
    ```
                      project
                        |
                   upper filter (b>0)
                        |
                     group by(a, c)
                        |
                   bottom filter (a>0)
                        |
                      scan
    ```
    
    Note:
    'a>0' could be push down, because 'a' is in group by keys;
    but 'b>0' could not push down, because 'b' is not in group by keys.
---
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../logical/PushPredicateThroughAggregation.java   | 109 +++++++++++
 .../PushDownPredicateThroughAggregationTest.java   | 206 +++++++++++++++++++++
 3 files changed, 316 insertions(+)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 2485fad8a4..fa6df6d5f7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -43,6 +43,7 @@ public enum RuleType {
     COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
     // predicate push down rules
     PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
+    PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION(RuleTypeClass.REWRITE),
     // column prune rules,
     COLUMN_PRUNE_AGGREGATION_CHILD(RuleTypeClass.REWRITE),
     COLUMN_PRUNE_FILTER_CHILD(RuleTypeClass.REWRITE),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java
new file mode 100644
index 0000000000..bc4155bdec
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughAggregation.java
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.Lists;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+
+/**
+ * Push the predicate in the LogicalFilter to the aggregate child.
+ * For example:
+ * Logical plan tree:
+ *                 any_node
+ *                   |
+ *                filter (a>0 and b>0)
+ *                   |
+ *                group by(a, c)
+ *                   |
+ *                 scan
+ * transformed to:
+ *                 project
+ *                   |
+ *              upper filter (b>0)
+ *                   |
+ *                group by(a, c)
+ *                   |
+ *              bottom filter (a>0)
+ *                   |
+ *                 scan
+ * Note:
+ *    'a>0' could be push down, because 'a' is in group by keys;
+ *    but 'b>0' could not push down, because 'b' is not in group by keys.
+ *
+ */
+
+public class PushPredicateThroughAggregation extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule build() {
+        return logicalFilter(logicalAggregate()).then(filter -> {
+            LogicalAggregate<GroupPlan> aggregate = filter.child();
+            Set<Slot> groupBySlots = new HashSet<>();
+            for (Expression groupByExpression : aggregate.getGroupByExpressionList()) {
+                if (groupByExpression instanceof Slot) {
+                    groupBySlots.add((Slot) groupByExpression);
+                }
+            }
+            List<Expression> pushDownPredicates = Lists.newArrayList();
+            List<Expression> filterPredicates = Lists.newArrayList();
+            ExpressionUtils.extractConjunct(filter.getPredicates()).forEach(conjunct -> {
+                Set<Slot> conjunctSlots = SlotExtractor.extractSlot(conjunct);
+                if (groupBySlots.containsAll(conjunctSlots)) {
+                    pushDownPredicates.add(conjunct);
+                } else {
+                    filterPredicates.add(conjunct);
+                }
+            });
+
+            return pushDownPredicate(filter, aggregate, pushDownPredicates, filterPredicates);
+        }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_AGGREGATION);
+    }
+
+    private Plan pushDownPredicate(LogicalFilter filter, LogicalAggregate aggregate,
+                                   List<Expression> pushDownPredicates, List<Expression> filterPredicates) {
+        if (pushDownPredicates.size() == 0) {
+            //nothing pushed down, just return origin plan
+            return filter;
+        }
+        LogicalFilter bottomFilter = new LogicalFilter(ExpressionUtils.and(pushDownPredicates),
+                (Plan) aggregate.child(0));
+        if (filterPredicates.isEmpty()) {
+            //all predicates are pushed down, just exchange filter and aggregate
+            return aggregate.withChildren(Lists.newArrayList(bottomFilter));
+        } else {
+            aggregate = aggregate.withChildren(Lists.newArrayList(bottomFilter));
+            return new LogicalFilter<>(ExpressionUtils.and(filterPredicates), aggregate);
+        }
+    }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java
new file mode 100644
index 0000000000..9980f246ad
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateThroughAggregationTest.java
@@ -0,0 +1,206 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+
+
+import org.apache.doris.catalog.AggregateType;
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.Table;
+import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.memo.Memo;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.Literal;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanRewriter;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.List;
+
+public class PushDownPredicateThroughAggregationTest {
+
+    /**
+    * origin plan:
+    *                project
+    *                  |
+    *                filter gender=1
+    *                  |
+    *               aggregation group by gender
+    *                  |
+    *               scan(student)
+    *
+    *  transformed plan:
+    *                project
+    *                  |
+    *               aggregation group by gender
+    *                  |
+    *               filter gender=1
+    *                  |
+    *               scan(student)
+    */
+    @Test
+    public void pushDownPredicateOneFilterTest() {
+        Table student = new Table(0L, "student", Table.TableType.OLAP,
+                ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
+                        new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""),
+                        new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
+                        new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
+        Plan scan = new LogicalOlapScan(student, ImmutableList.of("student"));
+        Slot gender = scan.getOutput().get(1);
+        Slot age = scan.getOutput().get(3);
+
+        List<Expression> groupByKeys = Lists.newArrayList(age, gender);
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
+        Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
+        Expression filterPredicate = new GreaterThan(gender, Literal.of(1));
+        LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
+        Plan root = new LogicalProject<>(
+                Lists.newArrayList(gender),
+                filter
+        );
+
+        Memo memo = rewrite(root);
+        System.out.println(memo.copyOut().treeString());
+        Group rootGroup = memo.getRoot();
+
+        GroupExpression groupExpression = rootGroup
+                .getLogicalExpression().child(0)
+                .getLogicalExpression();
+        aggregation = groupExpression.getPlan();
+        Assert.assertTrue(aggregation instanceof LogicalAggregate);
+
+        groupExpression = groupExpression.child(0).getLogicalExpression();
+        Plan bottomFilter = groupExpression.getPlan();
+        Assert.assertTrue(bottomFilter instanceof LogicalFilter);
+        Expression greater = ((LogicalFilter<?>) bottomFilter).getPredicates();
+        Assert.assertTrue(greater instanceof GreaterThan);
+        Assert.assertTrue(greater.child(0) instanceof Slot);
+        Assert.assertEquals("gender", ((Slot) greater.child(0)).getName());
+
+        groupExpression = groupExpression.child(0).getLogicalExpression();
+        Plan scan2 = groupExpression.getPlan();
+        Assert.assertTrue(scan2 instanceof LogicalOlapScan);
+    }
+
+    /**
+     * origin plan:
+     *                project
+     *                  |
+     *                filter gender=1 and name="abc" and (gender+10)<100
+     *                  |
+     *               aggregation group by gender
+     *                  |
+     *               scan(student)
+     *
+     *  transformed plan:
+     *                project
+     *                  |
+     *                filter name="abc"
+     *                  |
+     *               aggregation group by gender
+     *                  |
+     *               filter gender=1 and  and (gender+10)<100
+     *                  |
+     *               scan(student)
+     */
+    @Test
+    public void pushDownPredicateTwoFilterTest() {
+        Table student = new Table(0L, "student", Table.TableType.OLAP,
+                ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
+                        new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""),
+                        new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
+                        new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
+        Plan scan = new LogicalOlapScan(student, ImmutableList.of("student"));
+        Slot gender = scan.getOutput().get(1);
+        Slot name = scan.getOutput().get(2);
+        Slot age = scan.getOutput().get(3);
+
+        List<Expression> groupByKeys = Lists.newArrayList(age, gender);
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(gender, age);
+        Plan aggregation = new LogicalAggregate<>(groupByKeys, outputExpressionList, scan);
+        Expression filterPredicate = ExpressionUtils.and(
+                new GreaterThan(gender, Literal.of(1)),
+                new LessThanEqual(
+                        new Add(
+                                gender,
+                                Literal.of(10)
+                        ),
+                        Literal.of(100)
+                ),
+                new EqualTo(name, Literal.of("abc"))
+        );
+        LogicalFilter filter = new LogicalFilter(filterPredicate, aggregation);
+        Plan root = new LogicalProject<>(
+                Lists.newArrayList(gender),
+                filter
+        );
+
+        Memo memo = rewrite(root);
+        System.out.println(memo.copyOut().treeString());
+        Group rootGroup = memo.getRoot();
+        GroupExpression groupExpression = rootGroup.getLogicalExpression().child(0).getLogicalExpression();
+        Plan upperFilter = groupExpression.getPlan();
+        Assert.assertTrue(upperFilter instanceof LogicalFilter);
+        Expression upperPredicates = ((LogicalFilter<?>) upperFilter).getPredicates();
+        Assert.assertTrue(upperPredicates instanceof  EqualTo);
+        Assert.assertTrue(upperPredicates.child(0) instanceof Slot);
+        groupExpression = groupExpression.child(0).getLogicalExpression();
+        aggregation = groupExpression.getPlan();
+        Assert.assertTrue(aggregation instanceof LogicalAggregate);
+        groupExpression = groupExpression.child(0).getLogicalExpression();
+        Plan bottomFilter = groupExpression.getPlan();
+        Assert.assertTrue(bottomFilter instanceof LogicalFilter);
+        Expression bottomPredicates = ((LogicalFilter<?>) bottomFilter).getPredicates();
+        Assert.assertTrue(bottomPredicates instanceof And);
+        Assert.assertEquals(2, bottomPredicates.children().size());
+        Expression greater = bottomPredicates.child(0);
+        Assert.assertTrue(greater instanceof GreaterThan);
+        Assert.assertTrue(greater.child(0) instanceof Slot);
+        Assert.assertEquals("gender", ((Slot) greater.child(0)).getName());
+        Expression less = bottomPredicates.child(1);
+        Assert.assertTrue(less instanceof LessThanEqual);
+        Assert.assertTrue(less.child(0) instanceof Add);
+
+        groupExpression = groupExpression.child(0).getLogicalExpression();
+        Plan scan2 = groupExpression.getPlan();
+        Assert.assertTrue(scan2 instanceof LogicalOlapScan);
+    }
+
+    private Memo rewrite(Plan plan) {
+        return PlanRewriter.topDownRewriteMemo(plan, new ConnectContext(), new PushPredicateThroughAggregation());
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org