You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by li...@apache.org on 2022/07/05 03:57:47 UTC

[doris] branch master updated: [Feature] [nereids] Agg rewrite rule of nereids optmizer (#10412)

This is an automated email from the ASF dual-hosted git repository.

lingmiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 680118c6b9 [Feature] [nereids] Agg rewrite rule of nereids optmizer (#10412)
680118c6b9 is described below

commit 680118c6b9573d4e483a147838db9e42b9720ea6
Author: Kikyou1997 <33...@users.noreply.github.com>
AuthorDate: Tue Jul 5 11:57:42 2022 +0800

    [Feature] [nereids] Agg rewrite rule of nereids optmizer (#10412)
    
    Add Rule for disassemble the logical aggregate node, this is necessary since our execution framework is distributed and the execution of aggregate always in two steps, first, aggregate locally then merge them.
    
    Add some fields to logical aggregate to determine whether a logical aggreate operator has been disasembled and mark the aggregate phase it belongs and add the logic to mapping  the new aggregate function to its stale definition to get the function intermediate type.
---
 .../operators/plans/logical/LogicalAggregate.java  |  23 ++-
 .../org/apache/doris/nereids/rules/RuleSet.java    |   5 +
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../rules/rewrite/AggregateDisassemble.java        | 160 +++++++++++++++++++++
 .../doris/nereids/trees/expressions/Alias.java     |  14 +-
 .../nereids/trees/expressions/Expression.java      |   4 +
 .../nereids/trees/expressions/SlotReference.java   |   6 +
 .../expressions/functions/AggregateFunction.java   |  10 ++
 .../trees/expressions/functions/BoundFunction.java |  10 ++
 9 files changed, 227 insertions(+), 6 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java
index 77fe7fc694..28ce1a62cf 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java
@@ -43,11 +43,12 @@ import java.util.Objects;
  */
 public class LogicalAggregate extends LogicalUnaryOperator {
 
+    private final boolean disassembled;
     private final List<Expression> groupByExprList;
     private final List<NamedExpression> outputExpressionList;
     private List<Expression> partitionExprList;
 
-    private AggPhase aggPhase;
+    private final AggPhase aggPhase;
 
     /**
      * Desc: Constructor for LogicalAggregation.
@@ -56,6 +57,18 @@ public class LogicalAggregate extends LogicalUnaryOperator {
         super(OperatorType.LOGICAL_AGGREGATION);
         this.groupByExprList = groupByExprList;
         this.outputExpressionList = outputExpressionList;
+        this.disassembled = false;
+        this.aggPhase = AggPhase.FIRST;
+    }
+
+    public LogicalAggregate(List<Expression> groupByExprList,
+            List<NamedExpression> outputExpressionList,
+            boolean disassembled, AggPhase aggPhase) {
+        super(OperatorType.LOGICAL_AGGREGATION);
+        this.groupByExprList = groupByExprList;
+        this.outputExpressionList = outputExpressionList;
+        this.disassembled = disassembled;
+        this.aggPhase = aggPhase;
     }
 
     public List<Expression> getPartitionExprList() {
@@ -97,7 +110,13 @@ public class LogicalAggregate extends LogicalUnaryOperator {
         return new ImmutableList.Builder<Expression>().addAll(groupByExprList).addAll(outputExpressionList).build();
     }
 
-    @Override
+    public boolean isDisassembled() {
+        return disassembled;
+    }
+
+    /**
+     * Determine the equality with another operator
+     */
     public boolean equals(Object o) {
         if (this == o) {
             return true;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
index c40a3bbb83..d20494cb80 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java
@@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalJoinToHashJoin;
 import org.apache.doris.nereids.rules.implementation.LogicalOlapScanToPhysicalOlapScan;
 import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalProject;
 import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalHeapSort;
+import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
 import org.apache.doris.nereids.trees.TreeNode;
 import org.apache.doris.nereids.trees.plans.Plan;
 
@@ -42,6 +43,10 @@ public class RuleSet {
             .add(new JoinLeftAssociative())
             .build();
 
+    public static final List<Rule<Plan>> REWRITE_RULES = planRuleFactories()
+            .add(new AggregateDisassemble())
+            .build();
+
     public static final List<Rule<Plan>> IMPLEMENTATION_RULES = planRuleFactories()
             .add(new LogicalAggToPhysicalHashAgg())
             .add(new LogicalFilterToPhysicalFilter())
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 9cd9b42ccb..954f060009 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -38,6 +38,7 @@ public enum RuleType {
     PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
 
     // rewrite rules
+    AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
     COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE),
     PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE),
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
new file mode 100644
index 0000000000..afd839a2e8
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -0,0 +1,160 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.analysis.FunctionName;
+import org.apache.doris.catalog.Catalog;
+import org.apache.doris.catalog.Function;
+import org.apache.doris.catalog.Function.CompareMode;
+import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.operators.Operator;
+import org.apache.doris.nereids.operators.plans.AggPhase;
+import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.types.DataType;
+
+import com.clearspring.analytics.util.Lists;
+import com.google.common.base.Preconditions;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * TODO: if instance count is 1, shouldn't disassemble the agg operator
+ * Used to generate the merge agg node for distributed execution.
+ * Do this in following steps:
+ *  1. clone output expr list, find all agg function
+ *  2. set found agg function intermediaType
+ *  3. create new child plan rooted at new local agg
+ *  4. update the slot referenced by expr of merge agg
+ *  5. create plan rooted at merge agg, return it.
+ */
+public class AggregateDisassemble extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule<Plan> build() {
+        return logicalAggregate().when(p -> {
+            LogicalAggregate logicalAggregation = p.getOperator();
+            return !logicalAggregation.isDisassembled();
+        }).thenApply(ctx -> {
+            Plan plan = ctx.root;
+            Operator operator = plan.getOperator();
+            LogicalAggregate agg = (LogicalAggregate) operator;
+            List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
+            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
+            // TODO: shouldn't extract agg function from this field.
+            for (NamedExpression namedExpression : outputExpressionList) {
+                namedExpression = (NamedExpression) namedExpression.clone();
+                List<AggregateFunction> functionCallList =
+                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
+                // TODO: we will have another mechanism to get corresponding stale agg func.
+                for (AggregateFunction functionCall : functionCallList) {
+                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
+                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
+                    Type staleRetType = staleAggFunc.getReturnType();
+                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
+                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+                    }
+                }
+                intermediateAggExpressionList.add(namedExpression);
+            }
+            LogicalAggregate localAgg = new LogicalAggregate(
+                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
+                    intermediateAggExpressionList,
+                    true,
+                    AggPhase.FIRST
+            );
+
+            Plan childPlan = plan(localAgg, plan.child(0));
+            List<Slot> stalePlanOutputSlotList = plan.getOutput();
+            List<Slot> childOutputSlotList = childPlan.getOutput();
+            int childOutputSize = stalePlanOutputSlotList.size();
+            Preconditions.checkState(childOutputSize == childOutputSlotList.size());
+            Map<Slot, Slot> staleToNew = new HashMap<>();
+            for (int i = 0; i < stalePlanOutputSlotList.size(); i++) {
+                staleToNew.put(stalePlanOutputSlotList.get(i), childOutputSlotList.get(i));
+            }
+            List<Expression> groupByExpressionList = agg.getGroupByExprList();
+            for (int i = 0; i < groupByExpressionList.size(); i++) {
+                replaceSlot(staleToNew, groupByExpressionList, groupByExpressionList.get(i), i);
+            }
+            List<NamedExpression> mergeOutputExpressionList = agg.getOutputExpressionList();
+            for (int i = 0; i < mergeOutputExpressionList.size(); i++) {
+                replaceSlot(staleToNew, mergeOutputExpressionList, mergeOutputExpressionList.get(i), i);
+            }
+            LogicalAggregate mergeAgg = new LogicalAggregate(
+                    groupByExpressionList,
+                    mergeOutputExpressionList,
+                    true,
+                    AggPhase.FIRST_MERGE
+            );
+            return plan(mergeAgg, childPlan);
+        }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
+    }
+
+    private org.apache.doris.catalog.AggregateFunction findAggFunc(AggregateFunction functionCall) {
+        FunctionName functionName = new FunctionName(functionCall.getName());
+        List<Expression> expressionList = functionCall.getArguments();
+        List<Type> staleTypeList = expressionList.stream().map(Expression::getDataType)
+                .map(DataType::toCatalogDataType).collect(Collectors.toList());
+        Function staleFuncDesc = new Function(functionName, staleTypeList,
+                functionCall.getDataType().toCatalogDataType(),
+                // I think an aggregate function will never have a variable length parameters
+                false);
+        Function staleFunc = Catalog.getCurrentCatalog()
+                .getFunction(staleFuncDesc, CompareMode.IS_IDENTICAL);
+        Preconditions.checkArgument(staleFunc instanceof org.apache.doris.catalog.AggregateFunction);
+        return  (org.apache.doris.catalog.AggregateFunction) staleFunc;
+    }
+
+    @SuppressWarnings("unchecked")
+    private <T extends Expression> void replaceSlot(Map<Slot, Slot> staleToNew,
+            List<T> expressionList, Expression root, int index) {
+        if (index != -1) {
+            if (root instanceof Slot) {
+                Slot v = staleToNew.get(root);
+                if (v == null) {
+                    return;
+                }
+                expressionList.set(index, (T) v);
+                return;
+            }
+        }
+        List<Expression> children = root.children();
+        for (int i = 0; i < children.size(); i++) {
+            Expression cur = children.get(i);
+            if (!(cur instanceof Slot)) {
+                replaceSlot(staleToNew, expressionList, cur, -1);
+                continue;
+            }
+            Expression v = staleToNew.get(cur);
+            if (v == null) {
+                continue;
+            }
+            children.set(i, v);
+        }
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java
index 2dc5bce7e7..33e7f79807 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java
@@ -85,13 +85,19 @@ public class Alias<CHILD_TYPE extends Expression> extends NamedExpression
     }
 
     @Override
+    public String toString() {
+        return child().toString() + " AS " + name;
+    }
+
+    @Override
+    public Alias<CHILD_TYPE> clone() {
+        CHILD_TYPE childType = (CHILD_TYPE) children.get(0).clone();
+        return new Alias<>(childType, name);
+    }
+
     public Expression withChildren(List<Expression> children) {
         Preconditions.checkArgument(children.size() == 1);
         return new Alias<>(children.get(0), name);
     }
 
-    @Override
-    public String toString() {
-        return child().toString() + " AS " + name;
-    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
index 70d5a53805..49595aaa69 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
@@ -83,6 +83,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> {
         return false;
     }
 
+    public Expression clone() {
+        throw new RuntimeException("Unimplemented method");
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java
index f4bb1269c3..789f131a20 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.NodeType;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.DataType;
 
+import com.clearspring.analytics.util.Lists;
 import com.google.common.base.Preconditions;
 import org.apache.commons.lang.StringUtils;
 
@@ -139,4 +140,9 @@ public class SlotReference extends Slot {
         Preconditions.checkArgument(children.size() == 0);
         return this;
     }
+
+    @Override
+    public SlotReference clone() {
+        return new SlotReference(name, getDataType(), nullable, Lists.newArrayList(qualifier));
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
index 8b581a475e..371aff83b6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java
@@ -24,6 +24,8 @@ import org.apache.doris.nereids.types.DataType;
 /** AggregateFunction. */
 public abstract class AggregateFunction extends BoundFunction {
 
+    private DataType intermediate;
+
     public AggregateFunction(String name, Expression... arguments) {
         super(name, arguments);
     }
@@ -34,4 +36,12 @@ public abstract class AggregateFunction extends BoundFunction {
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
         return visitor.visitAggregateFunction(this, context);
     }
+
+    public DataType getIntermediate() {
+        return intermediate;
+    }
+
+    public void setIntermediate(DataType intermediate) {
+        this.intermediate = intermediate;
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java
index 70a1d60669..c6f52a22b4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.NodeType;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
@@ -82,4 +83,13 @@ public class BoundFunction extends Expression {
                 .collect(Collectors.joining(", "));
         return name + "(" + args + ")";
     }
+
+    @Override
+    public BoundFunction clone() {
+        List<Expression> paramList = new ArrayList<>();
+        for (Expression param : getArguments()) {
+            paramList.add(param.clone());
+        }
+        return new BoundFunction(this.name, paramList.toArray(new Expression[0]));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org