You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by en...@apache.org on 2024/04/02 00:52:08 UTC

(doris) branch master updated: [feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087)

This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new ac5c9c686e8 [feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087)
ac5c9c686e8 is described below

commit ac5c9c686e8be019899af4fe54f69450281e62bf
Author: minghong <en...@gmail.com>
AuthorDate: Tue Apr 2 08:52:03 2024 +0800

    [feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087)
    
    * cse fe part
---
 .../glue/translator/PhysicalPlanTranslator.java    |  50 ++++++--
 .../post/CommonSubExpressionCollector.java         |  59 ++++++++++
 .../processor/post/CommonSubExpressionOpt.java     | 125 ++++++++++++++++++++
 .../nereids/processor/post/PlanPostProcessors.java |   3 +-
 .../trees/plans/physical/PhysicalProject.java      |  81 ++++++++++++-
 .../java/org/apache/doris/planner/PlanNode.java    |  38 +++++-
 .../apache/doris/catalog/CreateFunctionTest.java   |  41 ++++---
 .../postprocess/CommonSubExpressionTest.java       | 131 +++++++++++++++++++++
 regression-test/data/tpch_sf0.1_p1/sql/cse.out     |  30 +++++
 .../doris/regression/action/ExplainAction.groovy   |  15 +++
 .../suites/tpch_sf0.1_p1/sql/cse.groovy            |  49 ++++++++
 11 files changed, 591 insertions(+), 31 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index f47b6826ebe..205cfbd2530 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -1837,15 +1837,38 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
             registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot());
         }
 
-        List<Expr> projectionExprs = project.getProjects()
-                .stream()
-                .map(e -> ExpressionTranslator.translate(e, context))
-                .collect(Collectors.toList());
-        List<Slot> slots = project.getProjects()
-                .stream()
-                .map(NamedExpression::toSlot)
-                .collect(Collectors.toList());
-
+        PlanNode inputPlanNode = inputFragment.getPlanRoot();
+        List<Expr> projectionExprs = null;
+        List<Expr> allProjectionExprs = Lists.newArrayList();
+        List<Slot> slots = null;
+        if (project.hasMultiLayerProjection()) {
+            int layerCount = project.getMultiLayerProjects().size();
+            for (int i = 0; i < layerCount; i++) {
+                List<NamedExpression> layer = project.getMultiLayerProjects().get(i);
+                projectionExprs = layer.stream()
+                        .map(e -> ExpressionTranslator.translate(e, context))
+                        .collect(Collectors.toList());
+                slots = layer.stream()
+                        .map(NamedExpression::toSlot)
+                        .collect(Collectors.toList());
+                if (i < layerCount - 1) {
+                    inputPlanNode.addIntermediateProjectList(projectionExprs);
+                    TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context);
+                    inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple);
+                }
+                allProjectionExprs.addAll(projectionExprs);
+            }
+        } else {
+            projectionExprs = project.getProjects()
+                    .stream()
+                    .map(e -> ExpressionTranslator.translate(e, context))
+                    .collect(Collectors.toList());
+            slots = project.getProjects()
+                    .stream()
+                    .map(NamedExpression::toSlot)
+                    .collect(Collectors.toList());
+            allProjectionExprs.addAll(projectionExprs);
+        }
         // process multicast sink
         if (inputFragment instanceof MultiCastPlanFragment) {
             MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink();
@@ -1857,10 +1880,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
             return inputFragment;
         }
 
-        PlanNode inputPlanNode = inputFragment.getPlanRoot();
         List<Expr> conjuncts = inputPlanNode.getConjuncts();
         Set<SlotId> requiredSlotIdSet = Sets.newHashSet();
-        for (Expr expr : projectionExprs) {
+        for (Expr expr : allProjectionExprs) {
             Expr.extractSlots(expr, requiredSlotIdSet);
         }
         Set<SlotId> requiredByProjectSlotIdSet = Sets.newHashSet(requiredSlotIdSet);
@@ -1895,8 +1917,10 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
                 requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e)));
                 for (ExprId exprId : requiredExprIds) {
                     SlotId slotId = ((HashJoinNode) joinNode).getHashOutputExprSlotIdMap().get(exprId);
-                    Preconditions.checkState(slotId != null);
-                    ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId);
+                    // Preconditions.checkState(slotId != null);
+                    if (slotId != null) {
+                        ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId);
+                    }
                 }
             }
             return inputFragment;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
new file mode 100644
index 00000000000..5abc5f6f60f
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * collect common expr
+ */
+public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Void> {
+    public final Map<Integer, Set<Expression>> commonExprByDepth = new HashMap<>();
+    private final Map<Integer, Set<Expression>> expressionsByDepth = new HashMap<>();
+
+    @Override
+    public Integer visit(Expression expr, Void context) {
+        if (expr.children().isEmpty()) {
+            return 0;
+        }
+        return collectCommonExpressionByDepth(expr.children().stream().map(child ->
+                child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr);
+    }
+
+    private int collectCommonExpressionByDepth(int depth, Expression expr) {
+        Set<Expression> expressions = getExpressionsFromDepthMap(depth, expressionsByDepth);
+        if (expressions.contains(expr)) {
+            Set<Expression> commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth);
+            commonExpression.add(expr);
+        }
+        expressions.add(expr);
+        return depth;
+    }
+
+    public static Set<Expression> getExpressionsFromDepthMap(
+            int depth, Map<Integer, Set<Expression>> depthMap) {
+        depthMap.putIfAbsent(depth, new LinkedHashSet<>());
+        return depthMap.get(depth);
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
new file mode 100644
index 00000000000..dfaf2de757e
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
+
+import com.google.common.collect.Lists;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Select A+B, (A+B+C)*2, (A+B+C)*3, D from T
+ *
+ * before optimize
+ * projection:
+ * Proj: A+B, (A+B+C)*2, (A+B+C)*3, D
+ *
+ * ---
+ * after optimize:
+ * Projection: List < List < Expression > >
+ * A+B, C, D
+ * A+B, A+B+C, D
+ * A+B, (A+B+C)*2, (A+B+C)*3, D
+ */
+public class CommonSubExpressionOpt extends PlanPostProcessor {
+    @Override
+    public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) {
+
+        List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
+                project.getInputSlots(), project.getProjects());
+        project.setMultiLayerProjects(multiLayers);
+        return project;
+    }
+
+    private List<List<NamedExpression>> computeMultiLayerProjections(
+            Set<Slot> inputSlots, List<NamedExpression> projects) {
+
+        List<List<NamedExpression>> multiLayers = Lists.newArrayList();
+        CommonSubExpressionCollector collector = new CommonSubExpressionCollector();
+        for (Expression expr : projects) {
+            expr.accept(collector, null);
+        }
+        Map<Expression, Alias> commonExprToAliasMap = new HashMap<>();
+        collector.commonExprByDepth.values().stream().flatMap(expressions -> expressions.stream())
+                .forEach(expression -> {
+                    if (expression instanceof Alias) {
+                        commonExprToAliasMap.put(expression, (Alias) expression);
+                    } else {
+                        commonExprToAliasMap.put(expression, new Alias(expression));
+                    }
+                });
+        Map<Expression, Alias> aliasMap = new HashMap<>();
+        if (!collector.commonExprByDepth.isEmpty()) {
+            for (int i = 1; i <= collector.commonExprByDepth.size(); i++) {
+                List<NamedExpression> layer = Lists.newArrayList();
+                layer.addAll(inputSlots);
+                Set<Expression> exprsInDepth = CommonSubExpressionCollector
+                        .getExpressionsFromDepthMap(i, collector.commonExprByDepth);
+                exprsInDepth.forEach(expr -> {
+                    Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
+                    Alias alias = new Alias(rewritten);
+                    aliasMap.put(expr, alias);
+                });
+                layer.addAll(aliasMap.values());
+                multiLayers.add(layer);
+            }
+            // final layer
+            List<NamedExpression> finalLayer = Lists.newArrayList();
+            projects.forEach(expr -> {
+                Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
+                if (rewritten instanceof Slot) {
+                    finalLayer.add((NamedExpression) rewritten);
+                } else if (rewritten instanceof Alias) {
+                    finalLayer.add(new Alias(expr.getExprId(), ((Alias) rewritten).child(), expr.getName()));
+                }
+            });
+            multiLayers.add(finalLayer);
+        }
+        return multiLayers;
+    }
+
+    /**
+     * replace sub expr by aliasMap
+     */
+    public static class ExpressionReplacer
+            extends DefaultExpressionRewriter<Map<? extends Expression, ? extends Alias>> {
+        public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
+
+        private ExpressionReplacer() {
+        }
+
+        @Override
+        public Expression visit(Expression expr, Map<? extends Expression, ? extends Alias> replaceMap) {
+            if (replaceMap.containsKey(expr)) {
+                return replaceMap.get(expr).toSlot();
+            }
+            return super.visit(expr, replaceMap);
+        }
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
index 60c1a74445e..86c8486ef45 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
@@ -63,8 +63,9 @@ public class PlanPostProcessors {
         builder.add(new MergeProjectPostProcessor());
         builder.add(new RecomputeLogicalPropertiesProcessor());
         builder.add(new AddOffsetIntoDistribute());
+        builder.add(new CommonSubExpressionOpt());
+        // DO NOT replace PLAN NODE from here
         builder.add(new TopNScanOpt());
-        // after generate rf, DO NOT replace PLAN NODE
         builder.add(new FragmentProcessor());
         if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()
                         .toUpperCase().equals(TRuntimeFilterMode.OFF.name())) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
index a2419b7870a..93fde854a1c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
 import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -41,6 +42,7 @@ import org.apache.doris.thrift.TRuntimeFilterType;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 
 import java.util.List;
 import java.util.Objects;
@@ -52,6 +54,12 @@ import java.util.Optional;
 public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Project {
 
     private final List<NamedExpression> projects;
+    //multiLayerProjects is used to extract common expressions
+    // projects: (A+B) * 2, (A+B) * 3
+    // multiLayerProjects:
+    //            L1: A+B as x
+    //            L2: x*2, x*3
+    private List<List<NamedExpression>> multiLayerProjects = Lists.newArrayList();
 
     public PhysicalProject(List<NamedExpression> projects, LogicalProperties logicalProperties, CHILD_TYPE child) {
         this(projects, Optional.empty(), logicalProperties, child);
@@ -227,7 +235,12 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
 
     @Override
     public List<Slot> computeOutput() {
-        return projects.stream()
+        List<NamedExpression> output = projects;
+        if (! multiLayerProjects.isEmpty()) {
+            int layers = multiLayerProjects.size();
+            output = multiLayerProjects.get(layers - 1);
+        }
+        return output.stream()
                 .map(NamedExpression::toSlot)
                 .collect(ImmutableList.toImmutableList());
     }
@@ -237,4 +250,70 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
         return new PhysicalProject<>(projects, groupExpression, null, physicalProperties,
                 statistics, child());
     }
+
+    /**
+     * extract common expr, set multi layer projects
+     */
+    public void computeMultiLayerProjectsForCommonExpress() {
+        // hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + s_nationkey), s_name from supplier;
+        if (projects.size() == 3) {
+            if (projects.get(2) instanceof SlotReference) {
+                SlotReference sName = (SlotReference) projects.get(2);
+                if (sName.getName().equals("s_name")) {
+                    Alias a1 = (Alias) projects.get(0); // (s_suppkey + s_nationkey)
+                    Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + s_nationkey)
+                    // L1: (s_suppkey + s_nationkey) as x, s_name
+                    multiLayerProjects.add(Lists.newArrayList(projects.get(0), projects.get(2)));
+                    List<NamedExpression> l2 = Lists.newArrayList();
+                    l2.add(a1.toSlot());
+                    Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), a2.child().child(1)), a2.getName());
+                    l2.add(a3);
+                    l2.add(sName);
+                    // L2: x, (1+x) as y, s_name
+                    multiLayerProjects.add(l2);
+                }
+            }
+        }
+        // hard code:
+        // select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y
+        // from supplier join nation on s_nationkey=n_nationkey
+        // projects: x, y
+        // multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z
+        //       L2: z +1 as x, z+2 as y
+        if (projects.size() == 2 && projects.get(0) instanceof Alias && projects.get(1) instanceof Alias
+                && ((Alias) projects.get(0)).getName().equals("x")
+                && ((Alias) projects.get(1)).getName().equals("y")) {
+            Alias a0 = (Alias) projects.get(0);
+            Alias a1 = (Alias) projects.get(1);
+            Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey
+            List<NamedExpression> l1 = Lists.newArrayList();
+            common.children().stream().forEach(child -> l1.add((SlotReference) child));
+            Alias aliasOfCommon = new Alias(common);
+            l1.add(aliasOfCommon);
+            multiLayerProjects.add(l1);
+            Add add1 = new Add(common, a0.child().child(0).child(1));
+            Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName());
+            Add add2 = new Add(common, a1.child().child(0).child(1));
+            Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName());
+            List<NamedExpression> l2 = Lists.newArrayList(aliasOfAdd1, aliasOfAdd2);
+            multiLayerProjects.add(l2);
+        }
+    }
+
+    public boolean hasMultiLayerProjection() {
+        return !multiLayerProjects.isEmpty();
+    }
+
+    public List<List<NamedExpression>> getMultiLayerProjects() {
+        return multiLayerProjects;
+    }
+
+    public void setMultiLayerProjects(List<List<NamedExpression>> multiLayers) {
+        this.multiLayerProjects = multiLayers;
+    }
+
+    @Override
+    public List<Slot> getOutput() {
+        return computeOutput();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
index b404bc4ad35..8cc18a527a8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
@@ -59,6 +59,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * Each PlanNode represents a single relational operator
@@ -155,6 +156,8 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
     protected int nereidsId = -1;
 
     private List<List<Expr>> childrenDistributeExprLists = new ArrayList<>();
+    private List<TupleDescriptor> intermediateOutputTupleDescList = Lists.newArrayList();
+    private List<List<Expr>> intermediateProjectListList = Lists.newArrayList();
 
     protected PlanNode(PlanNodeId id, ArrayList<TupleId> tupleIds, String planNodeName,
             StatisticalType statisticalType) {
@@ -536,10 +539,20 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
             expBuilder.append(detailPrefix + "limit: " + limit + "\n");
         }
         if (!CollectionUtils.isEmpty(projectList)) {
-            expBuilder.append(detailPrefix).append("projections: ").append(getExplainString(projectList)).append("\n");
-            expBuilder.append(detailPrefix).append("project output tuple id: ")
+            expBuilder.append(detailPrefix).append("final projections: ")
+                .append(getExplainString(projectList)).append("\n");
+            expBuilder.append(detailPrefix).append("final project output tuple id: ")
                     .append(outputTupleDesc.getId().asInt()).append("\n");
         }
+        if (!intermediateProjectListList.isEmpty()) {
+            int layers = intermediateProjectListList.size();
+            for (int i = layers - 1; i >= 0; i--) {
+                expBuilder.append(detailPrefix).append("intermediate projections: ")
+                        .append(getExplainString(intermediateProjectListList.get(i))).append("\n");
+                expBuilder.append(detailPrefix).append("intermediate tuple id: ")
+                        .append(intermediateOutputTupleDescList.get(i).getId().asInt()).append("\n");
+            }
+        }
         if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) {
             for (List<Expr> distributeExprList : childrenDistributeExprLists) {
                 expBuilder.append(detailPrefix).append("distribute expr lists: ")
@@ -660,6 +673,19 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
                 }
             }
         }
+
+        if (!intermediateOutputTupleDescList.isEmpty()) {
+            intermediateOutputTupleDescList
+                    .forEach(
+                            tupleDescriptor -> msg.addToIntermediateOutputTupleIdList(tupleDescriptor.getId().asInt()));
+        }
+
+        if (!intermediateProjectListList.isEmpty()) {
+            intermediateProjectListList.forEach(
+                    projectList -> msg.addToIntermediateProjectionsList(
+                            projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList())));
+        }
+
         if (this instanceof ExchangeNode) {
             msg.num_children = 0;
             return;
@@ -1221,4 +1247,12 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
     public void setNereidsId(int nereidsId) {
         this.nereidsId = nereidsId;
     }
+
+    public void addIntermediateOutputTupleDescList(TupleDescriptor tupleDescriptor) {
+        intermediateOutputTupleDescList.add(tupleDescriptor);
+    }
+
+    public void addIntermediateProjectList(List<Expr> exprs) {
+        intermediateProjectListList.add(exprs);
+    }
 }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
index 0f464ba2946..c342d858fe1 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
@@ -74,6 +74,7 @@ public class CreateFunctionTest {
     public void test() throws Exception {
         ConnectContext ctx = UtFrameUtils.createDefaultCtx();
         ctx.getSessionVariable().setEnableNereidsPlanner(false);
+        ctx.getSessionVariable().enableFallbackToOriginalPlanner = true;
         ctx.getSessionVariable().setEnableFoldConstantByBe(false);
         // create database db1
         createDatabase(ctx, "create database db1;");
@@ -113,8 +114,8 @@ public class CreateFunctionTest {
         Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr);
 
         queryStr = "select db1.id_masking(k1) from db1.tbl1";
-        Assert.assertTrue(
-                dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
+        Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
 
         // create alias function with cast
         // cast any type to decimal with specific precision and scale
@@ -142,14 +143,16 @@ public class CreateFunctionTest {
 
         queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;";
         if (Config.enable_decimal_conversion) {
-            Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))"));
+            Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMALV3(4, 1))"));
         } else {
-            Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))"));
+            Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMAL(4, 1))"));
         }
 
         // cast any type to varchar with fixed length
-        createFuncStr = "create alias function db1.varchar(all) with parameter(text) as "
-                + "cast(text as varchar(65533));";
+        createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as "
+                + "cast(text as varchar(length));";
         createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
         Env.getCurrentEnv().createFunction(createFunctionStmt);
 
@@ -172,7 +175,8 @@ public class CreateFunctionTest {
         Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
 
         queryStr = "select db1.varchar(k1, 4) from db1.tbl1;";
-        Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))"));
+        Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS VARCHAR(65533))"));
 
         // cast any type to char with fixed length
         createFuncStr = "create alias function db1.to_char(all, int) with parameter(text, length) as "
@@ -199,7 +203,8 @@ public class CreateFunctionTest {
         Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
 
         queryStr = "select db1.to_char(k1, 4) from db1.tbl1;";
-        Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER"));
+        Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS CHARACTER"));
     }
 
     @Test
@@ -235,8 +240,8 @@ public class CreateFunctionTest {
         testFunctionQuery(ctx, queryStr, false);
 
         queryStr = "select id_masking(k1) from db2.tbl1";
-        Assert.assertTrue(
-                dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
+        Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
 
         // 4. create alias function with cast
         // cast any type to decimal with specific precision and scale
@@ -253,9 +258,11 @@ public class CreateFunctionTest {
 
         queryStr = "select decimal(k3, 4, 1) from db2.tbl1;";
         if (Config.enable_decimal_conversion) {
-            Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))"));
+            Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMALV3(4, 1))"));
         } else {
-            Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))"));
+            Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                    "CAST(`k3` AS DECIMAL(4, 1))"));
         }
 
         // 5. cast any type to varchar with fixed length
@@ -271,7 +278,8 @@ public class CreateFunctionTest {
         testFunctionQuery(ctx, queryStr, true);
 
         queryStr = "select varchar(k1, 4) from db2.tbl1;";
-        Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))"));
+        Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS VARCHAR(65533))"));
 
         // 6. cast any type to char with fixed length
         createFuncStr = "create global alias function db2.to_char(all, int) with parameter(text, length) as "
@@ -286,7 +294,8 @@ public class CreateFunctionTest {
         testFunctionQuery(ctx, queryStr, true);
 
         queryStr = "select to_char(k1, 4) from db2.tbl1;";
-        Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)"));
+        Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+                "CAST(`k1` AS CHARACTER)"));
     }
 
     private void testFunctionQuery(ConnectContext ctx, String queryStr, Boolean isStringLiteral) throws Exception {
@@ -320,4 +329,8 @@ public class CreateFunctionTest {
         Env.getCurrentEnv().createDb(createDbStmt);
         System.out.println(Env.getCurrentInternalCatalog().getDbNames());
     }
+
+    private boolean containsIgnoreCase(String str, String sub) {
+        return str.toLowerCase().contains(sub.toLowerCase());
+    }
 }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
new file mode 100644
index 00000000000..56b67e087d5
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.postprocess;
+
+import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector;
+import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.types.IntegerType;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class CommonSubExpressionTest extends ExpressionRewriteTestHelper {
+    @Test
+    public void testExtractCommonExpr() {
+        List<NamedExpression> exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a");
+        CommonSubExpressionCollector collector =
+                new CommonSubExpressionCollector();
+        exprs.forEach(expr -> collector.visit(expr, null));
+        System.out.println(collector.commonExprByDepth);
+        Assertions.assertEquals(2, collector.commonExprByDepth.size());
+        List<Expression> l1 = collector.commonExprByDepth.get(Integer.valueOf(1))
+                .stream().collect(Collectors.toList());
+        List<Expression> l2 = collector.commonExprByDepth.get(Integer.valueOf(2))
+                .stream().collect(Collectors.toList());
+        Assertions.assertEquals(1, l1.size());
+        assertExpression(l1.get(0), "a+b");
+        Assertions.assertEquals(1, l2.size());
+        assertExpression(l2.get(0), "a+b+1");
+    }
+
+    @Test
+    public void testMultiLayers() throws Exception {
+        List<NamedExpression> exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a");
+        Set<Slot> inputSlots = exprs.get(0).getInputSlots();
+        CommonSubExpressionOpt opt = new CommonSubExpressionOpt();
+        Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class
+                .getDeclaredMethod("computeMultiLayerProjections", Set.class, List.class);
+        computeMultLayerProjectionsMethod.setAccessible(true);
+        List<List<NamedExpression>> multiLayers = (List<List<NamedExpression>>) computeMultLayerProjectionsMethod
+                .invoke(opt, inputSlots, exprs);
+        System.out.println(multiLayers);
+        Assertions.assertEquals(3, multiLayers.size());
+        List<NamedExpression> l0 = multiLayers.get(0);
+        Assertions.assertEquals(2, l0.size());
+        Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a")));
+        Assertions.assertTrue(l0.get(1) instanceof Alias);
+        assertExpression(l0.get(1).child(0), "a+b");
+        Assertions.assertEquals(multiLayers.get(1).size(), 3);
+        Assertions.assertEquals(multiLayers.get(2).size(), 5);
+        List<NamedExpression> l2 = multiLayers.get(2);
+        for (int i = 0; i < 5; i++) {
+            Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt());
+        }
+
+    }
+
+    private void assertExpression(Expression expr, String str) {
+        Assertions.assertEquals(ExprParser.INSTANCE.parseExpression(str), expr);
+    }
+
+    private List<NamedExpression> parseProjections(String exprList) {
+        List<NamedExpression> result = new ArrayList<>();
+        String[] exprArray = exprList.split(",");
+        for (String item : exprArray) {
+            Expression expr = ExprParser.INSTANCE.parseExpression(item);
+            if (expr instanceof NamedExpression) {
+                result.add((NamedExpression) expr);
+            } else {
+                result.add(new Alias(expr));
+            }
+        }
+        return result;
+    }
+
+    public static class ExprParser {
+        public static ExprParser INSTANCE = new ExprParser();
+        HashMap<String, SlotReference> slotMap = new HashMap<>();
+
+        public Expression parseExpression(String str) {
+            Expression expr = PARSER.parseExpression(str);
+            return expr.accept(DataTypeAssignor.INSTANCE, slotMap);
+        }
+    }
+
+    public static class DataTypeAssignor extends DefaultExpressionRewriter<Map<String, SlotReference>> {
+        public static DataTypeAssignor INSTANCE = new DataTypeAssignor();
+
+        @Override
+        public Expression visitSlot(Slot slot, Map<String, SlotReference> slotMap) {
+            SlotReference exitsSlot = slotMap.get(slot.getName());
+            if (exitsSlot != null) {
+                return exitsSlot;
+            } else {
+                SlotReference slotReference = new SlotReference(slot.getName(), IntegerType.INSTANCE);
+                slotMap.put(slot.getName(), slotReference);
+                return slotReference;
+            }
+        }
+    }
+
+}
diff --git a/regression-test/data/tpch_sf0.1_p1/sql/cse.out b/regression-test/data/tpch_sf0.1_p1/sql/cse.out
new file mode 100644
index 00000000000..5ab44655661
--- /dev/null
+++ b/regression-test/data/tpch_sf0.1_p1/sql/cse.out
@@ -0,0 +1,30 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !cse --
+1	1	3	4
+2	0	3	4
+3	1	5	6
+4	0	5	6
+5	4	10	11
+6	0	7	8
+7	3	11	12
+8	1	10	11
+9	4	14	15
+10	1	12	13
+
+-- !cse_2 --
+17	1	18	19	19
+5	2	7	8	8
+1	3	4	5	5
+15	4	19	20	20
+11	5	16	17	17
+14	6	20	21	21
+23	7	30	31	31
+17	8	25	26	26
+10	9	19	20	20
+24	10	34	35	35
+
+-- !cse_3 --
+12093	13093	14093	15093
+
+-- !cse_4 --
+12093	13093	14093	15093
\ No newline at end of file
diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
index e6f05c6c765..cf0c03fc3bd 100644
--- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
+++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
@@ -32,6 +32,7 @@ class ExplainAction implements SuiteAction {
     private SuiteContext context
     private Set<String> containsStrings = new LinkedHashSet<>()
     private Set<String> notContainsStrings = new LinkedHashSet<>()
+    private Map<String, Integer> multiContainsStrings = new HashMap<>()
     private String coonType
     private Closure checkFunction
 
@@ -56,6 +57,10 @@ class ExplainAction implements SuiteAction {
         containsStrings.add(subString)
     }
 
+    void multiContains(String subString, int n) {
+        multiContainsStrings.put(subString, n);
+    }
+
     void notContains(String subString) {
         notContainsStrings.add(subString)
     }
@@ -112,6 +117,16 @@ class ExplainAction implements SuiteAction {
                     throw t
                 }
             }
+            for (Map.Entry entry : multiContainsStrings) {
+                int count = explainString.count(entry.key);
+                if (count != entry.value) {
+                    String msg = ("Explain and check failed, expect multiContains '${string}' , '${entry.value}' times, actural '${count}' times."
+                            + "Actual explain string is:\n${explainString}").toString()
+                    log.info(msg)
+                    def t = new IllegalStateException(msg)
+                    throw t
+                }
+            }
         }
     }
 
diff --git a/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy
new file mode 100644
index 00000000000..698dbd3e5d0
--- /dev/null
+++ b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// The cases is copied from https://github.com/trinodb/trino/tree/master
+// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/tpcds
+// and modified by Doris.
+
+suite('cse') {
+    def q1 = """select s_suppkey,n_regionkey,(s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y 
+            from supplier join nation on s_nationkey=n_nationkey order by s_suppkey , n_regionkey limit 10 ;
+            """
+
+    def q2 = """select s_nationkey,s_suppkey ,(s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1,  abs((s_nationkey + s_suppkey) + 1) 
+    from supplier order by s_suppkey , s_suppkey limit 10 ;"""
+
+    qt_cse "${q1}"
+
+    explain {
+        sql "${q1}"
+        contains "intermediate projections:"
+    }
+
+    qt_cse_2 "${q2}"
+
+    explain {
+        sql "${q2}"
+        multiContains("intermediate projections:", 2)
+    }
+
+    qt_cse_3 """ select sum(s_nationkey),sum(s_nationkey +1 ) ,sum(s_nationkey +2 )  , sum(s_nationkey + 3 ) from supplier ;"""
+
+    qt_cse_4 """select sum(s_nationkey),sum(s_nationkey) + count(1) ,sum(s_nationkey) + 2 * count(1) , sum(s_nationkey)  + 3 * count(1) from supplier ;"""
+
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org