You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2023/01/19 07:30:32 UTC

[doris] branch master updated: [fix](nereids) fix bug in CaseWhen.getDataType and add some missing case for findTightestCommonType (#15776)

This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 0144c51ddb [fix](nereids) fix bug in CaseWhen.getDataType and add some missing case for findTightestCommonType (#15776)
0144c51ddb is described below

commit 0144c51ddbcc9eac59a50d5e628d4e5ed111f35d
Author: minghong <en...@gmail.com>
AuthorDate: Thu Jan 19 15:30:25 2023 +0800

    [fix](nereids) fix bug in CaseWhen.getDataType and add some missing case for findTightestCommonType (#15776)
---
 .../doris/nereids/trees/expressions/CaseWhen.java    |  9 ++++++++-
 .../apache/doris/nereids/util/TypeCoercionUtils.java | 20 ++++++++++++++++++++
 .../doris/nereids/util/TypeCoercionUtilsTest.java    |  3 ++-
 .../data/nereids_syntax_p0/test_query_between.out    |  1 +
 .../suites/nereids_syntax_p0/explain.groovy          | 11 +++++++++++
 .../nereids_syntax_p0/test_query_between.groovy      |  4 ++--
 6 files changed, 44 insertions(+), 4 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
index 94015ec5ff..2ebdc5cecd 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CaseWhen.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions;
 import org.apache.doris.nereids.exceptions.UnboundException;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
@@ -78,7 +79,13 @@ public class CaseWhen extends Expression {
 
     @Override
     public DataType getDataType() {
-        return child(0).getDataType();
+        DataType outputType = child(0).getDataType();
+        for (Expression child : children) {
+            DataType tempType = outputType;
+            outputType = TypeCoercionUtils.findTightestCommonType(null,
+                            outputType, child.getDataType()).orElseGet(() -> tempType);
+        }
+        return outputType;
     }
 
     @Override
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
index 0c4e70ee00..6ea711efee 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
@@ -239,6 +239,26 @@ public class TypeCoercionUtils {
             } else if (left instanceof DateV2Type || right instanceof DateV2Type) {
                 tightestCommonType = DateV2Type.INSTANCE;
             }
+        } else if (left instanceof DoubleType && right instanceof DecimalV2Type
+                || left instanceof DecimalV2Type && right instanceof DoubleType) {
+            tightestCommonType = DoubleType.INSTANCE;
+        } else if (left instanceof DecimalV2Type && right instanceof DecimalV2Type) {
+            tightestCommonType = DecimalV2Type.widerDecimalV2Type((DecimalV2Type) left, (DecimalV2Type) right);
+        } else if (left instanceof FloatType && right instanceof DecimalV2Type
+                || left instanceof DecimalV2Type && right instanceof FloatType) {
+            //TODO: need refactor. let operator upgrade data type.
+            if (binaryOperator != null) {
+                // for arithmetic, like Float + Decimal, upgrade to Double
+                tightestCommonType = DoubleType.INSTANCE;
+            } else {
+                //of other case, like
+                //          case
+                //            when 1=1 then cast(1 as int)
+                //            when 1>1 then cast(1 as float)
+                //            else 0.0 end;
+                //do not upgrade data type, keep Float
+                tightestCommonType = FloatType.INSTANCE;
+            }
         } else if (canCompareDate(left, right)) {
             if (binaryOperator instanceof BinaryArithmetic) {
                 tightestCommonType = IntegerType.INSTANCE;
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java
index 16ac52b8ec..c09f6e8797 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java
@@ -145,7 +145,8 @@ public class TypeCoercionUtilsTest {
         testFindTightestCommonType(BigIntType.INSTANCE, IntegerType.INSTANCE, BigIntType.INSTANCE);
         testFindTightestCommonType(StringType.INSTANCE, StringType.INSTANCE, IntegerType.INSTANCE);
         testFindTightestCommonType(StringType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE);
-        testFindTightestCommonType(DoubleType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
+        testFindTightestCommonType(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
+        testFindTightestCommonType(FloatType.INSTANCE, FloatType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT);
         testFindTightestCommonType(VarcharType.createVarcharType(10), CharType.createCharType(8), CharType.createCharType(10));
         testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), VarcharType.createVarcharType(10));
         testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), CharType.createCharType(10));
diff --git a/regression-test/data/nereids_syntax_p0/test_query_between.out b/regression-test/data/nereids_syntax_p0/test_query_between.out
index c97f6621c7..258945a391 100644
--- a/regression-test/data/nereids_syntax_p0/test_query_between.out
+++ b/regression-test/data/nereids_syntax_p0/test_query_between.out
@@ -34,6 +34,7 @@ false
 -- !between11 --
 
 -- !between12 --
+6.333
 
 -- !between13 --
 123.123
diff --git a/regression-test/suites/nereids_syntax_p0/explain.groovy b/regression-test/suites/nereids_syntax_p0/explain.groovy
index 734c33fc69..251b490b26 100644
--- a/regression-test/suites/nereids_syntax_p0/explain.groovy
+++ b/regression-test/suites/nereids_syntax_p0/explain.groovy
@@ -53,4 +53,15 @@ suite("nereids_explain") {
         sql("plan with s as (select * from supplier) select * from s as s1, s as s2")
         contains "*LogicalSubQueryAlias"
     }
+
+    explain {
+        sql """
+        verbose 
+        select case 
+            when 1=1 then cast(1 as int) 
+            when 1>1 then cast(1 as float)
+            else 0.0 end;
+            """
+        contains "SlotDescriptor{id=0, col=null, colUniqueId=null, type=FLOAT, nullable=false}"
+    }
 }
diff --git a/regression-test/suites/nereids_syntax_p0/test_query_between.groovy b/regression-test/suites/nereids_syntax_p0/test_query_between.groovy
index 4c5971e3cb..7a76a76da3 100644
--- a/regression-test/suites/nereids_syntax_p0/test_query_between.groovy
+++ b/regression-test/suites/nereids_syntax_p0/test_query_between.groovy
@@ -36,6 +36,6 @@ suite("nereids_test_query_between", "query,p0") {
                 and \"9999-12-31 12:12:12\" order by k1, k2, k3, k4"""
     qt_between11 """select k10 from ${tableName} where k10 between \"2015-04-02\"
                 and \"9999-12-31\" order by k1, k2, k3, k4"""
-    qt_between12 "select k9 from ${tableName} where k9 between -1 and 6.333 order by k1, k2, k3, k4"
-    qt_between13 "select k5 from ${tableName} where k5 between 0 and 1243.5 order by k1, k2, k3, k4"
+    qt_between12 "select k9 from ${tableName} where k9 between -1 and 6.34 order by k1, k2, k3, k4"
+    qt_between13 "select k5 from ${tableName} where k5 between 0 and 1243.6 order by k1, k2, k3, k4"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org