You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by hu...@apache.org on 2022/07/21 04:02:42 UTC

[doris] branch master updated: [enhancement](Nereids) support case when for TPC-H (#10947)

This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 329f70dc02 [enhancement](Nereids) support case when for TPC-H (#10947)
329f70dc02 is described below

commit 329f70dc0277a5e1fc5ee314313d746c6bd65c43
Author: yinzhijian <37...@qq.com>
AuthorDate: Thu Jul 21 12:02:37 2022 +0800

    [enhancement](Nereids) support case when for TPC-H (#10947)
    
    support case when for TPC-H
    for example:
    CASE [expression] WHEN [value] THEN [expression] ... ELSE [expression] END
    or
    CASE WHEN [predicate] THEN [expression] ... ELSE [expression] END
---
 .../antlr4/org/apache/doris/nereids/DorisParser.g4 |   7 +-
 .../java/org/apache/doris/analysis/CaseExpr.java   |   7 ++
 .../glue/translator/ExpressionTranslator.java      |  22 ++++
 .../doris/nereids/parser/LogicalPlanBuilder.java   |  52 ++++++++-
 .../doris/nereids/trees/expressions/CaseWhen.java  | 129 +++++++++++++++++++++
 .../nereids/trees/expressions/ExpressionType.java  |   4 +-
 .../nereids/trees/expressions/WhenClause.java      |  91 +++++++++++++++
 .../expressions/visitor/ExpressionVisitor.java     |  10 ++
 .../trees/expressions/ExpressionParserTest.java    |   9 ++
 9 files changed, 327 insertions(+), 4 deletions(-)

diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
index cf8944423c..5b254c8074 100644
--- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
+++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4
@@ -190,7 +190,9 @@ valueExpression
     ;
 
 primaryExpression
-    : constant                                                                                 #constantDefault
+    : CASE whenClause+ (ELSE elseExpression=expression)? END                                   #searchedCase
+    | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END                  #simpleCase
+    | constant                                                                                 #constantDefault
     | ASTERISK                                                                                 #star
     | qualifiedName DOT ASTERISK                                                               #star
     | identifier LEFT_PAREN DISTINCT? arguments+=expression
@@ -220,6 +222,9 @@ booleanValue
     : TRUE | FALSE
     ;
 
+whenClause
+    : WHEN condition=expression THEN result=expression
+    ;
 
 // this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table`
 // replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java
index 40f0016453..83fb77c120 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CaseExpr.java
@@ -433,4 +433,11 @@ public class CaseExpr extends Expr {
         }
         return false;
     }
+
+    @Override
+    public void finalizeImplForNereids() throws AnalysisException {
+        // nereids do not have CaseExpr, and nereids will unify the types,
+        // so just use the first then type
+        type = children.get(1).getType();
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index 47635edc88..58c3aa7d05 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -21,6 +21,8 @@ import org.apache.doris.analysis.ArithmeticExpr;
 import org.apache.doris.analysis.BinaryPredicate;
 import org.apache.doris.analysis.BinaryPredicate.Operator;
 import org.apache.doris.analysis.BoolLiteral;
+import org.apache.doris.analysis.CaseExpr;
+import org.apache.doris.analysis.CaseWhenClause;
 import org.apache.doris.analysis.Expr;
 import org.apache.doris.analysis.FloatLiteral;
 import org.apache.doris.analysis.FunctionCallExpr;
@@ -33,6 +35,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Arithmetic;
 import org.apache.doris.nereids.trees.expressions.Between;
 import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
 import org.apache.doris.nereids.trees.expressions.DoubleLiteral;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
@@ -47,11 +50,13 @@ import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.StringRegexPredicate;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
 import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
 import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Optional;
 
 /**
  * Used to translate expression of new optimizer to stale expr.
@@ -213,6 +218,23 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
                 stringRegexPredicate.right().accept(this, context));
     }
 
+    @Override
+    public Expr visitCaseWhen(CaseWhen caseWhen, PlanTranslatorContext context) {
+        List<CaseWhenClause> caseWhenClauses = new ArrayList<>();
+        for (WhenClause whenClause : caseWhen.getWhenClauses()) {
+            caseWhenClauses.add(new CaseWhenClause(
+                    whenClause.left().accept(this, context),
+                    whenClause.right().accept(this, context)
+            ));
+        }
+        Expr elseExpr = null;
+        Optional<Expression> defaultValue = caseWhen.getDefaultValue();
+        if (defaultValue.isPresent()) {
+            elseExpr = defaultValue.get().accept(this, context);
+        }
+        return new CaseExpr(null, caseWhenClauses, elseExpr);
+    }
+
     // TODO: Supports for `distinct`
     @Override
     public Expr visitBoundFunction(BoundFunction function, PlanTranslatorContext context) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
index 1c25087dbd..bc3f5cf222 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
@@ -66,6 +66,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.And;
 import org.apache.doris.nereids.trees.expressions.Between;
 import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
 import org.apache.doris.nereids.trees.expressions.Divide;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -86,6 +87,7 @@ import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.Regexp;
 import org.apache.doris.nereids.trees.expressions.StringLiteral;
 import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
@@ -105,6 +107,7 @@ import org.antlr.v4.runtime.tree.TerminalNode;
 import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
+import java.util.stream.Collectors;
 
 /**
  * Build an logical plan tree with unbounded nodes.
@@ -334,13 +337,58 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
                     case DorisParser.MINUS:
                         return new Subtract(left, right);
                     default:
-                        throw new IllegalStateException("Unsupported arithmetic binary type: "
-                            + ctx.operator.getText());
+                        throw new IllegalStateException(
+                                "Unsupported arithmetic binary type: " + ctx.operator.getText());
                 }
             });
         });
     }
 
+    /**
+     * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+     * {{{
+     *   CASE [expression]
+     *    WHEN [value] THEN [expression]
+     *    ...
+     *    ELSE [expression]
+     *   END
+     * }}}
+     */
+    @Override
+    public Expression visitSimpleCase(DorisParser.SimpleCaseContext context) {
+        Expression e = getExpression(context.value);
+        List<WhenClause> whenClauses = context.whenClause().stream()
+                .map(w -> new WhenClause(new EqualTo(e, getExpression(w.condition)), getExpression(w.result)))
+                .collect(Collectors.toList());
+        if (context.elseExpression == null) {
+            return new CaseWhen(whenClauses);
+        }
+        return new CaseWhen(whenClauses, getExpression(context.elseExpression));
+    }
+
+    /**
+     * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+     * {{{
+     *   CASE
+     *    WHEN [predicate] THEN [expression]
+     *    ...
+     *    ELSE [expression]
+     *   END
+     * }}}
+     *
+     * @param context the parse tree
+     */
+    @Override
+    public Expression visitSearchedCase(DorisParser.SearchedCaseContext context) {
+        List<WhenClause> whenClauses = context.whenClause().stream()
+                .map(w -> new WhenClause(getExpression(w.condition), getExpression(w.result)))
+                .collect(Collectors.toList());
+        if (context.elseExpression == null) {
+            return new CaseWhen(whenClauses);
+        }
+        return new CaseWhen(whenClauses, getExpression(context.elseExpression));
+    }
+
     @Override
     public UnboundFunction visitFunctionCall(DorisParser.FunctionCallContext ctx) {
         return ParserUtils.withOrigin(ctx, () -> {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
new file mode 100644
index 0000000000..fdbdd6c813
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions;
+
+import org.apache.doris.nereids.exceptions.UnboundException;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.DataType;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * The internal representation of
+ * CASE [expr] WHEN expr THEN expr [WHEN expr THEN expr ...] [ELSE expr] END
+ * Each When/Then is stored as two consecutive children (whenExpr, thenExpr).
+ * If a case expr is given, convert it to equalTo(caseExpr, whenExpr) and set it to whenExpr.
+ * If an else expr is given then it is the last child.
+ */
+public class CaseWhen extends Expression {
+    /**
+     * If default value exists, then defaultValueIndex is the index of the last element in children,
+     * otherwise it is -1
+     */
+    private final int defaultValueIndex;
+
+    public CaseWhen(List<WhenClause> whenClauses) {
+        super(ExpressionType.CASE, whenClauses.toArray(new Expression[0]));
+        defaultValueIndex = -1;
+    }
+
+    public CaseWhen(List<WhenClause> whenClauses, Expression defaultValue) {
+        super(ExpressionType.CASE,
+                ImmutableList.builder().addAll(whenClauses).add(defaultValue).build().toArray(new Expression[0]));
+        defaultValueIndex = children().size() - 1;
+    }
+
+    public List<WhenClause> getWhenClauses() {
+        List<WhenClause> whenClauses = children().stream()
+                .filter(e -> e instanceof WhenClause)
+                .map(e -> (WhenClause) e)
+                .collect(Collectors.toList());
+        return whenClauses;
+    }
+
+    public Optional<Expression> getDefaultValue() {
+        if (defaultValueIndex == -1) {
+            return Optional.empty();
+        }
+        return Optional.of(child(defaultValueIndex));
+    }
+
+    public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+        return visitor.visitCaseWhen(this, context);
+    }
+
+    @Override
+    public DataType getDataType() {
+        return child(0).getDataType();
+    }
+
+    @Override
+    public boolean nullable() {
+        for (Expression child : children()) {
+            if (child.nullable()) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    @Override
+    public String toString() {
+        return toSql();
+    }
+
+    @Override
+    public String toSql() throws UnboundException {
+        StringBuilder output = new StringBuilder("CASE ");
+        for (Expression child : children()) {
+            if (child instanceof WhenClause) {
+                output.append(child.toSql());
+            } else {
+                output.append(" ELSE " + child.toSql());
+            }
+        }
+        output.append(" END");
+        return output.toString();
+    }
+
+    @Override
+    public Expression withChildren(List<Expression> children) {
+        Preconditions.checkArgument(children.size() >= 1);
+        List<WhenClause> whenClauseList = new ArrayList<>();
+        Expression defaultValue = null;
+        for (int i = 0; i < children.size(); i++) {
+            if (children.get(i) instanceof WhenClause) {
+                whenClauseList.add((WhenClause) children.get(i));
+            } else if (children.size() - 1 == i) {
+                defaultValue = children.get(i);
+            } else {
+                throw new IllegalArgumentException("The children format needs to be [WhenClause*, DefaultValue+]");
+            }
+        }
+        if (defaultValue == null) {
+            return new CaseWhen(whenClauseList);
+        }
+        return new CaseWhen(whenClauseList, defaultValue);
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java
index 5387203243..9109f30f0c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionType.java
@@ -55,5 +55,7 @@ public enum ExpressionType {
     BITXOR,
     BITNOT,
     FACTORIAL,
-    FUNCTION_CALL
+    FUNCTION_CALL,
+    CASE,
+    WHEN_CLAUSE
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
new file mode 100644
index 0000000000..d397a77048
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions;
+
+import org.apache.doris.nereids.exceptions.UnboundException;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.DataType;
+
+import com.google.common.base.Preconditions;
+
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * captures info of a single WHEN expr THEN expr clause.
+ */
+public class WhenClause extends Expression implements BinaryExpression {
+    public WhenClause(Expression operand, Expression result) {
+        super(ExpressionType.WHEN_CLAUSE, operand, result);
+    }
+
+    @Override
+    public String toSql() {
+        return "WHEN " + left().toSql() + " THEN " + right().toSql();
+    }
+
+    @Override
+    public Expression withChildren(List<Expression> children) {
+        Preconditions.checkArgument(children.size() == 2);
+        return new WhenClause(children.get(0), children.get(1));
+    }
+
+    @Override
+    public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
+        return visitor.visitWhenClause(this, context);
+    }
+
+
+    @Override
+    public DataType getDataType() {
+        // when left() then right()
+        // Depends on the data type of the result
+        return right().getDataType();
+    }
+
+    @Override
+    public boolean nullable() throws UnboundException {
+        // Depends on whether the result is nullable or not
+        return right().nullable();
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        if (!super.equals(o)) {
+            return false;
+        }
+        WhenClause other = (WhenClause) o;
+        return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right());
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(left(), right());
+    }
+
+    @Override
+    public String toString() {
+        return toSql();
+    }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java
index 7286812e82..1788abdb39 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ExpressionVisitor.java
@@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Arithmetic;
 import org.apache.doris.nereids.trees.expressions.Between;
 import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
 import org.apache.doris.nereids.trees.expressions.Divide;
@@ -51,6 +52,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.StringLiteral;
 import org.apache.doris.nereids.trees.expressions.StringRegexPredicate;
 import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
 import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
 
@@ -185,6 +187,14 @@ public abstract class ExpressionVisitor<R, C> {
         return visitArithmetic(mod, context);
     }
 
+    public R visitWhenClause(WhenClause whenClause, C context) {
+        return visit(whenClause, context);
+    }
+
+    public R visitCaseWhen(CaseWhen caseWhen, C context) {
+        return visit(caseWhen, context);
+    }
+
     /* ********************************************************************************************
      * Unbound expressions
      * ********************************************************************************************/
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java
index 1014662cbb..7b238fa2de 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/ExpressionParserTest.java
@@ -137,4 +137,13 @@ public class ExpressionParserTest {
         String sort1 = "select a from test order by 1";
         assertSql(sort1);
     }
+
+    @Test
+    public void testCaseWhen() throws Exception {
+        String caseWhen = "select case a when 1 then 2 else 3 end from test";
+        assertSql(caseWhen);
+
+        String caseWhen2 = "select case when a = 1 then 2 else 3 end from test";
+        assertSql(caseWhen2);
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org