You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by hu...@apache.org on 2022/07/21 04:02:42 UTC
[doris] branch master updated: [enhancement](Nereids) support case when for TPC-H (#10947)
This is an automated email from the ASF dual-hosted git repository.
huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 329f70dc02 [enhancement](Nereids) support case when for TPC-H (#10947)
329f70dc02 is described below
commit 329f70dc0277a5e1fc5ee314313d746c6bd65c43
Author: yinzhijian <37...@qq.com>
AuthorDate: Thu Jul 21 12:02:37 2022 +0800
[enhancement](Nereids) support case when for TPC-H (#10947)
support case when for TPC-H
for example:
CASE [expression] WHEN [value] THEN [expression] ... ELSE [expression] END
or
CASE WHEN [predicate] THEN [expression] ... ELSE [expression] END
---
.../antlr4/org/apache/doris/nereids/DorisParser.g4 | 7 +-
.../java/org/apache/doris/analysis/CaseExpr.java | 7 ++
.../glue/translator/ExpressionTranslator.java | 22 ++++
.../doris/nereids/parser/LogicalPlanBuilder.java | 52 ++++++++-
.../doris/nereids/trees/expressions/CaseWhen.java | 129 +++++++++++++++++++++
.../nereids/trees/expressions/ExpressionType.java | 4 +-
.../nereids/trees/expressions/WhenClause.java | 91 +++++++++++++++
.../expressions/visitor/ExpressionVisitor.java | 10 ++
.../trees/expressions/ExpressionParserTest.java | 9 ++
9 files changed, 327 insertions(+), 4 deletions(-)
diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
index cf8944423c..5b254c8074 100644
--- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
+++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
@@ -190,7 +190,9 @@ valueExpression
;
primaryExpression
- : constant #constantDefault
+ : CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | constant #constantDefault
| ASTERISK #star
| qualifiedName DOT ASTERISK #star
| identifier LEFT_PAREN DISTINCT? arguments+=expression
@@ -220,6 +222,9 @@ booleanValue
: TRUE | FALSE
;
+whenClause
+ : WHEN condition=expression THEN result=expression
+ ;
// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table`
// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java
index 40f0016453..83fb77c120 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java
@@ -433,4 +433,11 @@ public class CaseExpr extends Expr {
}
return false;
}
+
+ @Override
+ public void finalizeImplForNereids() throws AnalysisException {
+ // nereids do not have CaseExpr, and nereids will unify the types,
+ // so just use the first then type
+ type = children.get(1).getType();
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index 47635edc88..58c3aa7d05 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -21,6 +21,8 @@ import org.apache.doris.analysis.ArithmeticExpr;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.BinaryPredicate.Operator;
import org.apache.doris.analysis.BoolLiteral;
+import org.apache.doris.analysis.CaseExpr;
+import org.apache.doris.analysis.CaseWhenClause;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FloatLiteral;
import org.apache.doris.analysis.FunctionCallExpr;
@@ -33,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Arithmetic;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@@ -47,11 +50,13 @@ import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StringRegexPredicate;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import java.util.ArrayList;
import java.util.List;
+import java.util.Optional;
/**
* Used to translate expression of new optimizer to stale expr.
@@ -213,6 +218,23 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
stringRegexPredicate.right().accept(this, context));
}
+ @Override
+ public Expr visitCaseWhen(CaseWhen caseWhen, PlanTranslatorContext context) {
+ List<CaseWhenClause> caseWhenClauses = new ArrayList<>();
+ for (WhenClause whenClause : caseWhen.getWhenClauses()) {
+ caseWhenClauses.add(new CaseWhenClause(
+ whenClause.left().accept(this, context),
+ whenClause.right().accept(this, context)
+ ));
+ }
+ Expr elseExpr = null;
+ Optional<Expression> defaultValue = caseWhen.getDefaultValue();
+ if (defaultValue.isPresent()) {
+ elseExpr = defaultValue.get().accept(this, context);
+ }
+ return new CaseExpr(null, caseWhenClauses, elseExpr);
+ }
+
// TODO: Supports for `distinct`
@Override
public Expr visitBoundFunction(BoundFunction function, PlanTranslatorContext context) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
index 1c25087dbd..bc3f5cf222 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
@@ -66,6 +66,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
@@ -86,6 +87,7 @@ import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Regexp;
import org.apache.doris.nereids.trees.expressions.StringLiteral;
import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
@@ -105,6 +107,7 @@ import org.antlr.v4.runtime.tree.TerminalNode;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
+import java.util.stream.Collectors;
/**
* Build an logical plan tree with unbounded nodes.
@@ -334,13 +337,58 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
case DorisParser.MINUS:
return new Subtract(left, right);
default:
- throw new IllegalStateException("Unsupported arithmetic binary type: "
- + ctx.operator.getText());
+ throw new IllegalStateException(
+ "Unsupported arithmetic binary type: " + ctx.operator.getText());
}
});
});
}
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ @Override
+ public Expression visitSimpleCase(DorisParser.SimpleCaseContext context) {
+ Expression e = getExpression(context.value);
+ List<WhenClause> whenClauses = context.whenClause().stream()
+ .map(w -> new WhenClause(new EqualTo(e, getExpression(w.condition)), getExpression(w.result)))
+ .collect(Collectors.toList());
+ if (context.elseExpression == null) {
+ return new CaseWhen(whenClauses);
+ }
+ return new CaseWhen(whenClauses, getExpression(context.elseExpression));
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param context the parse tree
+ */
+ @Override
+ public Expression visitSearchedCase(DorisParser.SearchedCaseContext context) {
+ List<WhenClause> whenClauses = context.whenClause().stream()
+ .map(w -> new WhenClause(getExpression(w.condition), getExpression(w.result)))
+ .collect(Collectors.toList());
+ if (context.elseExpression == null) {
+ return new CaseWhen(whenClauses);
+ }
+ return new CaseWhen(whenClauses, getExpression(context.elseExpression));
+ }
+
@Override
public UnboundFunction visitFunctionCall(DorisParser.FunctionCallContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
new file mode 100644
index 0000000000..fdbdd6c813
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions;
+
+import org.apache.doris.nereids.exceptions.UnboundException;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.DataType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * The internal representation of
+ * CASE [expr] WHEN expr THEN expr [WHEN expr THEN expr ...] [ELSE expr] END
+ * Each When/Then is stored as two consecutive children (whenExpr, thenExpr).
+ * If a case expr is given, convert it to equalTo(caseExpr, whenExpr) and set it to whenExpr.
+ * If an else expr is given then it is the last child.
+ */
+public class CaseWhen extends Expression {
+ /**
+ * If default value exists, then defaultValueIndex is the index of the last element in children,
+ * otherwise it is -1
+ */
+ private final int defaultValueIndex;
+
+ public CaseWhen(List<WhenClause> whenClauses) {
+ super(ExpressionType.CASE, whenClauses.toArray(new Expression[0]));
+ defaultValueIndex = -1;
+ }
+
+ public CaseWhen(List<WhenClause> whenClauses, Expression defaultValue) {
+ super(ExpressionType.CASE,
+ ImmutableList.builder().addAll(whenClauses).add(defaultValue).build().toArray(new Expression[0]));
+ defaultValueIndex = children().size() - 1;
+ }
+
+ public List<WhenClause> getWhenClauses() {
+ List<WhenClause> whenClauses = children().stream()
+ .filter(e -> e instanceof WhenClause)
+ .map(e -> (WhenClause) e)
+ .collect(Collectors.toList());
+ return whenClauses;
+ }
+
+ public Optional<Expression> getDefaultValue() {
+ if (defaultValueIndex == -1) {
+ return Optional.empty();
+ }
+ return Optional.of(child(defaultValueIndex));
+ }
+
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitCaseWhen(this, context);
+ }
+
+ @Override
+ public DataType getDataType() {
+ return child(0).getDataType();
+ }
+
+ @Override
+ public boolean nullable() {
+ for (Expression child : children()) {
+ if (child.nullable()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return toSql();
+ }
+
+ @Override
+ public String toSql() throws UnboundException {
+ StringBuilder output = new StringBuilder("CASE ");
+ for (Expression child : children()) {
+ if (child instanceof WhenClause) {
+ output.append(child.toSql());
+ } else {
+ output.append(" ELSE " + child.toSql());
+ }
+ }
+ output.append(" END");
+ return output.toString();
+ }
+
+ @Override
+ public Expression withChildren(List<Expression> children) {
+ Preconditions.checkArgument(children.size() >= 1);
+ List<WhenClause> whenClauseList = new ArrayList<>();
+ Expression defaultValue = null;
+ for (int i = 0; i < children.size(); i++) {
+ if (children.get(i) instanceof WhenClause) {
+ whenClauseList.add((WhenClause) children.get(i));
+ } else if (children.size() - 1 == i) {
+ defaultValue = children.get(i);
+ } else {
+ throw new IllegalArgumentException("The children format needs to be [WhenClause*, DefaultValue+]");
+ }
+ }
+ if (defaultValue == null) {
+ return new CaseWhen(whenClauseList);
+ }
+ return new CaseWhen(whenClauseList, defaultValue);
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java
index 5387203243..9109f30f0c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java
@@ -55,5 +55,7 @@ public enum ExpressionType {
BITXOR,
BITNOT,
FACTORIAL,
- FUNCTION_CALL
+ FUNCTION_CALL,
+ CASE,
+ WHEN_CLAUSE
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
new file mode 100644
index 0000000000..d397a77048
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions;
+
+import org.apache.doris.nereids.exceptions.UnboundException;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.DataType;
+
+import com.google.common.base.Preconditions;
+
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * captures info of a single WHEN expr THEN expr clause.
+ */
+public class WhenClause extends Expression implements BinaryExpression {
+ public WhenClause(Expression operand, Expression result) {
+ super(ExpressionType.WHEN_CLAUSE, operand, result);
+ }
+
+ @Override
+ public String toSql() {
+ return "WHEN " + left().toSql() + " THEN " + right().toSql();
+ }
+
+ @Override
+ public Expression withChildren(List<Expression> children) {
+ Preconditions.checkArgument(children.size() == 2);
+ return new WhenClause(children.get(0), children.get(1));
+ }
+
+ @Override
+ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+ return visitor.visitWhenClause(this, context);
+ }
+
+
+ @Override
+ public DataType getDataType() {
+ // when left() then right()
+ // Depends on the data type of the result
+ return right().getDataType();
+ }
+
+ @Override
+ public boolean nullable() throws UnboundException {
+ // Depends on whether the result is nullable or not
+ return right().nullable();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ WhenClause other = (WhenClause) o;
+ return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right());
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(left(), right());
+ }
+
+ @Override
+ public String toString() {
+ return toSql();
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java
index 7286812e82..1788abdb39 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java
@@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Arithmetic;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
@@ -51,6 +52,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StringLiteral;
import org.apache.doris.nereids.trees.expressions.StringRegexPredicate;
import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
@@ -185,6 +187,14 @@ public abstract class ExpressionVisitor<R, C> {
return visitArithmetic(mod, context);
}
+ public R visitWhenClause(WhenClause whenClause, C context) {
+ return visit(whenClause, context);
+ }
+
+ public R visitCaseWhen(CaseWhen caseWhen, C context) {
+ return visit(caseWhen, context);
+ }
+
/* ********************************************************************************************
* Unbound expressions
* ********************************************************************************************/
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java
index 1014662cbb..7b238fa2de 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java
@@ -137,4 +137,13 @@ public class ExpressionParserTest {
String sort1 = "select a from test order by 1";
assertSql(sort1);
}
+
+ @Test
+ public void testCaseWhen() throws Exception {
+ String caseWhen = "select case a when 1 then 2 else 3 end from test";
+ assertSql(caseWhen);
+
+ String caseWhen2 = "select case when a = 1 then 2 else 3 end from test";
+ assertSql(caseWhen2);
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org