You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2023/01/03 11:09:57 UTC

[doris] branch master updated: [fix](Nereids) get datatype for binary arithmetic (#15548)

This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new a365486a25 [fix](Nereids) get datatype for binary arithmetic (#15548)
a365486a25 is described below

commit a365486a256055755de45baf1a015ff9bf5d526e
Author: morrySnow <10...@users.noreply.github.com>
AuthorDate: Tue Jan 3 19:09:48 2023 +0800

    [fix](Nereids) get datatype for binary arithmetic (#15548)
    
    it is just a temporary fix for binary arithmetic. Next we will refactor the TypeCoercion rule to make the behavior exactly same with Lagecy planner.
---
 .../doris/nereids/rules/analysis/BindFunction.java | 58 +++++++++++++++-------
 .../trees/expressions/BinaryArithmetic.java        | 12 ++++-
 .../org/apache/doris/common/ExceptionChecker.java  |  2 +-
 3 files changed, 52 insertions(+), 20 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
index 387409753c..7878577322 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java
@@ -27,6 +27,9 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.properties.OrderKey;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
+import org.apache.doris.nereids.rules.expression.rewrite.rules.CharacterLiteralTypeCoercion;
+import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.TVFProperties;
@@ -47,6 +50,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
 import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
 import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
+import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
 
@@ -66,7 +70,7 @@ public class BindFunction implements AnalysisRuleFactory {
                 logicalOneRowRelation().thenApply(ctx -> {
                     LogicalOneRowRelation oneRowRelation = ctx.root;
                     List<NamedExpression> projects = oneRowRelation.getProjects();
-                    List<NamedExpression> boundProjects = bind(projects, ctx.connectContext.getEnv());
+                    List<NamedExpression> boundProjects = bindAndTypeCoercion(projects, ctx.connectContext);
                     if (projects.equals(boundProjects)) {
                         return oneRowRelation;
                     }
@@ -76,15 +80,18 @@ public class BindFunction implements AnalysisRuleFactory {
             RuleType.BINDING_PROJECT_FUNCTION.build(
                 logicalProject().thenApply(ctx -> {
                     LogicalProject<GroupPlan> project = ctx.root;
-                    List<NamedExpression> boundExpr = bind(project.getProjects(), ctx.connectContext.getEnv());
+                    List<NamedExpression> boundExpr = bindAndTypeCoercion(project.getProjects(),
+                            ctx.connectContext);
                     return new LogicalProject<>(boundExpr, project.child(), project.isDistinct());
                 })
             ),
             RuleType.BINDING_AGGREGATE_FUNCTION.build(
                 logicalAggregate().thenApply(ctx -> {
                     LogicalAggregate<GroupPlan> agg = ctx.root;
-                    List<Expression> groupBy = bind(agg.getGroupByExpressions(), ctx.connectContext.getEnv());
-                    List<NamedExpression> output = bind(agg.getOutputExpressions(), ctx.connectContext.getEnv());
+                    List<Expression> groupBy = bindAndTypeCoercion(agg.getGroupByExpressions(),
+                            ctx.connectContext);
+                    List<NamedExpression> output = bindAndTypeCoercion(agg.getOutputExpressions(),
+                            ctx.connectContext);
                     return agg.withGroupByAndOutput(groupBy, output);
                 })
             ),
@@ -93,23 +100,24 @@ public class BindFunction implements AnalysisRuleFactory {
                     LogicalRepeat<GroupPlan> repeat = ctx.root;
                     List<List<Expression>> groupingSets = repeat.getGroupingSets()
                             .stream()
-                            .map(groupingSet -> bind(groupingSet, ctx.connectContext.getEnv()))
+                            .map(groupingSet -> bindAndTypeCoercion(groupingSet, ctx.connectContext))
                             .collect(ImmutableList.toImmutableList());
-                    List<NamedExpression> output = bind(repeat.getOutputExpressions(), ctx.connectContext.getEnv());
+                    List<NamedExpression> output = bindAndTypeCoercion(repeat.getOutputExpressions(),
+                            ctx.connectContext);
                     return repeat.withGroupSetsAndOutput(groupingSets, output);
                 })
             ),
             RuleType.BINDING_FILTER_FUNCTION.build(
                logicalFilter().thenApply(ctx -> {
                    LogicalFilter<GroupPlan> filter = ctx.root;
-                   Set<Expression> conjuncts = bind(filter.getConjuncts(), ctx.connectContext.getEnv());
+                   Set<Expression> conjuncts = bindAndTypeCoercion(filter.getConjuncts(), ctx.connectContext);
                    return new LogicalFilter<>(conjuncts, filter.child());
                })
             ),
             RuleType.BINDING_HAVING_FUNCTION.build(
                 logicalHaving().thenApply(ctx -> {
                     LogicalHaving<GroupPlan> having = ctx.root;
-                    Set<Expression> conjuncts = bind(having.getConjuncts(), ctx.connectContext.getEnv());
+                    Set<Expression> conjuncts = bindAndTypeCoercion(having.getConjuncts(), ctx.connectContext);
                     return new LogicalHaving<>(conjuncts, having.child());
                 })
             ),
@@ -118,10 +126,14 @@ public class BindFunction implements AnalysisRuleFactory {
                     LogicalSort<GroupPlan> sort = ctx.root;
                     List<OrderKey> orderKeys = sort.getOrderKeys().stream()
                             .map(orderKey -> new OrderKey(
-                                    FunctionBinder.INSTANCE.bind(orderKey.getExpr(), ctx.connectContext.getEnv()),
-                                    orderKey.isAsc(),
-                                    orderKey.isNullFirst()
-                            ))
+                                        bindAndTypeCoercion(orderKey.getExpr(),
+                                                ctx.connectContext.getEnv(),
+                                                new ExpressionRewriteContext(ctx.connectContext)
+                                                ),
+                                        orderKey.isAsc(),
+                                        orderKey.isNullFirst())
+
+                            )
                             .collect(ImmutableList.toImmutableList());
                     return new LogicalSort<>(orderKeys, sort.child());
                 })
@@ -129,8 +141,10 @@ public class BindFunction implements AnalysisRuleFactory {
             RuleType.BINDING_JOIN_FUNCTION.build(
                 logicalJoin().thenApply(ctx -> {
                     LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
-                    List<Expression> hashConjuncts = bind(join.getHashJoinConjuncts(), ctx.connectContext.getEnv());
-                    List<Expression> otherConjuncts = bind(join.getOtherJoinConjuncts(), ctx.connectContext.getEnv());
+                    List<Expression> hashConjuncts = bindAndTypeCoercion(join.getHashJoinConjuncts(),
+                            ctx.connectContext);
+                    List<Expression> otherConjuncts = bindAndTypeCoercion(join.getOtherJoinConjuncts(),
+                            ctx.connectContext);
                     return new LogicalJoin<>(join.getJoinType(), hashConjuncts, otherConjuncts,
                             join.getHint(),
                             join.left(), join.right());
@@ -145,15 +159,23 @@ public class BindFunction implements AnalysisRuleFactory {
         );
     }
 
-    private <E extends Expression> List<E> bind(List<? extends E> exprList, Env env) {
+    private <E extends Expression> List<E> bindAndTypeCoercion(List<? extends E> exprList, ConnectContext ctx) {
+        ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
         return exprList.stream()
-            .map(expr -> FunctionBinder.INSTANCE.bind(expr, env))
+            .map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext))
             .collect(Collectors.toList());
     }
 
-    private <E extends Expression> Set<E> bind(Set<? extends E> exprSet, Env env) {
+    private <E extends Expression> E bindAndTypeCoercion(E expr, Env env, ExpressionRewriteContext ctx) {
+        expr = FunctionBinder.INSTANCE.bind(expr, env);
+        expr = (E) CharacterLiteralTypeCoercion.INSTANCE.rewrite(expr, ctx);
+        return (E) TypeCoercion.INSTANCE.rewrite(expr, null);
+    }
+
+    private <E extends Expression> Set<E> bindAndTypeCoercion(Set<? extends E> exprSet, ConnectContext ctx) {
+        ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
         return exprSet.stream()
-                .map(expr -> FunctionBinder.INSTANCE.bind(expr, env))
+                .map(expr -> bindAndTypeCoercion(expr, ctx.getEnv(), rewriteContext))
                 .collect(Collectors.toSet());
     }
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java
index ec99e5e2d2..ec45b4f707 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryArithmetic.java
@@ -22,6 +22,7 @@ import org.apache.doris.nereids.exceptions.UnboundException;
 import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
 
 /**
  * binary arithmetic operator. Such as +, -, *, /.
@@ -41,7 +42,16 @@ public abstract class BinaryArithmetic extends BinaryOperator implements Propaga
 
     @Override
     public DataType getDataType() throws UnboundException {
-        return left().getDataType();
+        if (left().getDataType().equals(right().getDataType())) {
+            return left().getDataType();
+        } else {
+            try {
+                return TypeCoercionUtils.findCommonNumericsType(left().getDataType(), right().getDataType());
+            } catch (Exception e) {
+                return TypeCoercionUtils.findTightestCommonType(left().getDataType(), right().getDataType())
+                        .orElseGet(() -> left().getDataType());
+            }
+        }
     }
 
     @Override
diff --git a/fe/fe-core/src/test/java/org/apache/doris/common/ExceptionChecker.java b/fe/fe-core/src/test/java/org/apache/doris/common/ExceptionChecker.java
index 42e7dae87d..26d5d1dc47 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/common/ExceptionChecker.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/common/ExceptionChecker.java
@@ -66,13 +66,13 @@ public class ExceptionChecker {
         try {
             runnable.run();
         } catch (Throwable e) {
-            e.printStackTrace();
             if (expectedType.isInstance(e)) {
                 if (!Strings.isNullOrEmpty(exceptionMsg)) {
                     if (!e.getMessage().contains(exceptionMsg)) {
                         AssertionFailedError assertion = new AssertionFailedError(
                                 "expected msg: " + exceptionMsg + ", actual: " + e.getMessage());
                         assertion.initCause(e);
+                        assertion.printStackTrace();
                         throw assertion;
                     }
                 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org