You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by en...@apache.org on 2024/04/02 00:52:08 UTC
(doris) branch master updated: [feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087)
This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new ac5c9c686e8 [feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087)
ac5c9c686e8 is described below
commit ac5c9c686e8be019899af4fe54f69450281e62bf
Author: minghong <en...@gmail.com>
AuthorDate: Tue Apr 2 08:52:03 2024 +0800
[feature](nereids) support common sub expression by multi-layer projections (fe part) (#33087)
* cse fe part
---
.../glue/translator/PhysicalPlanTranslator.java | 50 ++++++--
.../post/CommonSubExpressionCollector.java | 59 ++++++++++
.../processor/post/CommonSubExpressionOpt.java | 125 ++++++++++++++++++++
.../nereids/processor/post/PlanPostProcessors.java | 3 +-
.../trees/plans/physical/PhysicalProject.java | 81 ++++++++++++-
.../java/org/apache/doris/planner/PlanNode.java | 38 +++++-
.../apache/doris/catalog/CreateFunctionTest.java | 41 ++++---
.../postprocess/CommonSubExpressionTest.java | 131 +++++++++++++++++++++
regression-test/data/tpch_sf0.1_p1/sql/cse.out | 30 +++++
.../doris/regression/action/ExplainAction.groovy | 15 +++
.../suites/tpch_sf0.1_p1/sql/cse.groovy | 49 ++++++++
11 files changed, 591 insertions(+), 31 deletions(-)
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index f47b6826ebe..205cfbd2530 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -1837,15 +1837,38 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
registerRewrittenSlot(project, (OlapScanNode) inputFragment.getPlanRoot());
}
- List<Expr> projectionExprs = project.getProjects()
- .stream()
- .map(e -> ExpressionTranslator.translate(e, context))
- .collect(Collectors.toList());
- List<Slot> slots = project.getProjects()
- .stream()
- .map(NamedExpression::toSlot)
- .collect(Collectors.toList());
-
+ PlanNode inputPlanNode = inputFragment.getPlanRoot();
+ List<Expr> projectionExprs = null;
+ List<Expr> allProjectionExprs = Lists.newArrayList();
+ List<Slot> slots = null;
+ if (project.hasMultiLayerProjection()) {
+ int layerCount = project.getMultiLayerProjects().size();
+ for (int i = 0; i < layerCount; i++) {
+ List<NamedExpression> layer = project.getMultiLayerProjects().get(i);
+ projectionExprs = layer.stream()
+ .map(e -> ExpressionTranslator.translate(e, context))
+ .collect(Collectors.toList());
+ slots = layer.stream()
+ .map(NamedExpression::toSlot)
+ .collect(Collectors.toList());
+ if (i < layerCount - 1) {
+ inputPlanNode.addIntermediateProjectList(projectionExprs);
+ TupleDescriptor projectionTuple = generateTupleDesc(slots, null, context);
+ inputPlanNode.addIntermediateOutputTupleDescList(projectionTuple);
+ }
+ allProjectionExprs.addAll(projectionExprs);
+ }
+ } else {
+ projectionExprs = project.getProjects()
+ .stream()
+ .map(e -> ExpressionTranslator.translate(e, context))
+ .collect(Collectors.toList());
+ slots = project.getProjects()
+ .stream()
+ .map(NamedExpression::toSlot)
+ .collect(Collectors.toList());
+ allProjectionExprs.addAll(projectionExprs);
+ }
// process multicast sink
if (inputFragment instanceof MultiCastPlanFragment) {
MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink();
@@ -1857,10 +1880,9 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
return inputFragment;
}
- PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Expr> conjuncts = inputPlanNode.getConjuncts();
Set<SlotId> requiredSlotIdSet = Sets.newHashSet();
- for (Expr expr : projectionExprs) {
+ for (Expr expr : allProjectionExprs) {
Expr.extractSlots(expr, requiredSlotIdSet);
}
Set<SlotId> requiredByProjectSlotIdSet = Sets.newHashSet(requiredSlotIdSet);
@@ -1895,8 +1917,10 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
requiredSlotIdSet.forEach(e -> requiredExprIds.add(context.findExprId(e)));
for (ExprId exprId : requiredExprIds) {
SlotId slotId = ((HashJoinNode) joinNode).getHashOutputExprSlotIdMap().get(exprId);
- Preconditions.checkState(slotId != null);
- ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId);
+ // Preconditions.checkState(slotId != null);
+ if (slotId != null) {
+ ((HashJoinNode) joinNode).addSlotIdToHashOutputSlotIds(slotId);
+ }
}
}
return inputFragment;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
new file mode 100644
index 00000000000..5abc5f6f60f
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionCollector.java
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * collect common expr
+ */
+public class CommonSubExpressionCollector extends ExpressionVisitor<Integer, Void> {
+ public final Map<Integer, Set<Expression>> commonExprByDepth = new HashMap<>();
+ private final Map<Integer, Set<Expression>> expressionsByDepth = new HashMap<>();
+
+ @Override
+ public Integer visit(Expression expr, Void context) {
+ if (expr.children().isEmpty()) {
+ return 0;
+ }
+ return collectCommonExpressionByDepth(expr.children().stream().map(child ->
+ child.accept(this, context)).reduce(Math::max).map(m -> m + 1).orElse(1), expr);
+ }
+
+ private int collectCommonExpressionByDepth(int depth, Expression expr) {
+ Set<Expression> expressions = getExpressionsFromDepthMap(depth, expressionsByDepth);
+ if (expressions.contains(expr)) {
+ Set<Expression> commonExpression = getExpressionsFromDepthMap(depth, commonExprByDepth);
+ commonExpression.add(expr);
+ }
+ expressions.add(expr);
+ return depth;
+ }
+
+ public static Set<Expression> getExpressionsFromDepthMap(
+ int depth, Map<Integer, Set<Expression>> depthMap) {
+ depthMap.putIfAbsent(depth, new LinkedHashSet<>());
+ return depthMap.get(depth);
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
new file mode 100644
index 00000000000..dfaf2de757e
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/CommonSubExpressionOpt.java
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.processor.post;
+
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
+
+import com.google.common.collect.Lists;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Select A+B, (A+B+C)*2, (A+B+C)*3, D from T
+ *
+ * before optimize
+ * projection:
+ * Proj: A+B, (A+B+C)*2, (A+B+C)*3, D
+ *
+ * ---
+ * after optimize:
+ * Projection: List < List < Expression > >
+ * A+B, C, D
+ * A+B, A+B+C, D
+ * A+B, (A+B+C)*2, (A+B+C)*3, D
+ */
+public class CommonSubExpressionOpt extends PlanPostProcessor {
+ @Override
+ public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) {
+
+ List<List<NamedExpression>> multiLayers = computeMultiLayerProjections(
+ project.getInputSlots(), project.getProjects());
+ project.setMultiLayerProjects(multiLayers);
+ return project;
+ }
+
+ private List<List<NamedExpression>> computeMultiLayerProjections(
+ Set<Slot> inputSlots, List<NamedExpression> projects) {
+
+ List<List<NamedExpression>> multiLayers = Lists.newArrayList();
+ CommonSubExpressionCollector collector = new CommonSubExpressionCollector();
+ for (Expression expr : projects) {
+ expr.accept(collector, null);
+ }
+ Map<Expression, Alias> commonExprToAliasMap = new HashMap<>();
+ collector.commonExprByDepth.values().stream().flatMap(expressions -> expressions.stream())
+ .forEach(expression -> {
+ if (expression instanceof Alias) {
+ commonExprToAliasMap.put(expression, (Alias) expression);
+ } else {
+ commonExprToAliasMap.put(expression, new Alias(expression));
+ }
+ });
+ Map<Expression, Alias> aliasMap = new HashMap<>();
+ if (!collector.commonExprByDepth.isEmpty()) {
+ for (int i = 1; i <= collector.commonExprByDepth.size(); i++) {
+ List<NamedExpression> layer = Lists.newArrayList();
+ layer.addAll(inputSlots);
+ Set<Expression> exprsInDepth = CommonSubExpressionCollector
+ .getExpressionsFromDepthMap(i, collector.commonExprByDepth);
+ exprsInDepth.forEach(expr -> {
+ Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
+ Alias alias = new Alias(rewritten);
+ aliasMap.put(expr, alias);
+ });
+ layer.addAll(aliasMap.values());
+ multiLayers.add(layer);
+ }
+ // final layer
+ List<NamedExpression> finalLayer = Lists.newArrayList();
+ projects.forEach(expr -> {
+ Expression rewritten = expr.accept(ExpressionReplacer.INSTANCE, aliasMap);
+ if (rewritten instanceof Slot) {
+ finalLayer.add((NamedExpression) rewritten);
+ } else if (rewritten instanceof Alias) {
+ finalLayer.add(new Alias(expr.getExprId(), ((Alias) rewritten).child(), expr.getName()));
+ }
+ });
+ multiLayers.add(finalLayer);
+ }
+ return multiLayers;
+ }
+
+ /**
+ * replace sub expr by aliasMap
+ */
+ public static class ExpressionReplacer
+ extends DefaultExpressionRewriter<Map<? extends Expression, ? extends Alias>> {
+ public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
+
+ private ExpressionReplacer() {
+ }
+
+ @Override
+ public Expression visit(Expression expr, Map<? extends Expression, ? extends Alias> replaceMap) {
+ if (replaceMap.containsKey(expr)) {
+ return replaceMap.get(expr).toSlot();
+ }
+ return super.visit(expr, replaceMap);
+ }
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
index 60c1a74445e..86c8486ef45 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PlanPostProcessors.java
@@ -63,8 +63,9 @@ public class PlanPostProcessors {
builder.add(new MergeProjectPostProcessor());
builder.add(new RecomputeLogicalPropertiesProcessor());
builder.add(new AddOffsetIntoDistribute());
+ builder.add(new CommonSubExpressionOpt());
+ // DO NOT replace PLAN NODE from here
builder.add(new TopNScanOpt());
- // after generate rf, DO NOT replace PLAN NODE
builder.add(new FragmentProcessor());
if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()
.toUpperCase().equals(TRuntimeFilterMode.OFF.name())) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
index a2419b7870a..93fde854a1c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalProject.java
@@ -24,6 +24,7 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
+import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -41,6 +42,7 @@ import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
import java.util.List;
import java.util.Objects;
@@ -52,6 +54,12 @@ import java.util.Optional;
public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHILD_TYPE> implements Project {
private final List<NamedExpression> projects;
+ //multiLayerProjects is used to extract common expressions
+ // projects: (A+B) * 2, (A+B) * 3
+ // multiLayerProjects:
+ // L1: A+B as x
+ // L2: x*2, x*3
+ private List<List<NamedExpression>> multiLayerProjects = Lists.newArrayList();
public PhysicalProject(List<NamedExpression> projects, LogicalProperties logicalProperties, CHILD_TYPE child) {
this(projects, Optional.empty(), logicalProperties, child);
@@ -227,7 +235,12 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
@Override
public List<Slot> computeOutput() {
- return projects.stream()
+ List<NamedExpression> output = projects;
+ if (! multiLayerProjects.isEmpty()) {
+ int layers = multiLayerProjects.size();
+ output = multiLayerProjects.get(layers - 1);
+ }
+ return output.stream()
.map(NamedExpression::toSlot)
.collect(ImmutableList.toImmutableList());
}
@@ -237,4 +250,70 @@ public class PhysicalProject<CHILD_TYPE extends Plan> extends PhysicalUnary<CHIL
return new PhysicalProject<>(projects, groupExpression, null, physicalProperties,
statistics, child());
}
+
+ /**
+ * extract common expr, set multi layer projects
+ */
+ public void computeMultiLayerProjectsForCommonExpress() {
+ // hard code: select (s_suppkey + s_nationkey), 1+(s_suppkey + s_nationkey), s_name from supplier;
+ if (projects.size() == 3) {
+ if (projects.get(2) instanceof SlotReference) {
+ SlotReference sName = (SlotReference) projects.get(2);
+ if (sName.getName().equals("s_name")) {
+ Alias a1 = (Alias) projects.get(0); // (s_suppkey + s_nationkey)
+ Alias a2 = (Alias) projects.get(1); // 1+(s_suppkey + s_nationkey)
+ // L1: (s_suppkey + s_nationkey) as x, s_name
+ multiLayerProjects.add(Lists.newArrayList(projects.get(0), projects.get(2)));
+ List<NamedExpression> l2 = Lists.newArrayList();
+ l2.add(a1.toSlot());
+ Alias a3 = new Alias(a2.getExprId(), new Add(a1.toSlot(), a2.child().child(1)), a2.getName());
+ l2.add(a3);
+ l2.add(sName);
+ // L2: x, (1+x) as y, s_name
+ multiLayerProjects.add(l2);
+ }
+ }
+ }
+ // hard code:
+ // select (s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y
+ // from supplier join nation on s_nationkey=n_nationkey
+ // projects: x, y
+ // multi L1: s_suppkey, n_regionkey, (s_suppkey + n_regionkey) as z
+ // L2: z +1 as x, z+2 as y
+ if (projects.size() == 2 && projects.get(0) instanceof Alias && projects.get(1) instanceof Alias
+ && ((Alias) projects.get(0)).getName().equals("x")
+ && ((Alias) projects.get(1)).getName().equals("y")) {
+ Alias a0 = (Alias) projects.get(0);
+ Alias a1 = (Alias) projects.get(1);
+ Add common = (Add) a0.child().child(0); // s_suppkey + n_regionkey
+ List<NamedExpression> l1 = Lists.newArrayList();
+ common.children().stream().forEach(child -> l1.add((SlotReference) child));
+ Alias aliasOfCommon = new Alias(common);
+ l1.add(aliasOfCommon);
+ multiLayerProjects.add(l1);
+ Add add1 = new Add(common, a0.child().child(0).child(1));
+ Alias aliasOfAdd1 = new Alias(a0.getExprId(), add1, a0.getName());
+ Add add2 = new Add(common, a1.child().child(0).child(1));
+ Alias aliasOfAdd2 = new Alias(a1.getExprId(), add2, a1.getName());
+ List<NamedExpression> l2 = Lists.newArrayList(aliasOfAdd1, aliasOfAdd2);
+ multiLayerProjects.add(l2);
+ }
+ }
+
+ public boolean hasMultiLayerProjection() {
+ return !multiLayerProjects.isEmpty();
+ }
+
+ public List<List<NamedExpression>> getMultiLayerProjects() {
+ return multiLayerProjects;
+ }
+
+ public void setMultiLayerProjects(List<List<NamedExpression>> multiLayers) {
+ this.multiLayerProjects = multiLayers;
+ }
+
+ @Override
+ public List<Slot> getOutput() {
+ return computeOutput();
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
index b404bc4ad35..8cc18a527a8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
@@ -59,6 +59,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* Each PlanNode represents a single relational operator
@@ -155,6 +156,8 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
protected int nereidsId = -1;
private List<List<Expr>> childrenDistributeExprLists = new ArrayList<>();
+ private List<TupleDescriptor> intermediateOutputTupleDescList = Lists.newArrayList();
+ private List<List<Expr>> intermediateProjectListList = Lists.newArrayList();
protected PlanNode(PlanNodeId id, ArrayList<TupleId> tupleIds, String planNodeName,
StatisticalType statisticalType) {
@@ -536,10 +539,20 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
expBuilder.append(detailPrefix + "limit: " + limit + "\n");
}
if (!CollectionUtils.isEmpty(projectList)) {
- expBuilder.append(detailPrefix).append("projections: ").append(getExplainString(projectList)).append("\n");
- expBuilder.append(detailPrefix).append("project output tuple id: ")
+ expBuilder.append(detailPrefix).append("final projections: ")
+ .append(getExplainString(projectList)).append("\n");
+ expBuilder.append(detailPrefix).append("final project output tuple id: ")
.append(outputTupleDesc.getId().asInt()).append("\n");
}
+ if (!intermediateProjectListList.isEmpty()) {
+ int layers = intermediateProjectListList.size();
+ for (int i = layers - 1; i >= 0; i--) {
+ expBuilder.append(detailPrefix).append("intermediate projections: ")
+ .append(getExplainString(intermediateProjectListList.get(i))).append("\n");
+ expBuilder.append(detailPrefix).append("intermediate tuple id: ")
+ .append(intermediateOutputTupleDescList.get(i).getId().asInt()).append("\n");
+ }
+ }
if (!CollectionUtils.isEmpty(childrenDistributeExprLists)) {
for (List<Expr> distributeExprList : childrenDistributeExprLists) {
expBuilder.append(detailPrefix).append("distribute expr lists: ")
@@ -660,6 +673,19 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
}
}
}
+
+ if (!intermediateOutputTupleDescList.isEmpty()) {
+ intermediateOutputTupleDescList
+ .forEach(
+ tupleDescriptor -> msg.addToIntermediateOutputTupleIdList(tupleDescriptor.getId().asInt()));
+ }
+
+ if (!intermediateProjectListList.isEmpty()) {
+ intermediateProjectListList.forEach(
+ projectList -> msg.addToIntermediateProjectionsList(
+ projectList.stream().map(expr -> expr.treeToThrift()).collect(Collectors.toList())));
+ }
+
if (this instanceof ExchangeNode) {
msg.num_children = 0;
return;
@@ -1221,4 +1247,12 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
public void setNereidsId(int nereidsId) {
this.nereidsId = nereidsId;
}
+
+ public void addIntermediateOutputTupleDescList(TupleDescriptor tupleDescriptor) {
+ intermediateOutputTupleDescList.add(tupleDescriptor);
+ }
+
+ public void addIntermediateProjectList(List<Expr> exprs) {
+ intermediateProjectListList.add(exprs);
+ }
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
index 0f464ba2946..c342d858fe1 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
@@ -74,6 +74,7 @@ public class CreateFunctionTest {
public void test() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
ctx.getSessionVariable().setEnableNereidsPlanner(false);
+ ctx.getSessionVariable().enableFallbackToOriginalPlanner = true;
ctx.getSessionVariable().setEnableFoldConstantByBe(false);
// create database db1
createDatabase(ctx, "create database db1;");
@@ -113,8 +114,8 @@ public class CreateFunctionTest {
Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr);
queryStr = "select db1.id_masking(k1) from db1.tbl1";
- Assert.assertTrue(
- dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
// create alias function with cast
// cast any type to decimal with specific precision and scale
@@ -142,14 +143,16 @@ public class CreateFunctionTest {
queryStr = "select db1.decimal(k3, 4, 1) from db1.tbl1;";
if (Config.enable_decimal_conversion) {
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k3` AS DECIMALV3(4, 1))"));
} else {
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k3` AS DECIMAL(4, 1))"));
}
// cast any type to varchar with fixed length
- createFuncStr = "create alias function db1.varchar(all) with parameter(text) as "
- + "cast(text as varchar(65533));";
+ createFuncStr = "create alias function db1.varchar(all, int) with parameter(text, length) as "
+ + "cast(text as varchar(length));";
createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
Env.getCurrentEnv().createFunction(createFunctionStmt);
@@ -172,7 +175,8 @@ public class CreateFunctionTest {
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
queryStr = "select db1.varchar(k1, 4) from db1.tbl1;";
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k1` AS VARCHAR(65533))"));
// cast any type to char with fixed length
createFuncStr = "create alias function db1.to_char(all, int) with parameter(text, length) as "
@@ -199,7 +203,8 @@ public class CreateFunctionTest {
Assert.assertTrue(constExprLists.get(0).get(0) instanceof StringLiteral);
queryStr = "select db1.to_char(k1, 4) from db1.tbl1;";
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k1` AS CHARACTER"));
}
@Test
@@ -235,8 +240,8 @@ public class CreateFunctionTest {
testFunctionQuery(ctx, queryStr, false);
queryStr = "select id_masking(k1) from db2.tbl1";
- Assert.assertTrue(
- dorisAssert.query(queryStr).explainQuery().contains("concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "concat(left(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 3), '****', right(CAST(CAST(k1 AS BIGINT) AS VARCHAR(65533)), 4))"));
// 4. create alias function with cast
// cast any type to decimal with specific precision and scale
@@ -253,9 +258,11 @@ public class CreateFunctionTest {
queryStr = "select decimal(k3, 4, 1) from db2.tbl1;";
if (Config.enable_decimal_conversion) {
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMALV3(4, 1))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k3` AS DECIMALV3(4, 1))"));
} else {
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k3` AS DECIMAL(4, 1))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k3` AS DECIMAL(4, 1))"));
}
// 5. cast any type to varchar with fixed length
@@ -271,7 +278,8 @@ public class CreateFunctionTest {
testFunctionQuery(ctx, queryStr, true);
queryStr = "select varchar(k1, 4) from db2.tbl1;";
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS VARCHAR(65533))"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k1` AS VARCHAR(65533))"));
// 6. cast any type to char with fixed length
createFuncStr = "create global alias function db2.to_char(all, int) with parameter(text, length) as "
@@ -286,7 +294,8 @@ public class CreateFunctionTest {
testFunctionQuery(ctx, queryStr, true);
queryStr = "select to_char(k1, 4) from db2.tbl1;";
- Assert.assertTrue(dorisAssert.query(queryStr).explainQuery().contains("CAST(`k1` AS CHARACTER)"));
+ Assert.assertTrue(containsIgnoreCase(dorisAssert.query(queryStr).explainQuery(),
+ "CAST(`k1` AS CHARACTER)"));
}
private void testFunctionQuery(ConnectContext ctx, String queryStr, Boolean isStringLiteral) throws Exception {
@@ -320,4 +329,8 @@ public class CreateFunctionTest {
Env.getCurrentEnv().createDb(createDbStmt);
System.out.println(Env.getCurrentInternalCatalog().getDbNames());
}
+
+ private boolean containsIgnoreCase(String str, String sub) {
+ return str.toLowerCase().contains(sub.toLowerCase());
+ }
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
new file mode 100644
index 00000000000..56b67e087d5
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/postprocess/CommonSubExpressionTest.java
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.postprocess;
+
+import org.apache.doris.nereids.processor.post.CommonSubExpressionCollector;
+import org.apache.doris.nereids.processor.post.CommonSubExpressionOpt;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.types.IntegerType;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class CommonSubExpressionTest extends ExpressionRewriteTestHelper {
+ @Test
+ public void testExtractCommonExpr() {
+ List<NamedExpression> exprs = parseProjections("a+b, a+b+1, abs(a+b+1), a");
+ CommonSubExpressionCollector collector =
+ new CommonSubExpressionCollector();
+ exprs.forEach(expr -> collector.visit(expr, null));
+ System.out.println(collector.commonExprByDepth);
+ Assertions.assertEquals(2, collector.commonExprByDepth.size());
+ List<Expression> l1 = collector.commonExprByDepth.get(Integer.valueOf(1))
+ .stream().collect(Collectors.toList());
+ List<Expression> l2 = collector.commonExprByDepth.get(Integer.valueOf(2))
+ .stream().collect(Collectors.toList());
+ Assertions.assertEquals(1, l1.size());
+ assertExpression(l1.get(0), "a+b");
+ Assertions.assertEquals(1, l2.size());
+ assertExpression(l2.get(0), "a+b+1");
+ }
+
+ @Test
+ public void testMultiLayers() throws Exception {
+ List<NamedExpression> exprs = parseProjections("a, a+b, a+b+1, abs(a+b+1), a");
+ Set<Slot> inputSlots = exprs.get(0).getInputSlots();
+ CommonSubExpressionOpt opt = new CommonSubExpressionOpt();
+ Method computeMultLayerProjectionsMethod = CommonSubExpressionOpt.class
+ .getDeclaredMethod("computeMultiLayerProjections", Set.class, List.class);
+ computeMultLayerProjectionsMethod.setAccessible(true);
+ List<List<NamedExpression>> multiLayers = (List<List<NamedExpression>>) computeMultLayerProjectionsMethod
+ .invoke(opt, inputSlots, exprs);
+ System.out.println(multiLayers);
+ Assertions.assertEquals(3, multiLayers.size());
+ List<NamedExpression> l0 = multiLayers.get(0);
+ Assertions.assertEquals(2, l0.size());
+ Assertions.assertTrue(l0.contains(ExprParser.INSTANCE.parseExpression("a")));
+ Assertions.assertTrue(l0.get(1) instanceof Alias);
+ assertExpression(l0.get(1).child(0), "a+b");
+ Assertions.assertEquals(multiLayers.get(1).size(), 3);
+ Assertions.assertEquals(multiLayers.get(2).size(), 5);
+ List<NamedExpression> l2 = multiLayers.get(2);
+ for (int i = 0; i < 5; i++) {
+ Assertions.assertEquals(exprs.get(i).getExprId().asInt(), l2.get(i).getExprId().asInt());
+ }
+
+ }
+
+ private void assertExpression(Expression expr, String str) {
+ Assertions.assertEquals(ExprParser.INSTANCE.parseExpression(str), expr);
+ }
+
+ private List<NamedExpression> parseProjections(String exprList) {
+ List<NamedExpression> result = new ArrayList<>();
+ String[] exprArray = exprList.split(",");
+ for (String item : exprArray) {
+ Expression expr = ExprParser.INSTANCE.parseExpression(item);
+ if (expr instanceof NamedExpression) {
+ result.add((NamedExpression) expr);
+ } else {
+ result.add(new Alias(expr));
+ }
+ }
+ return result;
+ }
+
+ public static class ExprParser {
+ public static ExprParser INSTANCE = new ExprParser();
+ HashMap<String, SlotReference> slotMap = new HashMap<>();
+
+ public Expression parseExpression(String str) {
+ Expression expr = PARSER.parseExpression(str);
+ return expr.accept(DataTypeAssignor.INSTANCE, slotMap);
+ }
+ }
+
+ public static class DataTypeAssignor extends DefaultExpressionRewriter<Map<String, SlotReference>> {
+ public static DataTypeAssignor INSTANCE = new DataTypeAssignor();
+
+ @Override
+ public Expression visitSlot(Slot slot, Map<String, SlotReference> slotMap) {
+ SlotReference exitsSlot = slotMap.get(slot.getName());
+ if (exitsSlot != null) {
+ return exitsSlot;
+ } else {
+ SlotReference slotReference = new SlotReference(slot.getName(), IntegerType.INSTANCE);
+ slotMap.put(slot.getName(), slotReference);
+ return slotReference;
+ }
+ }
+ }
+
+}
diff --git a/regression-test/data/tpch_sf0.1_p1/sql/cse.out b/regression-test/data/tpch_sf0.1_p1/sql/cse.out
new file mode 100644
index 00000000000..5ab44655661
--- /dev/null
+++ b/regression-test/data/tpch_sf0.1_p1/sql/cse.out
@@ -0,0 +1,30 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !cse --
+1 1 3 4
+2 0 3 4
+3 1 5 6
+4 0 5 6
+5 4 10 11
+6 0 7 8
+7 3 11 12
+8 1 10 11
+9 4 14 15
+10 1 12 13
+
+-- !cse_2 --
+17 1 18 19 19
+5 2 7 8 8
+1 3 4 5 5
+15 4 19 20 20
+11 5 16 17 17
+14 6 20 21 21
+23 7 30 31 31
+17 8 25 26 26
+10 9 19 20 20
+24 10 34 35 35
+
+-- !cse_3 --
+12093 13093 14093 15093
+
+-- !cse_4 --
+12093 13093 14093 15093
\ No newline at end of file
diff --git a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
index e6f05c6c765..cf0c03fc3bd 100644
--- a/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
+++ b/regression-test/framework/src/main/groovy/org/apache/doris/regression/action/ExplainAction.groovy
@@ -32,6 +32,7 @@ class ExplainAction implements SuiteAction {
private SuiteContext context
private Set<String> containsStrings = new LinkedHashSet<>()
private Set<String> notContainsStrings = new LinkedHashSet<>()
+ private Map<String, Integer> multiContainsStrings = new HashMap<>()
private String coonType
private Closure checkFunction
@@ -56,6 +57,10 @@ class ExplainAction implements SuiteAction {
containsStrings.add(subString)
}
+ void multiContains(String subString, int n) {
+ multiContainsStrings.put(subString, n);
+ }
+
void notContains(String subString) {
notContainsStrings.add(subString)
}
@@ -112,6 +117,16 @@ class ExplainAction implements SuiteAction {
throw t
}
}
+ for (Map.Entry entry : multiContainsStrings) {
+ int count = explainString.count(entry.key);
+ if (count != entry.value) {
+ String msg = ("Explain and check failed, expect multiContains '${string}' , '${entry.value}' times, actural '${count}' times."
+ + "Actual explain string is:\n${explainString}").toString()
+ log.info(msg)
+ def t = new IllegalStateException(msg)
+ throw t
+ }
+ }
}
}
diff --git a/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy
new file mode 100644
index 00000000000..698dbd3e5d0
--- /dev/null
+++ b/regression-test/suites/tpch_sf0.1_p1/sql/cse.groovy
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// The cases is copied from https://github.com/trinodb/trino/tree/master
+// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/tpcds
+// and modified by Doris.
+
+suite('cse') {
+ def q1 = """select s_suppkey,n_regionkey,(s_suppkey + n_regionkey) + 1 as x, (s_suppkey + n_regionkey) + 2 as y
+ from supplier join nation on s_nationkey=n_nationkey order by s_suppkey , n_regionkey limit 10 ;
+ """
+
+ def q2 = """select s_nationkey,s_suppkey ,(s_nationkey + s_suppkey), (s_nationkey + s_suppkey) + 1, abs((s_nationkey + s_suppkey) + 1)
+ from supplier order by s_suppkey , s_suppkey limit 10 ;"""
+
+ qt_cse "${q1}"
+
+ explain {
+ sql "${q1}"
+ contains "intermediate projections:"
+ }
+
+ qt_cse_2 "${q2}"
+
+ explain {
+ sql "${q2}"
+ multiContains("intermediate projections:", 2)
+ }
+
+ qt_cse_3 """ select sum(s_nationkey),sum(s_nationkey +1 ) ,sum(s_nationkey +2 ) , sum(s_nationkey + 3 ) from supplier ;"""
+
+ qt_cse_4 """select sum(s_nationkey),sum(s_nationkey) + count(1) ,sum(s_nationkey) + 2 * count(1) , sum(s_nationkey) + 3 * count(1) from supplier ;"""
+
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org