You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by GitBox <gi...@apache.org> on 2022/07/06 18:44:25 UTC

[GitHub] [doris] morrySnow opened a new pull request, #10659: [enhancement](nereids) make aggregate works

morrySnow opened a new pull request, #10659:
URL: https://github.com/apache/doris/pull/10659

   # Proposed changes
   
   enhancement
   - refactor compute output expression on root fragment in nereids planner
   - refactor aggregate plan translator
   - refactor aggregate disassemble rule
   - add exchange node on the top of plan node tree
   - add contains interface on TreeNode
   
   fix
   - slotDescriptor should not reuse between TupleDescriptors
   - expression's nullable now works fine
   
   known issues
   - aggregate function must be the top expression in output expression (need project in ExecNode in BE)
   - first phase aggregate could not convert to stream mode.
   - OlapScanNode do not set data partition
   
   ## Checklist(Required)
   
   1. Does it affect the original behavior: No
   2. Has unit tests been added: Yes
   3. Has document been added or modified: No Need
   4. Does it need to update dependencies: No
   5. Are there any changes that cannot be rolled back: No
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] 924060929 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
924060929 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915532556


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -64,49 +65,69 @@ public Rule<Plan> build() {
             Operator operator = plan.getOperator();
             LogicalAggregate agg = (LogicalAggregate) operator;
             List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            List<Expression> groupByExpressionList = agg.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap();
+            for (NamedExpression outputExpression : outputExpressionList) {
+                outputExpression.foreach(e -> {
+                    if (e instanceof AggregateFunction) {
+                        AggregateFunction a = (AggregateFunction) e;
+                        aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql()));
+                    }
+                });
+            }
+
+            List<Expression> updateGroupByExpressionList = groupByExpressionList;
+            List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream()
+                    .map(g -> new Alias<>(g, g.sql()))
+                    .collect(Collectors.toList());
+
+            List<NamedExpression> updateOutputExpressionList = Lists.newArrayList();
+            updateOutputExpressionList.addAll(updateGroupByAliasList);
+            updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values());
+
+            List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream()
+                    .map(NamedExpression::toSlot).collect(Collectors.toList());
+
+            List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList();
+            for (NamedExpression o : outputExpressionList) {
+                if (o.contains(AggregateFunction.class::isInstance)) {
+                    mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter()
+                            .visit(o, aggregateFunctionAliasMap));
+                } else {
+                    for (int i = 0; i < updateGroupByAliasList.size(); i++) {
+                        // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+                        //    e.g. a + 1 + 2 in output expression should be replaced by
+                        //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+                        //   currently, we could only handle output expression same with group by expression
+                        if (o instanceof SlotReference) {
+                            // a in output expression will be SLotReference
+                            if (o.equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        } else if (o instanceof Alias) {
+                            // a + 1 in output expression will be Alias
+                            if (o.child(0).equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        }
                     }
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
+
             LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+                    updateGroupByExpressionList,
+                    updateOutputExpressionList,

Review Comment:
   I think this variable names and compute logic is confuse, how about this:
   1. localGroupByExprs = originGloupByExprs
   2. localOutputExprs = originOutput.withAlias
   3. globalGroupByWithAlias = originGloupByExprs.withAlias
   4. globalOutputWithAlias = originOutput.replaceAggregateFunctionArgumentsToAlias
   
   this advantage is
   1. variableName contains position and member information
   2. assign statement contains the most simple compute logical
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916499550


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -263,23 +313,18 @@ public PlanFragment visitPhysicalHashJoin(
         // NOTICE: We must visit from right to left, to ensure the last fragment is root fragment
         PlanFragment rightFragment = visit(hashJoin.child(1), context);
         PlanFragment leftFragment = visit(hashJoin.child(0), context);
-        PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
-
-        //        Expression predicateExpr = physicalHashJoin.getCondition().get();
-        //        List<Expression> eqExprList = Utils.getEqConjuncts(hashJoin.child(0).getOutput(),
-        //                hashJoin.child(1).getOutput(), predicateExpr);
-        JoinType joinType = physicalHashJoin.getJoinType();
-
         PlanNode leftFragmentPlanRoot = leftFragment.getPlanRoot();
         PlanNode rightFragmentPlanRoot = rightFragment.getPlanRoot();
+        PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
+        JoinType joinType = physicalHashJoin.getJoinType();
 
         if (joinType.equals(JoinType.CROSS_JOIN)

Review Comment:
   this PR only change translate join from shuffled join to broadcast join. btw, in translator we should translate all situation, should not depend on implementation and exploration job in cascades. because when we test cascades framework, we will forbidden some rules



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916505299


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java:
##########
@@ -66,32 +65,28 @@ public TupleDescriptor generateTupleDesc() {
         return descTable.createTupleDescriptor();
     }
 
-    public PlanNodeId nextNodeId() {
+    public PlanFragmentId nextFragmentId() {
+        return fragmentIdGenerator.getNextId();
+    }
+
+    public PlanNodeId nextPlanNodeId() {
         return nodeIdGenerator.getNextId();
     }
 
     public SlotDescriptor addSlotDesc(TupleDescriptor t) {
         return descTable.addSlotDescriptor(t);
     }
 
-    public SlotDescriptor addSlotDesc(TupleDescriptor t, int id) {
-        return descTable.addSlotDescriptor(t, id);
-    }
-
-    public PlanFragmentId nextFragmentId() {
-        return fragmentIdGenerator.getNextId();
-    }
-
     public void addPlanFragment(PlanFragment planFragment) {
         this.planFragmentList.add(planFragment);
     }
 
-    public void addSlotRefMapping(Expression expression, Expr expr) {
-        expressionToExecExpr.put(expression, expr);
+    public void addIdPair(ExprId exprId, SlotRef slotRef) {

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915540665


##########
fe/fe-core/src/test/java/org/apache/doris/nereids/AnalyzeSSBTest.java:
##########
@@ -161,10 +161,7 @@ private void executeRewriteBottomUpJob(PlannerContext plannerContext, RuleFactor
     }
 
     private boolean checkBound(LogicalPlan root) {
-        if (!checkPlanBound(root))  {
-            return false;
-        }
-        return true;
+        return checkPlanBound(root);

Review Comment:
   i think so



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] github-actions[bot] commented on pull request #10659: [enhancement](nereids) make SSB works

Posted by GitBox <gi...@apache.org>.
github-actions[bot] commented on PR #10659:
URL: https://github.com/apache/doris/pull/10659#issuecomment-1179929020

   PR approved by anyone and no changes requested.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] 924060929 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
924060929 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915575646


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -64,49 +65,69 @@ public Rule<Plan> build() {
             Operator operator = plan.getOperator();
             LogicalAggregate agg = (LogicalAggregate) operator;
             List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            List<Expression> groupByExpressionList = agg.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap();
+            for (NamedExpression outputExpression : outputExpressionList) {
+                outputExpression.foreach(e -> {
+                    if (e instanceof AggregateFunction) {
+                        AggregateFunction a = (AggregateFunction) e;
+                        aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql()));
+                    }
+                });
+            }
+
+            List<Expression> updateGroupByExpressionList = groupByExpressionList;
+            List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream()
+                    .map(g -> new Alias<>(g, g.sql()))
+                    .collect(Collectors.toList());
+
+            List<NamedExpression> updateOutputExpressionList = Lists.newArrayList();
+            updateOutputExpressionList.addAll(updateGroupByAliasList);
+            updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values());
+
+            List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream()
+                    .map(NamedExpression::toSlot).collect(Collectors.toList());
+
+            List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList();
+            for (NamedExpression o : outputExpressionList) {
+                if (o.contains(AggregateFunction.class::isInstance)) {
+                    mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter()
+                            .visit(o, aggregateFunctionAliasMap));
+                } else {
+                    for (int i = 0; i < updateGroupByAliasList.size(); i++) {
+                        // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+                        //    e.g. a + 1 + 2 in output expression should be replaced by
+                        //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+                        //   currently, we could only handle output expression same with group by expression
+                        if (o instanceof SlotReference) {
+                            // a in output expression will be SLotReference
+                            if (o.equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        } else if (o instanceof Alias) {
+                            // a + 1 in output expression will be Alias
+                            if (o.child(0).equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        }
                     }
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
+
             LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+                    updateGroupByExpressionList,
+                    updateOutputExpressionList,

Review Comment:
   okey, I remember these term. Keep this logic is reasonable, If reserve these names:
   
   updateGroupByExprs = originGloupByExprs
   updateOutputExprs = originOutput.withAlias
   mergeGroupByWithAlias = originGloupByExprs.withAlias
   mergeOutputWithAlias = originOutput.replaceAggregateFunctionArgumentsToAliasReference
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] 924060929 commented on a diff in pull request #10659: [enhancement](nereids) make SSB works

Posted by GitBox <gi...@apache.org>.
924060929 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r917338424


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java:
##########
@@ -81,33 +74,14 @@ public void plan(StatementBase queryStmt,
 
         PhysicalPlanTranslator physicalPlanTranslator = new PhysicalPlanTranslator();
         PlanTranslatorContext planTranslatorContext = new PlanTranslatorContext();
-        physicalPlanTranslator.translatePlan(physicalPlan, planTranslatorContext);
+        PlanFragment root = physicalPlanTranslator.translatePlan(physicalPlan, planTranslatorContext);
 
         scanNodeList = planTranslatorContext.getScanNodeList();
         descTable = planTranslatorContext.getDescTable();
         fragments = new ArrayList<>(planTranslatorContext.getPlanFragmentList());
-        for (PlanFragment fragment : fragments) {
-            fragment.finalize(queryStmt);
-        }
-        Collections.reverse(fragments);
-        PlanFragment root = fragments.get(0);
-
-        // compute output exprs
-        Map<Integer, Expr> outputCandidates = Maps.newHashMap();
-        List<Expr> outputExprs = Lists.newArrayList();
-        for (TupleId tupleId : root.getPlanRoot().getTupleIds()) {
-            TupleDescriptor tupleDescriptor = descTable.getTupleDesc(tupleId);
-            for (SlotDescriptor slotDescriptor : tupleDescriptor.getSlots()) {
-                SlotRef slotRef = new SlotRef(slotDescriptor);
-                outputCandidates.put(slotDescriptor.getId().asInt(), slotRef);
-            }
-        }
-        physicalPlan.getOutput().stream()
-                .forEach(i -> outputExprs.add(planTranslatorContext.findExpr(i)));
-        root.setOutputExprs(outputExprs);
-        root.getPlanRoot().convertToVectoriezd();
 
-        logicalPlanAdapter.setResultExprs(outputExprs);
+        // set output exprs
+        logicalPlanAdapter.setResultExprs(root.getOutputExprs());

Review Comment:
   great refactor



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -17,144 +17,156 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.analysis.FunctionName;
-import org.apache.doris.catalog.Catalog;
-import org.apache.doris.catalog.Function;
-import org.apache.doris.catalog.Function.CompareMode;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.operators.Operator;
 import org.apache.doris.nereids.operators.plans.AggPhase;
 import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
 
-import com.clearspring.analytics.util.Lists;
-import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 
-import java.util.HashMap;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
 /**
- * TODO: if instance count is 1, shouldn't disassemble the agg operator
  * Used to generate the merge agg node for distributed execution.
- * Do this in following steps:
- *  1. clone output expr list, find all agg function
- *  2. set found agg function intermediaType
- *  3. create new child plan rooted at new local agg
- *  4. update the slot referenced by expr of merge agg
- *  5. create plan rooted at merge agg, return it.
+ * NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output expressions' ExprId.
+ * If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 * v2) + 1) #2], groupByExpr: [k + 1])
+ *   +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1) #2], groupByExpr: [b])
+ *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
+ *       +-- childPlan

Review Comment:
   great comment



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -17,144 +17,156 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.analysis.FunctionName;
-import org.apache.doris.catalog.Catalog;
-import org.apache.doris.catalog.Function;
-import org.apache.doris.catalog.Function.CompareMode;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.operators.Operator;
 import org.apache.doris.nereids.operators.plans.AggPhase;
 import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
 
-import com.clearspring.analytics.util.Lists;
-import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 
-import java.util.HashMap;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
 /**
- * TODO: if instance count is 1, shouldn't disassemble the agg operator
  * Used to generate the merge agg node for distributed execution.
- * Do this in following steps:
- *  1. clone output expr list, find all agg function
- *  2. set found agg function intermediaType
- *  3. create new child plan rooted at new local agg
- *  4. update the slot referenced by expr of merge agg
- *  5. create plan rooted at merge agg, return it.
+ * NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output expressions' ExprId.
+ * If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 * v2) + 1) #2], groupByExpr: [k + 1])
+ *   +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1) #2], groupByExpr: [b])
+ *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
+ *       +-- childPlan
+ *
+ * TODO:
+ *     1. use different class represent different phase aggregate
+ *     2. if instance count is 1, shouldn't disassemble the agg operator
+ *     3. we need another rule to removing duplicated expressions in group by expression list
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
     @Override
     public Rule<Plan> build() {
         return logicalAggregate().when(p -> {
-            LogicalAggregate logicalAggregation = p.getOperator();
-            return !logicalAggregation.isDisassembled();
+            LogicalAggregate logicalAggregate = p.getOperator();
+            return !logicalAggregate.isDisassembled();
         }).thenApply(ctx -> {
-            Plan plan = ctx.root;
-            Operator operator = plan.getOperator();
-            LogicalAggregate agg = (LogicalAggregate) operator;
-            List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            LogicalUnaryPlan<LogicalAggregate, GroupPlan> plan = ctx.root;
+            LogicalAggregate aggregate = plan.getOperator();
+            List<NamedExpression> originOutputExprs = aggregate.getOutputExpressionList();
+            List<Expression> originGroupByExprs = aggregate.getGroupByExpressionList();
+
+            // 1. generate a map from local aggregate output to global aggregate expr substitution.
+            //    inputSubstitutionMap use for replacing expression in global aggregate
+            //    replace rule is:
+            //        a: Expression is a group by key and is a slot reference. e.g. group by k1
+            //        b. Expression is a group by key and is an expression. e.g. group by k1 + 1
+            //        c. Expression is an aggregate function. e.g. sum(v1) in select list
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | situation | origin expression   | local output expression | expression in global aggregate |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | a         | Ref(k1)#1           | Ref(k1)#1               | Ref(k1)#1                      |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2                     |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     | AF(af#3)                       |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
+            // 2. collect local aggregate output expressions and local aggregate group by expression list

Review Comment:
   great comment



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -17,144 +17,156 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.analysis.FunctionName;
-import org.apache.doris.catalog.Catalog;
-import org.apache.doris.catalog.Function;
-import org.apache.doris.catalog.Function.CompareMode;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.operators.Operator;
 import org.apache.doris.nereids.operators.plans.AggPhase;
 import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
 
-import com.clearspring.analytics.util.Lists;
-import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 
-import java.util.HashMap;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
 /**
- * TODO: if instance count is 1, shouldn't disassemble the agg operator
  * Used to generate the merge agg node for distributed execution.
- * Do this in following steps:
- *  1. clone output expr list, find all agg function
- *  2. set found agg function intermediaType
- *  3. create new child plan rooted at new local agg
- *  4. update the slot referenced by expr of merge agg
- *  5. create plan rooted at merge agg, return it.
+ * NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output expressions' ExprId.
+ * If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 * v2) + 1) #2], groupByExpr: [k + 1])
+ *   +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1) #2], groupByExpr: [b])
+ *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
+ *       +-- childPlan
+ *
+ * TODO:
+ *     1. use different class represent different phase aggregate
+ *     2. if instance count is 1, shouldn't disassemble the agg operator
+ *     3. we need another rule to removing duplicated expressions in group by expression list
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
     @Override
     public Rule<Plan> build() {
         return logicalAggregate().when(p -> {
-            LogicalAggregate logicalAggregation = p.getOperator();
-            return !logicalAggregation.isDisassembled();
+            LogicalAggregate logicalAggregate = p.getOperator();
+            return !logicalAggregate.isDisassembled();
         }).thenApply(ctx -> {
-            Plan plan = ctx.root;
-            Operator operator = plan.getOperator();
-            LogicalAggregate agg = (LogicalAggregate) operator;
-            List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            LogicalUnaryPlan<LogicalAggregate, GroupPlan> plan = ctx.root;
+            LogicalAggregate aggregate = plan.getOperator();
+            List<NamedExpression> originOutputExprs = aggregate.getOutputExpressionList();
+            List<Expression> originGroupByExprs = aggregate.getGroupByExpressionList();
+
+            // 1. generate a map from local aggregate output to global aggregate expr substitution.
+            //    inputSubstitutionMap use for replacing expression in global aggregate
+            //    replace rule is:
+            //        a: Expression is a group by key and is a slot reference. e.g. group by k1
+            //        b. Expression is a group by key and is an expression. e.g. group by k1 + 1
+            //        c. Expression is an aggregate function. e.g. sum(v1) in select list
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | situation | origin expression   | local output expression | expression in global aggregate |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | a         | Ref(k1)#1           | Ref(k1)#1               | Ref(k1)#1                      |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2                     |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     | AF(af#3)                       |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
+            // 2. collect local aggregate output expressions and local aggregate group by expression list
+            Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
+            List<Expression> localGroupByExprs = aggregate.getGroupByExpressionList();
+            List<NamedExpression> localOutputExprs = Lists.newArrayList();
+            for (Expression originGroupByExpr : originGroupByExprs) {
+                if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
+                    continue;
+                }
+                if (originGroupByExpr instanceof SlotReference) {
+                    inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
+                    localOutputExprs.add((SlotReference) originGroupByExpr);
+                } else {
+                    NamedExpression localOutputExpr = new Alias<>(originGroupByExpr, originGroupByExpr.toSql());
+                    inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
+                    localOutputExprs.add(localOutputExpr);
+                }
+            }
+            for (NamedExpression originOutputExpr : originOutputExprs) {
+                List<AggregateFunction> aggregateFunctions
+                        = originOutputExpr.collect(AggregateFunction.class::isInstance);
+                for (AggregateFunction aggregateFunction : aggregateFunctions) {
+                    if (inputSubstitutionMap.containsKey(aggregateFunction)) {
+                        continue;
                     }
+                    NamedExpression localOutputExpr = new Alias<>(aggregateFunction, aggregateFunction.toSql());
+                    Expression substitutionValue = aggregateFunction.withChildren(
+                            Lists.newArrayList(localOutputExpr.toSlot()));
+                    inputSubstitutionMap.put(aggregateFunction, substitutionValue);
+                    localOutputExprs.add(localOutputExpr);
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
-            LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+
+            // 3. replace expression in globalOutputExprs and globalGroupByExprs
+            List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressionList().stream()
+                    .map(e -> ExpressionReplacer.INSTANCE.visit(e, inputSubstitutionMap))
+                    .map(NamedExpression.class::cast)
+                    .collect(Collectors.toList());
+            List<Expression> globalGroupByExprs = localGroupByExprs.stream()
+                    .map(e -> ExpressionReplacer.INSTANCE.visit(e, inputSubstitutionMap)).collect(Collectors.toList());
+
+            // 4. generate new plan
+            LogicalAggregate localAggregate = new LogicalAggregate(
+                    localGroupByExprs,
+                    localOutputExprs,
                     true,
-                    AggPhase.FIRST
+                    AggPhase.LOCAL
             );
-
-            Plan childPlan = plan(localAgg, plan.child(0));
-            List<Slot> stalePlanOutputSlotList = plan.getOutput();
-            List<Slot> childOutputSlotList = childPlan.getOutput();
-            int childOutputSize = stalePlanOutputSlotList.size();
-            Preconditions.checkState(childOutputSize == childOutputSlotList.size());
-            Map<Slot, Slot> staleToNew = new HashMap<>();
-            for (int i = 0; i < stalePlanOutputSlotList.size(); i++) {
-                staleToNew.put(stalePlanOutputSlotList.get(i), childOutputSlotList.get(i));
-            }
-            List<Expression> groupByExpressionList = agg.getGroupByExprList();
-            for (int i = 0; i < groupByExpressionList.size(); i++) {
-                replaceSlot(staleToNew, groupByExpressionList, groupByExpressionList.get(i), i);
-            }
-            List<NamedExpression> mergeOutputExpressionList = agg.getOutputExpressionList();
-            for (int i = 0; i < mergeOutputExpressionList.size(); i++) {
-                replaceSlot(staleToNew, mergeOutputExpressionList, mergeOutputExpressionList.get(i), i);
-            }
-            LogicalAggregate mergeAgg = new LogicalAggregate(
-                    groupByExpressionList,
-                    mergeOutputExpressionList,
+            LogicalAggregate globalAggregate = new LogicalAggregate(
+                    globalGroupByExprs,
+                    globalOutputExprs,
                     true,
-                    AggPhase.FIRST_MERGE
+                    AggPhase.GLOBAL
             );
-            return plan(mergeAgg, childPlan);
+            return plan(globalAggregate, plan(localAggregate, plan.child(0)));
         }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
     }
 
-    private org.apache.doris.catalog.AggregateFunction findAggFunc(AggregateFunction functionCall) {
-        FunctionName functionName = new FunctionName(functionCall.getName());
-        List<Expression> expressionList = functionCall.getArguments();
-        List<Type> staleTypeList = expressionList.stream().map(Expression::getDataType)
-                .map(DataType::toCatalogDataType).collect(Collectors.toList());
-        Function staleFuncDesc = new Function(functionName, staleTypeList,
-                functionCall.getDataType().toCatalogDataType(),
-                // I think an aggregate function will never have a variable length parameters
-                false);
-        Function staleFunc = Catalog.getCurrentCatalog()
-                .getFunction(staleFuncDesc, CompareMode.IS_IDENTICAL);
-        Preconditions.checkArgument(staleFunc instanceof org.apache.doris.catalog.AggregateFunction);
-        return  (org.apache.doris.catalog.AggregateFunction) staleFunc;
-    }
+    @SuppressWarnings("InnerClassMayBeStatic")
+    private static class ExpressionReplacer
+            extends ExpressionVisitor<Expression, Map<Expression, Expression>> {
+        private static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
 
-    @SuppressWarnings("unchecked")
-    private <T extends Expression> void replaceSlot(Map<Slot, Slot> staleToNew,
-            List<T> expressionList, Expression root, int index) {
-        if (index != -1) {
-            if (root instanceof Slot) {
-                Slot v = staleToNew.get(root);
-                if (v == null) {
-                    return;
+        @Override
+        public Expression visit(Expression expr, Map<Expression, Expression> substitutionMap) {
+            // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+            //    e.g. a + 1 + 2 in output expression should be replaced by
+            //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+            //   currently, we could only handle output expression same with group by expression
+            if (substitutionMap.containsKey(expr)) {
+                return substitutionMap.get(expr);
+            } else {
+                List<Expression> newChildren = new ArrayList<>();

Review Comment:
   I think you can extends DefaultExpressionRewriter, and replace this line to `return super.visit(expr, substitutionMap)`



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -17,144 +17,156 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.analysis.FunctionName;
-import org.apache.doris.catalog.Catalog;
-import org.apache.doris.catalog.Function;
-import org.apache.doris.catalog.Function.CompareMode;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.operators.Operator;
 import org.apache.doris.nereids.operators.plans.AggPhase;
 import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
 
-import com.clearspring.analytics.util.Lists;
-import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 
-import java.util.HashMap;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
 /**
- * TODO: if instance count is 1, shouldn't disassemble the agg operator
  * Used to generate the merge agg node for distributed execution.
- * Do this in following steps:
- *  1. clone output expr list, find all agg function
- *  2. set found agg function intermediaType
- *  3. create new child plan rooted at new local agg
- *  4. update the slot referenced by expr of merge agg
- *  5. create plan rooted at merge agg, return it.
+ * NOTICE: GLOBAL output expressions' ExprId should SAME with ORIGIN output expressions' ExprId.
+ * If we have a query: SELECT SUM(v1 * v2) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(SUM(v1 * v2) + 1) #2], groupByExpr: [k + 1])
+ *   +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [Alias(b) #1, Alias(SUM(a) + 1) #2], groupByExpr: [b])
+ *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
+ *       +-- childPlan
+ *
+ * TODO:
+ *     1. use different class represent different phase aggregate
+ *     2. if instance count is 1, shouldn't disassemble the agg operator
+ *     3. we need another rule to removing duplicated expressions in group by expression list
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
     @Override
     public Rule<Plan> build() {
         return logicalAggregate().when(p -> {
-            LogicalAggregate logicalAggregation = p.getOperator();
-            return !logicalAggregation.isDisassembled();
+            LogicalAggregate logicalAggregate = p.getOperator();
+            return !logicalAggregate.isDisassembled();
         }).thenApply(ctx -> {
-            Plan plan = ctx.root;
-            Operator operator = plan.getOperator();
-            LogicalAggregate agg = (LogicalAggregate) operator;
-            List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            LogicalUnaryPlan<LogicalAggregate, GroupPlan> plan = ctx.root;
+            LogicalAggregate aggregate = plan.getOperator();
+            List<NamedExpression> originOutputExprs = aggregate.getOutputExpressionList();
+            List<Expression> originGroupByExprs = aggregate.getGroupByExpressionList();
+
+            // 1. generate a map from local aggregate output to global aggregate expr substitution.
+            //    inputSubstitutionMap use for replacing expression in global aggregate
+            //    replace rule is:
+            //        a: Expression is a group by key and is a slot reference. e.g. group by k1
+            //        b. Expression is a group by key and is an expression. e.g. group by k1 + 1
+            //        c. Expression is an aggregate function. e.g. sum(v1) in select list
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | situation | origin expression   | local output expression | expression in global aggregate |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | a         | Ref(k1)#1           | Ref(k1)#1               | Ref(k1)#1                      |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | b         | Ref(k1)#1 + 1       | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2                     |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    | c         | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3     | AF(af#3)                       |
+            //    +-----------+---------------------+-------------------------+--------------------------------+
+            //    NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
+            // 2. collect local aggregate output expressions and local aggregate group by expression list

Review Comment:
   great comment



##########
fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java:
##########
@@ -0,0 +1,322 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.catalog.AggregateType;
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.Table;
+import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.PlannerContext;
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;
+import org.apache.doris.nereids.memo.Memo;
+import org.apache.doris.nereids.operators.plans.AggPhase;
+import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.operators.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Literal;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Sum;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.Plans;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
+import org.apache.doris.qe.ConnectContext;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInstance;
+
+import java.util.List;
+
+@TestInstance(TestInstance.Lifecycle.PER_CLASS)
+public class AggregateDisassembleTest implements Plans {
+    private Plan rStudent;
+
+    @BeforeAll
+    public final void beforeAll() {
+        Table student = new Table(0L, "student", Table.TableType.OLAP,
+                ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, true, "0", ""),
+                        new Column("name", Type.STRING, true, AggregateType.NONE, true, "", ""),
+                        new Column("age", Type.INT, true, AggregateType.NONE, true, "", "")));
+        rStudent = plan(new LogicalOlapScan(student, ImmutableList.of("student")));
+    }
+
+    /**
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum], groupByExpr: [age])
+     *   +--childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
+     *   +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b], groupByExpr: [age])
+     *       +--childPlan(id, name, age)
+     */
+    @Test
+    public void slotReferenceGroupBy() {
+        List<Expression> groupExpressionList = Lists.newArrayList(
+                rStudent.getOutput().get(2).toSlot());
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(
+                rStudent.getOutput().get(2).toSlot(),
+                new Alias<>(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
+        Plan root = plan(new LogicalAggregate(groupExpressionList, outputExpressionList), rStudent);
+
+        Memo memo = new Memo();
+        memo.initialize(root);
+
+        PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
+        JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
+        RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
+                ImmutableList.of(new AggregateDisassemble().build()), jobContext);
+        plannerContext.pushJob(rewriteTopDownJob);
+        plannerContext.getJobScheduler().executeJobPool(plannerContext);
+
+        Plan after = memo.copyOut();
+
+        Assertions.assertTrue(after instanceof LogicalUnaryPlan);
+        Assertions.assertTrue(after.getOperator() instanceof LogicalAggregate);
+        Assertions.assertTrue(after.child(0) instanceof LogicalUnaryPlan);
+        LogicalAggregate global = (LogicalAggregate) after.getOperator();
+        LogicalAggregate local = (LogicalAggregate) after.child(0).getOperator();
+        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
+        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+
+        Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
+        Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
+        Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
+
+        Assertions.assertEquals(2, local.getOutputExpressionList().size());
+        Assertions.assertTrue(local.getOutputExpressionList().get(0) instanceof SlotReference);
+        Assertions.assertEquals(localOutput0, local.getOutputExpressionList().get(0));
+        Assertions.assertTrue(local.getOutputExpressionList().get(1) instanceof Alias);
+        Assertions.assertEquals(localOutput1, local.getOutputExpressionList().get(1).child(0));
+        Assertions.assertEquals(1, local.getGroupByExpressionList().size());
+        Assertions.assertEquals(localGroupBy, local.getGroupByExpressionList().get(0));
+
+        Expression globalOutput0 = local.getOutputExpressionList().get(0).toSlot();
+        Expression globalOutput1 = new Sum(local.getOutputExpressionList().get(1).toSlot());
+        Expression globalGroupBy = local.getOutputExpressionList().get(0).toSlot();
+
+        Assertions.assertEquals(2, global.getOutputExpressionList().size());
+        Assertions.assertTrue(global.getOutputExpressionList().get(0) instanceof SlotReference);
+        Assertions.assertEquals(globalOutput0, global.getOutputExpressionList().get(0));
+        Assertions.assertTrue(global.getOutputExpressionList().get(1) instanceof Alias);
+        Assertions.assertEquals(globalOutput1, global.getOutputExpressionList().get(1).child(0));
+        Assertions.assertEquals(1, global.getGroupByExpressionList().size());
+        Assertions.assertEquals(globalGroupBy, global.getGroupByExpressionList().get(0));
+
+        // check id:
+        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
+                global.getOutputExpressionList().get(0).getExprId());
+        Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
+                global.getOutputExpressionList().get(1).getExprId());
+    }
+
+    /**
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [(age + 1) as key, SUM(id) as sum], groupByExpr: [age + 1])
+     *   +--childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr: [a])
+     *   +--Aggregate(phase: [LOCAL], outputExpr: [(age + 1) as a, SUM(id) as b], groupByExpr: [age + 1])
+     *       +--childPlan(id, name, age)
+     */
+    @Test
+    public void aliasGroupBy() {
+        List<Expression> groupExpressionList = Lists.newArrayList(
+                new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1)));
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(
+                new Alias<>(new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1)), "key"),
+                new Alias<>(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
+        Plan root = plan(new LogicalAggregate(groupExpressionList, outputExpressionList), rStudent);
+
+        Memo memo = new Memo();
+        memo.initialize(root);
+
+        PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
+        JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
+        RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
+                ImmutableList.of(new AggregateDisassemble().build()), jobContext);
+        plannerContext.pushJob(rewriteTopDownJob);
+        plannerContext.getJobScheduler().executeJobPool(plannerContext);
+
+        Plan after = memo.copyOut();
+
+        Assertions.assertTrue(after instanceof LogicalUnaryPlan);
+        Assertions.assertTrue(after.getOperator() instanceof LogicalAggregate);
+        Assertions.assertTrue(after.child(0) instanceof LogicalUnaryPlan);
+        LogicalAggregate global = (LogicalAggregate) after.getOperator();
+        LogicalAggregate local = (LogicalAggregate) after.child(0).getOperator();
+        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
+        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+
+        Expression localOutput0 = new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1));
+        Expression localOutput1 = new Sum(rStudent.getOutput().get(0).toSlot());
+        Expression localGroupBy = new Add<>(rStudent.getOutput().get(2).toSlot(), new Literal(1));
+
+        Assertions.assertEquals(2, local.getOutputExpressionList().size());
+        Assertions.assertTrue(local.getOutputExpressionList().get(0) instanceof Alias);
+        Assertions.assertEquals(localOutput0, local.getOutputExpressionList().get(0).child(0));
+        Assertions.assertTrue(local.getOutputExpressionList().get(1) instanceof Alias);
+        Assertions.assertEquals(localOutput1, local.getOutputExpressionList().get(1).child(0));
+        Assertions.assertEquals(1, local.getGroupByExpressionList().size());
+        Assertions.assertEquals(localGroupBy, local.getGroupByExpressionList().get(0));
+
+        Expression globalOutput0 = local.getOutputExpressionList().get(0).toSlot();
+        Expression globalOutput1 = new Sum(local.getOutputExpressionList().get(1).toSlot());
+        Expression globalGroupBy = local.getOutputExpressionList().get(0).toSlot();
+
+        Assertions.assertEquals(2, global.getOutputExpressionList().size());
+        Assertions.assertTrue(global.getOutputExpressionList().get(0) instanceof Alias);
+        Assertions.assertEquals(globalOutput0, global.getOutputExpressionList().get(0).child(0));
+        Assertions.assertTrue(global.getOutputExpressionList().get(1) instanceof Alias);
+        Assertions.assertEquals(globalOutput1, global.getOutputExpressionList().get(1).child(0));
+        Assertions.assertEquals(1, global.getGroupByExpressionList().size());
+        Assertions.assertEquals(globalGroupBy, global.getGroupByExpressionList().get(0));
+
+        // check id:
+        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
+                global.getOutputExpressionList().get(0).getExprId());
+        Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
+                global.getOutputExpressionList().get(1).getExprId());
+    }
+
+    /**
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [])
+     *   +--childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
+     *   +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr: [])
+     *       +--childPlan(id, name, age)
+     */
+    @Test
+    public void globalAggregate() {
+        List<Expression> groupExpressionList = Lists.newArrayList();
+        List<NamedExpression> outputExpressionList = Lists.newArrayList(
+                new Alias<>(new Sum(rStudent.getOutput().get(0).toSlot()), "sum"));
+        Plan root = plan(new LogicalAggregate(groupExpressionList, outputExpressionList), rStudent);
+
+        Memo memo = new Memo();
+        memo.initialize(root);
+
+        PlannerContext plannerContext = new PlannerContext(memo, new ConnectContext());
+        JobContext jobContext = new JobContext(plannerContext, new PhysicalProperties(), 0);
+        RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(),
+                ImmutableList.of(new AggregateDisassemble().build()), jobContext);
+        plannerContext.pushJob(rewriteTopDownJob);
+        plannerContext.getJobScheduler().executeJobPool(plannerContext);
+
+        Plan after = memo.copyOut();
+
+        Assertions.assertTrue(after instanceof LogicalUnaryPlan);
+        Assertions.assertTrue(after.getOperator() instanceof LogicalAggregate);
+        Assertions.assertTrue(after.child(0) instanceof LogicalUnaryPlan);
+        LogicalAggregate global = (LogicalAggregate) after.getOperator();
+        LogicalAggregate local = (LogicalAggregate) after.child(0).getOperator();
+        Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
+        Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+
+        Expression localOutput0 = new Sum(rStudent.getOutput().get(0).toSlot());
+
+        Assertions.assertEquals(1, local.getOutputExpressionList().size());
+        Assertions.assertTrue(local.getOutputExpressionList().get(0) instanceof Alias);
+        Assertions.assertEquals(localOutput0, local.getOutputExpressionList().get(0).child(0));
+        Assertions.assertEquals(0, local.getGroupByExpressionList().size());
+
+        Expression globalOutput0 = new Sum(local.getOutputExpressionList().get(0).toSlot());
+
+        Assertions.assertEquals(1, global.getOutputExpressionList().size());
+        Assertions.assertTrue(global.getOutputExpressionList().get(0) instanceof Alias);
+        Assertions.assertEquals(globalOutput0, global.getOutputExpressionList().get(0).child(0));
+        Assertions.assertEquals(0, global.getGroupByExpressionList().size());
+
+        // check id:
+        Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
+                global.getOutputExpressionList().get(0).getExprId());
+    }
+
+    /**
+     * the initial plan is:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr: [age])
+     *   +--childPlan(id, name, age)
+     * we should rewrite to:
+     *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr: [a])

Review Comment:
   alias `c` should keep same as origin alias `sum`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915540173


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java:
##########
@@ -57,6 +57,6 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
 
 
     public String toString() {
-        return sql();
+        return left().toString() + ' ' + getArithmeticOperator().toString() + ' ' + right().sql();

Review Comment:
   so sleepy



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Sum.java:
##########
@@ -51,7 +51,7 @@ public DataType getDataType() {
 
     @Override
     public boolean nullable() {
-        return false;
+        return child().nullable();

Review Comment:
   yes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915560690


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
      * Translate Agg.
      */
     @Override
-    public PlanFragment visitPhysicalAggregation(
-            PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
-
+    public PlanFragment visitPhysicalAggregate(
+            PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) {
         PlanFragment inputPlanFragment = visit(agg.child(0), context);
-
-        AggregationNode aggregationNode;
-        List<Slot> slotList = new ArrayList<>();
-        PhysicalAggregation physicalAggregation = agg.getOperator();
-        AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
-
-        List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
+        PhysicalAggregate physicalAggregate = agg.getOperator();
+
+        // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
+        //    1. group by expressions: removing duplicate expressions add to tuple
+        //    2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
+        //       e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
+        //    We need:
+        //    1. add a project after agg, if output expressions include agg function as a expression tree leaf.
+        //    2. introduce canonicalized, semanticEquals and deterministic in Expression
+        //       for removing duplicate.
+        List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList();
+        List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList();
+
+        // 1. generate slot reference for each group expression
+        List<SlotReference> groupSlotList = Lists.newArrayList();
+        for (Expression e : groupByExpressionList) {
+            if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.contains(e::equals))) {
+                groupSlotList.add((SlotReference) e);
+            } else {
+                groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList()));
+            }
+        }
         ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
-                // Since output of plan doesn't contain the slots of groupBy, which is actually needed by
-                // the BE execution, so we have to collect them and add to the slotList to generate corresponding
-                // TupleDesc.
-                .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance)))
                 .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new));
-        slotList.addAll(agg.getOutput());
-        TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null);
-
-        List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList();
-        ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream()
-                .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
+        // 2. collect agg functions and generate agg function to slot reference map
+        List<Slot> aggFunctionOutput = Lists.newArrayList();
+        List<AggregateFunction> aggregateFunctionList = outputExpressionList.stream()
+                .filter(o -> o.contains(AggregateFunction.class::isInstance))
+                .peek(o -> aggFunctionOutput.add(o.toSlot()))
+                .map(o -> (List<AggregateFunction>) o.collect(AggregateFunction.class::isInstance))
                 .flatMap(List::stream)
+                .collect(Collectors.toList());
+        ArrayList<FunctionCallExpr> execAggExpressions = aggregateFunctionList.stream()
                 .map(x -> (FunctionCallExpr) ExpressionTranslator.translate(x, context))
                 .collect(Collectors.toCollection(ArrayList::new));
 
-        List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList();
+        // 3. generate output tuple
+        // TODO: currently, we only support sum(a), if we want to support sum(a) + 1, we need to
+        //  split merge agg to project(agg) and generate tuple like what first phase agg do.
+        List<Slot> slotList = Lists.newArrayList();
+        TupleDescriptor outputTupleDesc;
+        if (agg.getOperator().getAggPhase() == AggPhase.FIRST_MERGE) {
+            slotList.addAll(groupSlotList);
+            slotList.addAll(aggFunctionOutput);
+            outputTupleDesc = generateTupleDesc(slotList, null, context);
+        } else {
+            outputTupleDesc = generateTupleDesc(agg.getOutput(), null, context);
+        }
+
+        // process partition list
+        List<Expression> partitionExpressionList = physicalAggregate.getPartitionExprList();
         List<Expr> execPartitionExpressions = partitionExpressionList.stream()
-                .map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+                .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+        DataPartition mergePartition = DataPartition.UNPARTITIONED;
+        if (CollectionUtils.isNotEmpty(execPartitionExpressions)) {

Review Comment:
   yes, u r right



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916497641


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -112,62 +131,92 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
 
     /**
      * Translate Agg.
+     * todo: support DISTINCT
      */
     @Override
-    public PlanFragment visitPhysicalAggregation(
-            PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
-
-        PlanFragment inputPlanFragment = visit(agg.child(0), context);
-
-        AggregationNode aggregationNode;
-        List<Slot> slotList = new ArrayList<>();
-        PhysicalAggregation physicalAggregation = agg.getOperator();
-        AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
-
-        List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
+    public PlanFragment visitPhysicalAggregate(
+            PhysicalUnaryPlan<PhysicalAggregate, Plan> aggregate, PlanTranslatorContext context) {
+        PlanFragment inputPlanFragment = visit(aggregate.child(0), context);
+        PhysicalAggregate physicalAggregate = aggregate.getOperator();
+
+        // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
+        //    1. group by expressions: removing duplicate expressions add to tuple
+        //    2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
+        //       e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
+        //    We need:
+        //    1. add a project after agg, if agg function is not the top output expression.
+        //    2. introduce canonicalized, semanticEquals and deterministic in Expression
+        //       for removing duplicate.
+        List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList();
+        List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList();
+
+        // 1. generate slot reference for each group expression
+        List<SlotReference> groupSlotList = Lists.newArrayList();
+        for (Expression e : groupByExpressionList) {
+            if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.anyMatch(e::equals))) {
+                groupSlotList.add((SlotReference) e);
+            } else {
+                groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList()));
+            }
+        }
         ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
-                // Since output of plan doesn't contain the slots of groupBy, which is actually needed by
-                // the BE execution, so we have to collect them and add to the slotList to generate corresponding
-                // TupleDesc.
-                .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance)))
                 .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new));
-        slotList.addAll(agg.getOutput());
-        TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null);
-
-        List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList();
-        ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream()
-                .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
+        // 2. collect agg functions and generate agg function to slot reference map

Review Comment:
    in translator, this could not happen at all



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] EmmyMiao87 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
EmmyMiao87 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916687709


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java:
##########
@@ -52,6 +54,14 @@ public List<Expression> getArguments() {
         return children();
     }
 
+    @Override
+    public String sql() throws UnboundException {

Review Comment:
   The three words toString, toSql, toDigest seem to be unified



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] EmmyMiao87 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
EmmyMiao87 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916467348


##########
fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java:
##########
@@ -1257,6 +1257,16 @@ public String forJSON(String str) {
 
     @Override
     public void finalizeImplForNereids() throws AnalysisException {
-        super.finalizeImplForNereids();
+        if (fnName.getFunction().equalsIgnoreCase("sum")) {
+            // Prevent the cast type in vector exec engine
+            Type childType = getChild(0).type.getMaxResolutionType();
+            fn = getBuiltinFunction(fnName.getFunction(), new Type[]{childType},
+                    Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
+            type = fn.getReturnType();
+        }
+    }
+
+    public void setMergeAggFn(boolean mergeAggFn) {

Review Comment:
   It seems that this value should be set when expr is translated



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/LogicalPlanAdapter.java:
##########
@@ -77,4 +77,8 @@ public void setResultExprs(List<Expr> resultExprs) {
     public void setColLabels(ArrayList<String> colLabels) {
         this.colLabels = colLabels;
     }
+
+    public String toDigest() {

Review Comment:
   // Add todo comment



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java:
##########
@@ -66,32 +65,28 @@ public TupleDescriptor generateTupleDesc() {
         return descTable.createTupleDescriptor();
     }
 
-    public PlanNodeId nextNodeId() {
+    public PlanFragmentId nextFragmentId() {
+        return fragmentIdGenerator.getNextId();
+    }
+
+    public PlanNodeId nextPlanNodeId() {
         return nodeIdGenerator.getNextId();
     }
 
     public SlotDescriptor addSlotDesc(TupleDescriptor t) {
         return descTable.addSlotDescriptor(t);
     }
 
-    public SlotDescriptor addSlotDesc(TupleDescriptor t, int id) {
-        return descTable.addSlotDescriptor(t, id);
-    }
-
-    public PlanFragmentId nextFragmentId() {
-        return fragmentIdGenerator.getNextId();
-    }
-
     public void addPlanFragment(PlanFragment planFragment) {
         this.planFragmentList.add(planFragment);
     }
 
-    public void addSlotRefMapping(Expression expression, Expr expr) {
-        expressionToExecExpr.put(expression, expr);
+    public void addIdPair(ExprId exprId, SlotRef slotRef) {

Review Comment:
   ```suggestion
       public void addExprIdPair(ExprId exprId, SlotRef slotRef) {
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -288,23 +333,24 @@ public PlanFragment visitPhysicalHashJoin(
             leftFragment.setPlanRoot(crossJoinNode);
             context.addPlanFragment(leftFragment);
             return leftFragment;
+        } else {
+            Expression eqJoinExpression = physicalHashJoin.getCondition().get();
+            List<Expr> execEqConjunctList = ExpressionUtils.extractConjunct(eqJoinExpression).stream()
+                    .map(EqualTo.class::cast)
+                    .map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput()))
+                    .map(e -> ExpressionTranslator.translate(e, context))
+                    .collect(Collectors.toList());
+
+            HashJoinNode hashJoinNode = new HashJoinNode(context.nextPlanNodeId(), leftFragmentPlanRoot,
+                    rightFragmentPlanRoot,
+                    JoinType.toJoinOperator(physicalHashJoin.getJoinType()), execEqConjunctList, Lists.newArrayList());
+
+            hashJoinNode.setDistributionMode(DistributionMode.BROADCAST);

Review Comment:
   Can plannode also make a finalizeForNereids function to uniformly process some properties that need to be set?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -263,23 +313,18 @@ public PlanFragment visitPhysicalHashJoin(
         // NOTICE: We must visit from right to left, to ensure the last fragment is root fragment
         PlanFragment rightFragment = visit(hashJoin.child(1), context);
         PlanFragment leftFragment = visit(hashJoin.child(0), context);
-        PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
-
-        //        Expression predicateExpr = physicalHashJoin.getCondition().get();
-        //        List<Expression> eqExprList = Utils.getEqConjuncts(hashJoin.child(0).getOutput(),
-        //                hashJoin.child(1).getOutput(), predicateExpr);
-        JoinType joinType = physicalHashJoin.getJoinType();
-
         PlanNode leftFragmentPlanRoot = leftFragment.getPlanRoot();
         PlanNode rightFragmentPlanRoot = rightFragment.getPlanRoot();
+        PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
+        JoinType joinType = physicalHashJoin.getJoinType();
 
         if (joinType.equals(JoinType.CROSS_JOIN)

Review Comment:
   Cross join can only use the physical operator nestedloopjoin, so this situation should be handled in visitPhysicalNestedLoopJoin, not here.



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java:
##########
@@ -52,6 +54,14 @@ public List<Expression> getArguments() {
         return children();
     }
 
+    @Override
+    public String sql() throws UnboundException {

Review Comment:
   ```suggestion
       public String toSql() throws UnboundException {
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -112,62 +131,92 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
 
     /**
      * Translate Agg.
+     * todo: support DISTINCT
      */
     @Override
-    public PlanFragment visitPhysicalAggregation(
-            PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
-
-        PlanFragment inputPlanFragment = visit(agg.child(0), context);
-
-        AggregationNode aggregationNode;
-        List<Slot> slotList = new ArrayList<>();
-        PhysicalAggregation physicalAggregation = agg.getOperator();
-        AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
-
-        List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
+    public PlanFragment visitPhysicalAggregate(
+            PhysicalUnaryPlan<PhysicalAggregate, Plan> aggregate, PlanTranslatorContext context) {
+        PlanFragment inputPlanFragment = visit(aggregate.child(0), context);
+        PhysicalAggregate physicalAggregate = aggregate.getOperator();
+
+        // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
+        //    1. group by expressions: removing duplicate expressions add to tuple
+        //    2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
+        //       e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
+        //    We need:
+        //    1. add a project after agg, if agg function is not the top output expression.
+        //    2. introduce canonicalized, semanticEquals and deterministic in Expression
+        //       for removing duplicate.
+        List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList();
+        List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList();
+
+        // 1. generate slot reference for each group expression
+        List<SlotReference> groupSlotList = Lists.newArrayList();
+        for (Expression e : groupByExpressionList) {
+            if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.anyMatch(e::equals))) {
+                groupSlotList.add((SlotReference) e);
+            } else {
+                groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList()));
+            }
+        }
         ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
-                // Since output of plan doesn't contain the slots of groupBy, which is actually needed by
-                // the BE execution, so we have to collect them and add to the slotList to generate corresponding
-                // TupleDesc.
-                .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance)))
                 .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new));
-        slotList.addAll(agg.getOutput());
-        TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null);
-
-        List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList();
-        ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream()
-                .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
+        // 2. collect agg functions and generate agg function to slot reference map

Review Comment:
   What if the agg function does not appear in output but in having?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java:
##########
@@ -122,6 +123,13 @@ public Expr visitLessThanEqual(LessThanEqual lessThanEqual, PlanTranslatorContex
                 lessThanEqual.child(1).accept(this, context));
     }
 
+    @Override
+    public Expr visitNullSafeEqual(NullSafeEqual nullSafeEqual, PlanTranslatorContext context) {
+        return new BinaryPredicate(Operator.EQ_FOR_NULL,

Review Comment:
   Use `NullSafeEqual` instead



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java:
##########
@@ -127,14 +129,19 @@ public boolean equals(Object o) {
             return false;
         }
         LogicalAggregate that = (LogicalAggregate) o;
-        return Objects.equals(groupByExprList, that.groupByExprList)
+        return Objects.equals(groupByExpressionList, that.groupByExpressionList)
                 && Objects.equals(outputExpressionList, that.outputExpressionList)
                 && Objects.equals(partitionExprList, that.partitionExprList)
                 && aggPhase == that.aggPhase;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExprList, outputExpressionList, partitionExprList, aggPhase);
+        return Objects.hash(groupByExpressionList, outputExpressionList, partitionExprList, aggPhase);
+    }
+
+    public LogicalAggregate withGroupByAndOutput(List<Expression> groupByExprList,

Review Comment:
   What's different between `withGroupByAndOutput` function and `LogicalAggregate` constructor on Line57



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] EmmyMiao87 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
EmmyMiao87 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916493625


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -17,144 +17,132 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.analysis.FunctionName;
-import org.apache.doris.catalog.Catalog;
-import org.apache.doris.catalog.Function;
-import org.apache.doris.catalog.Function.CompareMode;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.operators.Operator;
 import org.apache.doris.nereids.operators.plans.AggPhase;
 import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
 
-import com.clearspring.analytics.util.Lists;
-import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
 /**
- * TODO: if instance count is 1, shouldn't disassemble the agg operator
  * Used to generate the merge agg node for distributed execution.
- * Do this in following steps:
- *  1. clone output expr list, find all agg function
- *  2. set found agg function intermediaType
- *  3. create new child plan rooted at new local agg
- *  4. update the slot referenced by expr of merge agg
- *  5. create plan rooted at merge agg, return it.
+ * If we have a query: SELECT SUM(v) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(v1 * v2) + 1], groupByExpr: [k + 1])
+ *   +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(a) + 1], groupByExpr: [b])
+ *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
+ *       +-- childPlan
+ *
+ * TODO:
+ *     1. use different class represent different phase aggregate
+ *     2. if instance count is 1, shouldn't disassemble the agg operator
+ *     3. we need another rule to removing duplicated expressions in group by expression list
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
     @Override
     public Rule<Plan> build() {
         return logicalAggregate().when(p -> {
-            LogicalAggregate logicalAggregation = p.getOperator();
-            return !logicalAggregation.isDisassembled();
+            LogicalAggregate logicalAggregate = p.getOperator();
+            return !logicalAggregate.isDisassembled();
         }).thenApply(ctx -> {
-            Plan plan = ctx.root;
-            Operator operator = plan.getOperator();
-            LogicalAggregate agg = (LogicalAggregate) operator;
-            List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            LogicalUnaryPlan<LogicalAggregate, GroupPlan> plan = ctx.root;
+            LogicalAggregate aggregate = plan.getOperator();
+            List<NamedExpression> originOutputExprs = aggregate.getOutputExpressionList();
+            List<Expression> originGroupByExprs = aggregate.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> originAggregateFunctionWithAlias = Maps.newHashMap();
+            for (NamedExpression originOutputExpr : originOutputExprs) {

Review Comment:
   Keep todo in here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915563444


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -64,49 +65,69 @@ public Rule<Plan> build() {
             Operator operator = plan.getOperator();
             LogicalAggregate agg = (LogicalAggregate) operator;
             List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            List<Expression> groupByExpressionList = agg.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap();
+            for (NamedExpression outputExpression : outputExpressionList) {
+                outputExpression.foreach(e -> {
+                    if (e instanceof AggregateFunction) {
+                        AggregateFunction a = (AggregateFunction) e;
+                        aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql()));
+                    }
+                });
+            }
+
+            List<Expression> updateGroupByExpressionList = groupByExpressionList;
+            List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream()
+                    .map(g -> new Alias<>(g, g.sql()))
+                    .collect(Collectors.toList());
+
+            List<NamedExpression> updateOutputExpressionList = Lists.newArrayList();
+            updateOutputExpressionList.addAll(updateGroupByAliasList);
+            updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values());
+
+            List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream()
+                    .map(NamedExpression::toSlot).collect(Collectors.toList());
+
+            List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList();
+            for (NamedExpression o : outputExpressionList) {
+                if (o.contains(AggregateFunction.class::isInstance)) {
+                    mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter()
+                            .visit(o, aggregateFunctionAliasMap));
+                } else {
+                    for (int i = 0; i < updateGroupByAliasList.size(); i++) {
+                        // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+                        //    e.g. a + 1 + 2 in output expression should be replaced by
+                        //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+                        //   currently, we could only handle output expression same with group by expression
+                        if (o instanceof SlotReference) {
+                            // a in output expression will be SLotReference
+                            if (o.equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        } else if (o instanceof Alias) {
+                            // a + 1 in output expression will be Alias
+                            if (o.child(0).equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        }
                     }
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
+
             LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+                    updateGroupByExpressionList,
+                    updateOutputExpressionList,

Review Comment:
   i use the names in the stale planner. imo, local and global is better than update and merge. but as discussion before, we want to reuse names in stale planner as much as possible. so i reserve these names.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] github-actions[bot] commented on pull request #10659: [enhancement](nereids) make SSB works

Posted by GitBox <gi...@apache.org>.
github-actions[bot] commented on PR #10659:
URL: https://github.com/apache/doris/pull/10659#issuecomment-1179929007

   PR approved by at least one committer and no changes requested.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] EmmyMiao87 merged pull request #10659: [enhancement](nereids) make SSB works

Posted by GitBox <gi...@apache.org>.
EmmyMiao87 merged PR #10659:
URL: https://github.com/apache/doris/pull/10659


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] 924060929 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
924060929 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915456668


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
      * Translate Agg.
      */
     @Override
-    public PlanFragment visitPhysicalAggregation(
-            PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
-
+    public PlanFragment visitPhysicalAggregate(
+            PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) {
         PlanFragment inputPlanFragment = visit(agg.child(0), context);
-
-        AggregationNode aggregationNode;
-        List<Slot> slotList = new ArrayList<>();
-        PhysicalAggregation physicalAggregation = agg.getOperator();
-        AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
-
-        List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
+        PhysicalAggregate physicalAggregate = agg.getOperator();
+
+        // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
+        //    1. group by expressions: removing duplicate expressions add to tuple
+        //    2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
+        //       e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
+        //    We need:
+        //    1. add a project after agg, if output expressions include agg function as a expression tree leaf.

Review Comment:
   ```suggestion
           //    1. add a project after agg, if agg function is not the top output expression.
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -100,8 +103,25 @@ private static Expression swapEqualToForChildrenOrder(EqualTo<?, ?> equalTo, Lis
         }
     }
 
-    public void translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
-        visit(physicalPlan, context);
+    /**
+     * Translate Nereids Physical Plan tree to Stale Planner PlanFragment tree.
+     *
+     * @param physicalPlan Nereids Physical Plan tree
+     * @param context context to help translate
+     * @return Stale Planner PlanFragment tree
+     */
+    public PlanFragment translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
+        PlanFragment rootFragment = visit(physicalPlan, context);
+        if (rootFragment.isPartitioned() && rootFragment.getPlanRoot().getNumInstances() > 1) {
+            rootFragment = createMergeFragment(rootFragment, context);
+            context.addPlanFragment(rootFragment);

Review Comment:
   rename `createMergeFragment()` to `exchangeToMergeFragment()`
   and move `context.addPlanFragment(rootFragment)` to `exchangeToMergeFragment(rootFragment, context)`
   
   



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
      * Translate Agg.
      */
     @Override
-    public PlanFragment visitPhysicalAggregation(
-            PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
-
+    public PlanFragment visitPhysicalAggregate(
+            PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) {
         PlanFragment inputPlanFragment = visit(agg.child(0), context);
-
-        AggregationNode aggregationNode;
-        List<Slot> slotList = new ArrayList<>();
-        PhysicalAggregation physicalAggregation = agg.getOperator();
-        AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
-
-        List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
+        PhysicalAggregate physicalAggregate = agg.getOperator();
+
+        // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
+        //    1. group by expressions: removing duplicate expressions add to tuple
+        //    2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
+        //       e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
+        //    We need:
+        //    1. add a project after agg, if output expressions include agg function as a expression tree leaf.
+        //    2. introduce canonicalized, semanticEquals and deterministic in Expression
+        //       for removing duplicate.
+        List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList();
+        List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList();
+
+        // 1. generate slot reference for each group expression
+        List<SlotReference> groupSlotList = Lists.newArrayList();
+        for (Expression e : groupByExpressionList) {
+            if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.contains(e::equals))) {
+                groupSlotList.add((SlotReference) e);
+            } else {
+                groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList()));
+            }
+        }
         ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
-                // Since output of plan doesn't contain the slots of groupBy, which is actually needed by
-                // the BE execution, so we have to collect them and add to the slotList to generate corresponding
-                // TupleDesc.
-                .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance)))
                 .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new));
-        slotList.addAll(agg.getOutput());
-        TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null);
-
-        List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList();
-        ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream()
-                .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
+        // 2. collect agg functions and generate agg function to slot reference map
+        List<Slot> aggFunctionOutput = Lists.newArrayList();
+        List<AggregateFunction> aggregateFunctionList = outputExpressionList.stream()
+                .filter(o -> o.contains(AggregateFunction.class::isInstance))
+                .peek(o -> aggFunctionOutput.add(o.toSlot()))
+                .map(o -> (List<AggregateFunction>) o.collect(AggregateFunction.class::isInstance))
                 .flatMap(List::stream)
+                .collect(Collectors.toList());
+        ArrayList<FunctionCallExpr> execAggExpressions = aggregateFunctionList.stream()
                 .map(x -> (FunctionCallExpr) ExpressionTranslator.translate(x, context))
                 .collect(Collectors.toCollection(ArrayList::new));
 
-        List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList();
+        // 3. generate output tuple
+        // TODO: currently, we only support sum(a), if we want to support sum(a) + 1, we need to
+        //  split merge agg to project(agg) and generate tuple like what first phase agg do.
+        List<Slot> slotList = Lists.newArrayList();
+        TupleDescriptor outputTupleDesc;
+        if (agg.getOperator().getAggPhase() == AggPhase.FIRST_MERGE) {
+            slotList.addAll(groupSlotList);
+            slotList.addAll(aggFunctionOutput);
+            outputTupleDesc = generateTupleDesc(slotList, null, context);
+        } else {
+            outputTupleDesc = generateTupleDesc(agg.getOutput(), null, context);
+        }
+
+        // process partition list
+        List<Expression> partitionExpressionList = physicalAggregate.getPartitionExprList();
         List<Expr> execPartitionExpressions = partitionExpressionList.stream()
-                .map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+                .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+        DataPartition mergePartition = DataPartition.UNPARTITIONED;
+        if (CollectionUtils.isNotEmpty(execPartitionExpressions)) {

Review Comment:
   Store execPartitionExpressions in the merge LogicalAggregate doesn't seem reasonable, because sender(input fragment) execute the partition expression, not receiver(merge fragment).
   



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -364,25 +416,43 @@ private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanTranslatorCon
     }
 
     private PlanFragment createParentFragment(PlanFragment childFragment, DataPartition parentPartition,

Review Comment:
   ```suggestion
       private PlanFragment exchangeToMergeFragment(PlanFragment childFragment, DataPartition parentPartition,
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -364,25 +416,43 @@ private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanTranslatorCon
     }
 
     private PlanFragment createParentFragment(PlanFragment childFragment, DataPartition parentPartition,
-            PlanTranslatorContext ctx) {
-        ExchangeNode exchangeNode = new ExchangeNode(ctx.nextNodeId(), childFragment.getPlanRoot(), false);
+            PlanTranslatorContext context) {
+        ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false);
         exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances());
-        PlanFragment parentFragment = new PlanFragment(ctx.nextFragmentId(), exchangeNode, parentPartition);
+        PlanFragment parentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, parentPartition);
         childFragment.setDestination(exchangeNode);
         childFragment.setOutputPartition(parentPartition);
+        context.addPlanFragment(parentFragment);
         return parentFragment;
     }
 
     private void connectChildFragment(PlanNode node, int childIdx,
             PlanFragment parentFragment, PlanFragment childFragment,
             PlanTranslatorContext context) {
-        ExchangeNode exchangeNode = new ExchangeNode(context.nextNodeId(), childFragment.getPlanRoot(), false);
+        ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false);
         exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances());
         exchangeNode.setFragment(parentFragment);
         node.setChild(childIdx, exchangeNode);
         childFragment.setDestination(exchangeNode);
     }
 
+    /**
+     * Return unpartitioned fragment that merges the input fragment's output via
+     * an ExchangeNode.
+     * Requires that input fragment be partitioned.
+     */
+    private PlanFragment createMergeFragment(PlanFragment inputFragment, PlanTranslatorContext context) {

Review Comment:
   ```suggestion
       private PlanFragment exchangeToMergeFragment(PlanFragment inputFragment, PlanTranslatorContext context) {
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
      * Translate Agg.
      */
     @Override
-    public PlanFragment visitPhysicalAggregation(
-            PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
-
+    public PlanFragment visitPhysicalAggregate(
+            PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) {
         PlanFragment inputPlanFragment = visit(agg.child(0), context);
-
-        AggregationNode aggregationNode;
-        List<Slot> slotList = new ArrayList<>();
-        PhysicalAggregation physicalAggregation = agg.getOperator();
-        AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
-
-        List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
+        PhysicalAggregate physicalAggregate = agg.getOperator();
+
+        // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
+        //    1. group by expressions: removing duplicate expressions add to tuple
+        //    2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
+        //       e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
+        //    We need:
+        //    1. add a project after agg, if output expressions include agg function as a expression tree leaf.
+        //    2. introduce canonicalized, semanticEquals and deterministic in Expression
+        //       for removing duplicate.
+        List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList();
+        List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList();
+
+        // 1. generate slot reference for each group expression
+        List<SlotReference> groupSlotList = Lists.newArrayList();
+        for (Expression e : groupByExpressionList) {
+            if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.contains(e::equals))) {
+                groupSlotList.add((SlotReference) e);
+            } else {
+                groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList()));
+            }
+        }
         ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
-                // Since output of plan doesn't contain the slots of groupBy, which is actually needed by
-                // the BE execution, so we have to collect them and add to the slotList to generate corresponding
-                // TupleDesc.
-                .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance)))
                 .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new));
-        slotList.addAll(agg.getOutput());
-        TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null);
-
-        List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList();
-        ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream()
-                .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
+        // 2. collect agg functions and generate agg function to slot reference map
+        List<Slot> aggFunctionOutput = Lists.newArrayList();
+        List<AggregateFunction> aggregateFunctionList = outputExpressionList.stream()
+                .filter(o -> o.contains(AggregateFunction.class::isInstance))
+                .peek(o -> aggFunctionOutput.add(o.toSlot()))
+                .map(o -> (List<AggregateFunction>) o.collect(AggregateFunction.class::isInstance))
                 .flatMap(List::stream)
+                .collect(Collectors.toList());
+        ArrayList<FunctionCallExpr> execAggExpressions = aggregateFunctionList.stream()
                 .map(x -> (FunctionCallExpr) ExpressionTranslator.translate(x, context))
                 .collect(Collectors.toCollection(ArrayList::new));
 
-        List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList();
+        // 3. generate output tuple
+        // TODO: currently, we only support sum(a), if we want to support sum(a) + 1, we need to
+        //  split merge agg to project(agg) and generate tuple like what first phase agg do.
+        List<Slot> slotList = Lists.newArrayList();
+        TupleDescriptor outputTupleDesc;
+        if (agg.getOperator().getAggPhase() == AggPhase.FIRST_MERGE) {
+            slotList.addAll(groupSlotList);
+            slotList.addAll(aggFunctionOutput);
+            outputTupleDesc = generateTupleDesc(slotList, null, context);
+        } else {
+            outputTupleDesc = generateTupleDesc(agg.getOutput(), null, context);
+        }
+
+        // process partition list
+        List<Expression> partitionExpressionList = physicalAggregate.getPartitionExprList();
         List<Expr> execPartitionExpressions = partitionExpressionList.stream()
-                .map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+                .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+        DataPartition mergePartition = DataPartition.UNPARTITIONED;
+        if (CollectionUtils.isNotEmpty(execPartitionExpressions)) {
+            mergePartition = DataPartition.hashPartitioned(execGroupingExpressions);
+        }
+
         // todo: support DISTINCT
+        AggregationNode aggregationNode;
         AggregateInfo aggInfo;
-        switch (phase) {
+        switch (physicalAggregate.getAggPhase()) {
             case FIRST:
                 aggInfo = AggregateInfo.create(execGroupingExpressions, execAggExpressions, outputTupleDesc,
                         outputTupleDesc, AggregateInfo.AggPhase.FIRST);
-                aggregationNode = new AggregationNode(context.nextNodeId(), inputPlanFragment.getPlanRoot(), aggInfo);
+                aggregationNode = new AggregationNode(context.nextPlanNodeId(),
+                        inputPlanFragment.getPlanRoot(), aggInfo);
                 aggregationNode.unsetNeedsFinalize();
-                aggregationNode.setUseStreamingPreagg(physicalAggregation.isUsingStream());
+                aggregationNode.setUseStreamingPreagg(physicalAggregate.isUsingStream());
                 aggregationNode.setIntermediateTuple();
-                if (!partitionExpressionList.isEmpty()) {
-                    inputPlanFragment.setOutputPartition(DataPartition.hashPartitioned(execPartitionExpressions));
-                }
-                break;
+                inputPlanFragment.setPlanRoot(aggregationNode);
+                PlanFragment mergeFragment = createParentFragment(inputPlanFragment, mergePartition, context);

Review Comment:
   ```suggestion
                   PlanFragment mergeFragment = exchangeToMergeFragment(inputPlanFragment, mergePartition, context);
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -115,46 +136,16 @@ public Rule<Plan> build() {
         }).toRule(RuleType.AGGREGATE_DISASSEMBLE);
     }
 
-    private org.apache.doris.catalog.AggregateFunction findAggFunc(AggregateFunction functionCall) {
-        FunctionName functionName = new FunctionName(functionCall.getName());
-        List<Expression> expressionList = functionCall.getArguments();
-        List<Type> staleTypeList = expressionList.stream().map(Expression::getDataType)
-                .map(DataType::toCatalogDataType).collect(Collectors.toList());
-        Function staleFuncDesc = new Function(functionName, staleTypeList,
-                functionCall.getDataType().toCatalogDataType(),
-                // I think an aggregate function will never have a variable length parameters
-                false);
-        Function staleFunc = Catalog.getCurrentCatalog()
-                .getFunction(staleFuncDesc, CompareMode.IS_IDENTICAL);
-        Preconditions.checkArgument(staleFunc instanceof org.apache.doris.catalog.AggregateFunction);
-        return  (org.apache.doris.catalog.AggregateFunction) staleFunc;
-    }
-
-    @SuppressWarnings("unchecked")
-    private <T extends Expression> void replaceSlot(Map<Slot, Slot> staleToNew,
-            List<T> expressionList, Expression root, int index) {
-        if (index != -1) {
-            if (root instanceof Slot) {
-                Slot v = staleToNew.get(root);
-                if (v == null) {
-                    return;
-                }
-                expressionList.set(index, (T) v);
-                return;
-            }
-        }
-        List<Expression> children = root.children();
-        for (int i = 0; i < children.size(); i++) {
-            Expression cur = children.get(i);
-            if (!(cur instanceof Slot)) {
-                replaceSlot(staleToNew, expressionList, cur, -1);
-                continue;
-            }
-            Expression v = staleToNew.get(cur);
-            if (v == null) {
-                continue;
+    private static class AggregateFunctionParamsRewriter
+            extends DefaultExpressionRewriter<Map<AggregateFunction, NamedExpression>> {
+        @Override
+        public Expression visitBoundFunction(BoundFunction boundFunction,

Review Comment:
   you can use `visitAggregateFunction` function



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -64,49 +65,69 @@ public Rule<Plan> build() {
             Operator operator = plan.getOperator();
             LogicalAggregate agg = (LogicalAggregate) operator;
             List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            List<Expression> groupByExpressionList = agg.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap();
+            for (NamedExpression outputExpression : outputExpressionList) {
+                outputExpression.foreach(e -> {
+                    if (e instanceof AggregateFunction) {
+                        AggregateFunction a = (AggregateFunction) e;
+                        aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql()));
+                    }
+                });
+            }
+
+            List<Expression> updateGroupByExpressionList = groupByExpressionList;
+            List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream()
+                    .map(g -> new Alias<>(g, g.sql()))
+                    .collect(Collectors.toList());
+
+            List<NamedExpression> updateOutputExpressionList = Lists.newArrayList();
+            updateOutputExpressionList.addAll(updateGroupByAliasList);
+            updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values());
+
+            List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream()
+                    .map(NamedExpression::toSlot).collect(Collectors.toList());
+
+            List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList();
+            for (NamedExpression o : outputExpressionList) {
+                if (o.contains(AggregateFunction.class::isInstance)) {
+                    mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter()
+                            .visit(o, aggregateFunctionAliasMap));
+                } else {
+                    for (int i = 0; i < updateGroupByAliasList.size(); i++) {
+                        // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+                        //    e.g. a + 1 + 2 in output expression should be replaced by
+                        //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+                        //   currently, we could only handle output expression same with group by expression
+                        if (o instanceof SlotReference) {
+                            // a in output expression will be SLotReference
+                            if (o.equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        } else if (o instanceof Alias) {
+                            // a + 1 in output expression will be Alias
+                            if (o.child(0).equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        }
                     }
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
+
             LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+                    updateGroupByExpressionList,
+                    updateOutputExpressionList,

Review Comment:
   I think this variable names and compute logic is confuse, how about this:
   1. localGroupByExprs = originGloupByExprs
   2. localOutputExprs = originOutput.withAlias
   3. globalGroupByWithAlias = originGloupByExprs.withAlias
   4. globalOutputAlias = originOutput.replaceAggregateFunctionArguments
   
   this advantage is
   1. variableName contains position and member information
   2. assign statement contains the most simple compute logical
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] 924060929 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
924060929 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915441815


##########
fe/fe-core/src/test/java/org/apache/doris/nereids/AnalyzeSSBTest.java:
##########
@@ -161,10 +161,7 @@ private void executeRewriteBottomUpJob(PlannerContext plannerContext, RuleFactor
     }
 
     private boolean checkBound(LogicalPlan root) {
-        if (!checkPlanBound(root))  {
-            return false;
-        }
-        return true;
+        return checkPlanBound(root);

Review Comment:
   I think we can remove checkBound or checkPlanBound  function



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java:
##########
@@ -91,4 +91,22 @@ default <T> T collect(Predicate<TreeNode<NODE_TYPE>> predicate) {
         return (T) result.build();
     }
 
+    /**
+     * Test whether this tree satisfied predicate.
+     *
+     * @param predicate test condition
+     * @return true if satisfied
+     */
+    default boolean contains(Predicate<TreeNode<NODE_TYPE>> predicate) {

Review Comment:
   contains function repeated with anyMatch function
   
   



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java:
##########
@@ -57,6 +57,6 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
 
 
     public String toString() {
-        return sql();
+        return left().toString() + ' ' + getArithmeticOperator().toString() + ' ' + right().sql();

Review Comment:
   why left use toString(), but right use sql()



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Sum.java:
##########
@@ -51,7 +51,7 @@ public DataType getDataType() {
 
     @Override
     public boolean nullable() {
-        return false;
+        return child().nullable();

Review Comment:
   sum aggregate function can return null value?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916506034


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java:
##########
@@ -127,14 +129,19 @@ public boolean equals(Object o) {
             return false;
         }
         LogicalAggregate that = (LogicalAggregate) o;
-        return Objects.equals(groupByExprList, that.groupByExprList)
+        return Objects.equals(groupByExpressionList, that.groupByExpressionList)
                 && Objects.equals(outputExpressionList, that.outputExpressionList)
                 && Objects.equals(partitionExprList, that.partitionExprList)
                 && aggPhase == that.aggPhase;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExprList, outputExpressionList, partitionExprList, aggPhase);
+        return Objects.hash(groupByExpressionList, outputExpressionList, partitionExprList, aggPhase);
+    }
+
+    public LogicalAggregate withGroupByAndOutput(List<Expression> groupByExprList,

Review Comment:
   constructor set `disassembled` and `aggPhase` to default value. This function reserve them current value.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] EmmyMiao87 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
EmmyMiao87 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916689097


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -263,23 +313,18 @@ public PlanFragment visitPhysicalHashJoin(
         // NOTICE: We must visit from right to left, to ensure the last fragment is root fragment
         PlanFragment rightFragment = visit(hashJoin.child(1), context);
         PlanFragment leftFragment = visit(hashJoin.child(0), context);
-        PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
-
-        //        Expression predicateExpr = physicalHashJoin.getCondition().get();
-        //        List<Expression> eqExprList = Utils.getEqConjuncts(hashJoin.child(0).getOutput(),
-        //                hashJoin.child(1).getOutput(), predicateExpr);
-        JoinType joinType = physicalHashJoin.getJoinType();
-
         PlanNode leftFragmentPlanRoot = leftFragment.getPlanRoot();
         PlanNode rightFragmentPlanRoot = rightFragment.getPlanRoot();
+        PhysicalHashJoin physicalHashJoin = hashJoin.getOperator();
+        JoinType joinType = physicalHashJoin.getJoinType();
 
         if (joinType.equals(JoinType.CROSS_JOIN)

Review Comment:
   Then if we encounter a `PhysicalHashJoin` whose `JoinType` is cross join, an error should be reported directly here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915557011


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -100,8 +103,25 @@ private static Expression swapEqualToForChildrenOrder(EqualTo<?, ?> equalTo, Lis
         }
     }
 
-    public void translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
-        visit(physicalPlan, context);
+    /**
+     * Translate Nereids Physical Plan tree to Stale Planner PlanFragment tree.
+     *
+     * @param physicalPlan Nereids Physical Plan tree
+     * @param context context to help translate
+     * @return Stale Planner PlanFragment tree
+     */
+    public PlanFragment translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) {
+        PlanFragment rootFragment = visit(physicalPlan, context);
+        if (rootFragment.isPartitioned() && rootFragment.getPlanRoot().getNumInstances() > 1) {
+            rootFragment = createMergeFragment(rootFragment, context);
+            context.addPlanFragment(rootFragment);

Review Comment:
   good way



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] 924060929 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
924060929 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915532556


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -64,49 +65,69 @@ public Rule<Plan> build() {
             Operator operator = plan.getOperator();
             LogicalAggregate agg = (LogicalAggregate) operator;
             List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            List<Expression> groupByExpressionList = agg.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap();
+            for (NamedExpression outputExpression : outputExpressionList) {
+                outputExpression.foreach(e -> {
+                    if (e instanceof AggregateFunction) {
+                        AggregateFunction a = (AggregateFunction) e;
+                        aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql()));
+                    }
+                });
+            }
+
+            List<Expression> updateGroupByExpressionList = groupByExpressionList;
+            List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream()
+                    .map(g -> new Alias<>(g, g.sql()))
+                    .collect(Collectors.toList());
+
+            List<NamedExpression> updateOutputExpressionList = Lists.newArrayList();
+            updateOutputExpressionList.addAll(updateGroupByAliasList);
+            updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values());
+
+            List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream()
+                    .map(NamedExpression::toSlot).collect(Collectors.toList());
+
+            List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList();
+            for (NamedExpression o : outputExpressionList) {
+                if (o.contains(AggregateFunction.class::isInstance)) {
+                    mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter()
+                            .visit(o, aggregateFunctionAliasMap));
+                } else {
+                    for (int i = 0; i < updateGroupByAliasList.size(); i++) {
+                        // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+                        //    e.g. a + 1 + 2 in output expression should be replaced by
+                        //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+                        //   currently, we could only handle output expression same with group by expression
+                        if (o instanceof SlotReference) {
+                            // a in output expression will be SLotReference
+                            if (o.equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        } else if (o instanceof Alias) {
+                            // a + 1 in output expression will be Alias
+                            if (o.child(0).equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        }
                     }
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
+
             LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+                    updateGroupByExpressionList,
+                    updateOutputExpressionList,

Review Comment:
   I think this variable names and compute logic is confuse, how about this:
   1. localGroupByExprs = originGloupByExprs
   2. localOutputExprs = originOutput.withAlias
   3. globalGroupByWithAlias = originGloupByExprs.withAlias
   4. globalOutputWithAlias = originOutput.replaceAggregateFunctionArgumentsToAliasReference
   
   this advantage is
   1. variableName contains position and member information
   2. assign statement contains the most simple compute logical
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916496635


##########
fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java:
##########
@@ -1257,6 +1257,16 @@ public String forJSON(String str) {
 
     @Override
     public void finalizeImplForNereids() throws AnalysisException {
-        super.finalizeImplForNereids();
+        if (fnName.getFunction().equalsIgnoreCase("sum")) {
+            // Prevent the cast type in vector exec engine
+            Type childType = getChild(0).type.getMaxResolutionType();
+            fn = getBuiltinFunction(fnName.getFunction(), new Type[]{childType},
+                    Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
+            type = fn.getReturnType();
+        }
+    }
+
+    public void setMergeAggFn(boolean mergeAggFn) {

Review Comment:
   not exactly, expression translator should not know about aggregate phase info



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/UnboundFunction.java:
##########
@@ -52,6 +54,14 @@ public List<Expression> getArguments() {
         return children();
     }
 
+    @Override
+    public String sql() throws UnboundException {

Review Comment:
   sql is better i think



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916513342


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -17,144 +17,132 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.analysis.FunctionName;
-import org.apache.doris.catalog.Catalog;
-import org.apache.doris.catalog.Function;
-import org.apache.doris.catalog.Function.CompareMode;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.operators.Operator;
 import org.apache.doris.nereids.operators.plans.AggPhase;
 import org.apache.doris.nereids.operators.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnaryPlan;
 
-import com.clearspring.analytics.util.Lists;
-import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
 /**
- * TODO: if instance count is 1, shouldn't disassemble the agg operator
  * Used to generate the merge agg node for distributed execution.
- * Do this in following steps:
- *  1. clone output expr list, find all agg function
- *  2. set found agg function intermediaType
- *  3. create new child plan rooted at new local agg
- *  4. update the slot referenced by expr of merge agg
- *  5. create plan rooted at merge agg, return it.
+ * If we have a query: SELECT SUM(v) + 1 FROM t GROUP BY k + 1
+ * the initial plan is:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(v1 * v2) + 1], groupByExpr: [k + 1])
+ *   +-- childPlan
+ * we should rewrite to:
+ *   Aggregate(phase: [GLOBAL], outputExpr: [SUM(a) + 1], groupByExpr: [b])
+ *   +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
+ *       +-- childPlan
+ *
+ * TODO:
+ *     1. use different class represent different phase aggregate
+ *     2. if instance count is 1, shouldn't disassemble the agg operator
+ *     3. we need another rule to removing duplicated expressions in group by expression list
  */
 public class AggregateDisassemble extends OneRewriteRuleFactory {
 
     @Override
     public Rule<Plan> build() {
         return logicalAggregate().when(p -> {
-            LogicalAggregate logicalAggregation = p.getOperator();
-            return !logicalAggregation.isDisassembled();
+            LogicalAggregate logicalAggregate = p.getOperator();
+            return !logicalAggregate.isDisassembled();
         }).thenApply(ctx -> {
-            Plan plan = ctx.root;
-            Operator operator = plan.getOperator();
-            LogicalAggregate agg = (LogicalAggregate) operator;
-            List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            LogicalUnaryPlan<LogicalAggregate, GroupPlan> plan = ctx.root;
+            LogicalAggregate aggregate = plan.getOperator();
+            List<NamedExpression> originOutputExprs = aggregate.getOutputExpressionList();
+            List<Expression> originGroupByExprs = aggregate.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> originAggregateFunctionWithAlias = Maps.newHashMap();
+            for (NamedExpression originOutputExpr : originOutputExprs) {

Review Comment:
   we don't need get corresponding stale function here. Because we could get intermediate type



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916502178


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -288,23 +333,24 @@ public PlanFragment visitPhysicalHashJoin(
             leftFragment.setPlanRoot(crossJoinNode);
             context.addPlanFragment(leftFragment);
             return leftFragment;
+        } else {
+            Expression eqJoinExpression = physicalHashJoin.getCondition().get();
+            List<Expr> execEqConjunctList = ExpressionUtils.extractConjunct(eqJoinExpression).stream()
+                    .map(EqualTo.class::cast)
+                    .map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput()))
+                    .map(e -> ExpressionTranslator.translate(e, context))
+                    .collect(Collectors.toList());
+
+            HashJoinNode hashJoinNode = new HashJoinNode(context.nextPlanNodeId(), leftFragmentPlanRoot,
+                    rightFragmentPlanRoot,
+                    JoinType.toJoinOperator(physicalHashJoin.getJoinType()), execEqConjunctList, Lists.newArrayList());
+
+            hashJoinNode.setDistributionMode(DistributionMode.BROADCAST);

Review Comment:
   i think do these in translator is better than in `finalizeForNereids`. since we do not need to change stale planner's class at all. The reason why expression use `finalizeForNereids` is stale planner's `Expr` have too many properties need to set. i don't think PlanNode has the same problem.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] 924060929 commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
924060929 commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915532556


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -64,49 +65,69 @@ public Rule<Plan> build() {
             Operator operator = plan.getOperator();
             LogicalAggregate agg = (LogicalAggregate) operator;
             List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            List<Expression> groupByExpressionList = agg.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap();
+            for (NamedExpression outputExpression : outputExpressionList) {
+                outputExpression.foreach(e -> {
+                    if (e instanceof AggregateFunction) {
+                        AggregateFunction a = (AggregateFunction) e;
+                        aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql()));
+                    }
+                });
+            }
+
+            List<Expression> updateGroupByExpressionList = groupByExpressionList;
+            List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream()
+                    .map(g -> new Alias<>(g, g.sql()))
+                    .collect(Collectors.toList());
+
+            List<NamedExpression> updateOutputExpressionList = Lists.newArrayList();
+            updateOutputExpressionList.addAll(updateGroupByAliasList);
+            updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values());
+
+            List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream()
+                    .map(NamedExpression::toSlot).collect(Collectors.toList());
+
+            List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList();
+            for (NamedExpression o : outputExpressionList) {
+                if (o.contains(AggregateFunction.class::isInstance)) {
+                    mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter()
+                            .visit(o, aggregateFunctionAliasMap));
+                } else {
+                    for (int i = 0; i < updateGroupByAliasList.size(); i++) {
+                        // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+                        //    e.g. a + 1 + 2 in output expression should be replaced by
+                        //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+                        //   currently, we could only handle output expression same with group by expression
+                        if (o instanceof SlotReference) {
+                            // a in output expression will be SLotReference
+                            if (o.equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        } else if (o instanceof Alias) {
+                            // a + 1 in output expression will be Alias
+                            if (o.child(0).equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        }
                     }
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
+
             LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+                    updateGroupByExpressionList,
+                    updateOutputExpressionList,

Review Comment:
   I think this variable names and compute logic is confuse, how about this:
   1. localGroupByExprs = originGloupByExprs
   2. localOutputExprs = originOutput.withAlias
   3. globalGroupByWithAlias = originGloupByExprs.withAlias
   4. globalOutputWithAlias = originOutput.replaceAggregateFunctionArgumentsToAliasReference
   
   then `new LogicalAggregate(localGroupByExprs, localOutputExprs, ...)` and `new LogicalAggregate(globalGroupByWithAlias, globalOutputWithAlias)`
   
   this advantage is
   1. variableName contains position and member information
   2. assign statement contains the most simple compute logical
   



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java:
##########
@@ -64,49 +65,69 @@ public Rule<Plan> build() {
             Operator operator = plan.getOperator();
             LogicalAggregate agg = (LogicalAggregate) operator;
             List<NamedExpression> outputExpressionList = agg.getOutputExpressionList();
-            List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList();
-            // TODO: shouldn't extract agg function from this field.
-            for (NamedExpression namedExpression : outputExpressionList) {
-                namedExpression = (NamedExpression) namedExpression.clone();
-                List<AggregateFunction> functionCallList =
-                        namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance);
-                // TODO: we will have another mechanism to get corresponding stale agg func.
-                for (AggregateFunction functionCall : functionCallList) {
-                    org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall);
-                    Type staleIntermediateType = staleAggFunc.getIntermediateType();
-                    Type staleRetType = staleAggFunc.getReturnType();
-                    if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) {
-                        functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType));
+            List<Expression> groupByExpressionList = agg.getGroupByExpressionList();
+
+            Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap();
+            for (NamedExpression outputExpression : outputExpressionList) {
+                outputExpression.foreach(e -> {
+                    if (e instanceof AggregateFunction) {
+                        AggregateFunction a = (AggregateFunction) e;
+                        aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql()));
+                    }
+                });
+            }
+
+            List<Expression> updateGroupByExpressionList = groupByExpressionList;
+            List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream()
+                    .map(g -> new Alias<>(g, g.sql()))
+                    .collect(Collectors.toList());
+
+            List<NamedExpression> updateOutputExpressionList = Lists.newArrayList();
+            updateOutputExpressionList.addAll(updateGroupByAliasList);
+            updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values());
+
+            List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream()
+                    .map(NamedExpression::toSlot).collect(Collectors.toList());
+
+            List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList();
+            for (NamedExpression o : outputExpressionList) {
+                if (o.contains(AggregateFunction.class::isInstance)) {
+                    mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter()
+                            .visit(o, aggregateFunctionAliasMap));
+                } else {
+                    for (int i = 0; i < updateGroupByAliasList.size(); i++) {
+                        // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now.
+                        //    e.g. a + 1 + 2 in output expression should be replaced by
+                        //    (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1
+                        //   currently, we could only handle output expression same with group by expression
+                        if (o instanceof SlotReference) {
+                            // a in output expression will be SLotReference
+                            if (o.equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        } else if (o instanceof Alias) {
+                            // a + 1 in output expression will be Alias
+                            if (o.child(0).equals(updateGroupByExpressionList.get(i))) {
+                                mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot());
+                                break;
+                            }
+                        }
                     }
                 }
-                intermediateAggExpressionList.add(namedExpression);
             }
+
             LogicalAggregate localAgg = new LogicalAggregate(
-                    agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()),
-                    intermediateAggExpressionList,
+                    updateGroupByExpressionList,
+                    updateOutputExpressionList,

Review Comment:
   I think this variable names and compute logic is confuse, how about this:
   1. localGroupByExprs = originGloupByExprs
   2. localOutputExprs = originOutput.withAlias
   3. globalGroupByWithAlias = originGloupByExprs.withAlias
   4. globalOutputWithAlias = originOutput.replaceAggregateFunctionArgumentsToAliasReference
   
   then `new LogicalAggregate(localGroupByExprs, localOutputExprs, ...)` and `new LogicalAggregate(globalGroupByWithAlias, globalOutputWithAlias, ...)`
   
   this advantage is
   1. variableName contains position and member information
   2. assign statement contains the most simple compute logical
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916506034


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/operators/plans/logical/LogicalAggregate.java:
##########
@@ -127,14 +129,19 @@ public boolean equals(Object o) {
             return false;
         }
         LogicalAggregate that = (LogicalAggregate) o;
-        return Objects.equals(groupByExprList, that.groupByExprList)
+        return Objects.equals(groupByExpressionList, that.groupByExpressionList)
                 && Objects.equals(outputExpressionList, that.outputExpressionList)
                 && Objects.equals(partitionExprList, that.partitionExprList)
                 && aggPhase == that.aggPhase;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(groupByExprList, outputExpressionList, partitionExprList, aggPhase);
+        return Objects.hash(groupByExpressionList, outputExpressionList, partitionExprList, aggPhase);
+    }
+
+    public LogicalAggregate withGroupByAndOutput(List<Expression> groupByExprList,

Review Comment:
   constructor set `disassembled` and `aggPhase` to default value. This function reserve them current value. But found i bug that, partition expression set to default value by mistake



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r916504214


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java:
##########
@@ -122,6 +123,13 @@ public Expr visitLessThanEqual(LessThanEqual lessThanEqual, PlanTranslatorContex
                 lessThanEqual.child(1).accept(this, context));
     }
 
+    @Override
+    public Expr visitNullSafeEqual(NullSafeEqual nullSafeEqual, PlanTranslatorContext context) {
+        return new BinaryPredicate(Operator.EQ_FOR_NULL,

Review Comment:
   stale planner have no `NullSafeEqual`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org


[GitHub] [doris] morrySnow commented on a diff in pull request #10659: [enhancement](nereids) make aggregate works

Posted by GitBox <gi...@apache.org>.
morrySnow commented on code in PR #10659:
URL: https://github.com/apache/doris/pull/10659#discussion_r915560690


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) {
      * Translate Agg.
      */
     @Override
-    public PlanFragment visitPhysicalAggregation(
-            PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) {
-
+    public PlanFragment visitPhysicalAggregate(
+            PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) {
         PlanFragment inputPlanFragment = visit(agg.child(0), context);
-
-        AggregationNode aggregationNode;
-        List<Slot> slotList = new ArrayList<>();
-        PhysicalAggregation physicalAggregation = agg.getOperator();
-        AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec();
-
-        List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList();
+        PhysicalAggregate physicalAggregate = agg.getOperator();
+
+        // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts:
+        //    1. group by expressions: removing duplicate expressions add to tuple
+        //    2. agg functions: only removing duplicate agg functions in output expression should appear in tuple.
+        //       e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple
+        //    We need:
+        //    1. add a project after agg, if output expressions include agg function as a expression tree leaf.
+        //    2. introduce canonicalized, semanticEquals and deterministic in Expression
+        //       for removing duplicate.
+        List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList();
+        List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList();
+
+        // 1. generate slot reference for each group expression
+        List<SlotReference> groupSlotList = Lists.newArrayList();
+        for (Expression e : groupByExpressionList) {
+            if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.contains(e::equals))) {
+                groupSlotList.add((SlotReference) e);
+            } else {
+                groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList()));
+            }
+        }
         ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream()
-                // Since output of plan doesn't contain the slots of groupBy, which is actually needed by
-                // the BE execution, so we have to collect them and add to the slotList to generate corresponding
-                // TupleDesc.
-                .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance)))
                 .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new));
-        slotList.addAll(agg.getOutput());
-        TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null);
-
-        List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList();
-        ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream()
-                .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance))
+        // 2. collect agg functions and generate agg function to slot reference map
+        List<Slot> aggFunctionOutput = Lists.newArrayList();
+        List<AggregateFunction> aggregateFunctionList = outputExpressionList.stream()
+                .filter(o -> o.contains(AggregateFunction.class::isInstance))
+                .peek(o -> aggFunctionOutput.add(o.toSlot()))
+                .map(o -> (List<AggregateFunction>) o.collect(AggregateFunction.class::isInstance))
                 .flatMap(List::stream)
+                .collect(Collectors.toList());
+        ArrayList<FunctionCallExpr> execAggExpressions = aggregateFunctionList.stream()
                 .map(x -> (FunctionCallExpr) ExpressionTranslator.translate(x, context))
                 .collect(Collectors.toCollection(ArrayList::new));
 
-        List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList();
+        // 3. generate output tuple
+        // TODO: currently, we only support sum(a), if we want to support sum(a) + 1, we need to
+        //  split merge agg to project(agg) and generate tuple like what first phase agg do.
+        List<Slot> slotList = Lists.newArrayList();
+        TupleDescriptor outputTupleDesc;
+        if (agg.getOperator().getAggPhase() == AggPhase.FIRST_MERGE) {
+            slotList.addAll(groupSlotList);
+            slotList.addAll(aggFunctionOutput);
+            outputTupleDesc = generateTupleDesc(slotList, null, context);
+        } else {
+            outputTupleDesc = generateTupleDesc(agg.getOutput(), null, context);
+        }
+
+        // process partition list
+        List<Expression> partitionExpressionList = physicalAggregate.getPartitionExprList();
         List<Expr> execPartitionExpressions = partitionExpressionList.stream()
-                .map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+                .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList());
+        DataPartition mergePartition = DataPartition.UNPARTITIONED;
+        if (CollectionUtils.isNotEmpty(execPartitionExpressions)) {

Review Comment:
   yes, u r right. add a todo use two class represent global and local aggregate



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org