You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/12/02 02:13:17 UTC

[doris] branch master updated: [refactor](datev2) refine function expr for datev2 (#14697)

This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new e9799fab09 [refactor](datev2) refine function expr for datev2 (#14697)
e9799fab09 is described below

commit e9799fab09cae98b64561e74a36bb8b706790d0d
Author: Gabriel <ga...@gmail.com>
AuthorDate: Fri Dec 2 10:13:11 2022 +0800

    [refactor](datev2) refine function expr for datev2 (#14697)
    
    * [refactor](datev2) refine function expr for datev2
    
    * update
---
 .../apache/doris/analysis/FunctionCallExpr.java    | 117 +++++++++------------
 .../java/org/apache/doris/catalog/Function.java    |   4 -
 2 files changed, 47 insertions(+), 74 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index 7a6a500666..0f3618ad4c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -72,44 +72,56 @@ public class FunctionCallExpr extends Expr {
             new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER)
                     .add("stddev").add("stddev_val").add("stddev_samp").add("stddev_pop")
                     .add("variance").add("variance_pop").add("variance_pop").add("var_samp").add("var_pop").build();
-    public static final Map<String, java.util.function.Function<Type[], Type>> DECIMAL_INFER_RULE;
-    public static final java.util.function.Function<Type[], Type> DEFAULT_DECIMAL_INFER_RULE;
+    public static final Map<String, java.util.function.BiFunction<Type[], Type, Type>> PRECISION_INFER_RULE;
+    public static final java.util.function.BiFunction<Type[], Type, Type> DEFAULT_PRECISION_INFER_RULE;
 
     static {
-        java.util.function.Function<Type[], Type> sumRule = (com.google.common.base.Function<Type[], Type>) type -> {
-            Preconditions.checkArgument(type != null && type.length > 0);
-            if (type[0].isDecimalV3()) {
+        java.util.function.BiFunction<Type[], Type, Type> sumRule = (childrenType, returnType) -> {
+            Preconditions.checkArgument(childrenType != null && childrenType.length > 0);
+            if (childrenType[0].isDecimalV3()) {
                 return ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
-                        ((ScalarType) type[0]).getScalarScale());
+                        ((ScalarType) childrenType[0]).getScalarScale());
             } else {
-                return type[0];
+                return returnType;
             }
         };
-        DEFAULT_DECIMAL_INFER_RULE = (com.google.common.base.Function<Type[], Type>) type -> {
-            Preconditions.checkArgument(type != null && type.length > 0);
-            return type[0];
+        DEFAULT_PRECISION_INFER_RULE = (childrenType, returnType) -> {
+            if (childrenType != null && childrenType.length > 0
+                    && childrenType[0].isDecimalV3() && returnType.isDecimalV3()) {
+                return childrenType[0];
+            } else if (childrenType != null && childrenType.length > 0 && childrenType[0].isDatetimeV2()
+                    && returnType.isDatetimeV2()) {
+                return childrenType[0];
+            } else {
+                return returnType;
+            }
         };
-        DECIMAL_INFER_RULE = new HashMap<>();
-        DECIMAL_INFER_RULE.put("sum", sumRule);
-        DECIMAL_INFER_RULE.put("multi_distinct_sum", sumRule);
-        DECIMAL_INFER_RULE.put("avg", (com.google.common.base.Function<Type[], Type>) type -> {
+        PRECISION_INFER_RULE = new HashMap<>();
+        PRECISION_INFER_RULE.put("sum", sumRule);
+        PRECISION_INFER_RULE.put("multi_distinct_sum", sumRule);
+        PRECISION_INFER_RULE.put("avg", (childrenType, returnType) -> {
             // TODO: how to set scale?
-            Preconditions.checkArgument(type != null && type.length > 0);
-            if (type[0].isDecimalV3()) {
+            Preconditions.checkArgument(childrenType != null && childrenType.length > 0);
+            if (childrenType[0].isDecimalV3()) {
                 return ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
-                        ((ScalarType) type[0]).getScalarScale());
+                        ((ScalarType) childrenType[0]).getScalarScale());
             } else {
-                return type[0];
+                return returnType;
             }
         });
-        DECIMAL_INFER_RULE.put("if", (com.google.common.base.Function<Type[], Type>) type -> {
-            Preconditions.checkArgument(type != null && type.length == 3);
-            if (type[1].isDecimalV3() && type[2].isDecimalV3()) {
+        PRECISION_INFER_RULE.put("if", (childrenType, returnType) -> {
+            Preconditions.checkArgument(childrenType != null && childrenType.length == 3);
+            if (childrenType[1].isDecimalV3() && childrenType[2].isDecimalV3()) {
                 return ScalarType.createDecimalV3Type(
-                        Math.max(((ScalarType) type[1]).decimalPrecision(), ((ScalarType) type[2]).decimalPrecision()),
-                        Math.max(((ScalarType) type[1]).decimalScale(), ((ScalarType) type[2]).decimalScale()));
+                        Math.max(((ScalarType) childrenType[1]).decimalPrecision(),
+                                ((ScalarType) childrenType[2]).decimalPrecision()),
+                        Math.max(((ScalarType) childrenType[1]).decimalScale(),
+                                ((ScalarType) childrenType[2]).decimalScale()));
+            } else if (childrenType[1].isDatetimeV2() && childrenType[2].isDatetimeV2()) {
+                return ((ScalarType) childrenType[1]).decimalScale() > ((ScalarType) childrenType[2]).decimalScale()
+                        ? childrenType[1] : childrenType[2];
             } else {
-                return type[0];
+                return returnType;
             }
         });
     }
@@ -1081,11 +1093,6 @@ public class FunctionCallExpr extends Expr {
             }
             fn = getBuiltinFunction(fnName.getFunction(), childTypes,
                     Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
-            if (fn != null && fn.getArgs()[2].isDatetime() && childTypes[2].isDatetimeV2()) {
-                fn.setArgType(childTypes[2], 2);
-            } else if (fn != null && fn.getArgs()[2].isDatetime() && childTypes[2].isDateV2()) {
-                fn.setArgType(ScalarType.DATETIMEV2, 2);
-            }
             if (fn != null && childTypes[2].isDate()) {
                 // cast date to datetime
                 uncheckedCastChild(ScalarType.DATETIME, 2);
@@ -1139,18 +1146,6 @@ public class FunctionCallExpr extends Expr {
             }
             fn = getBuiltinFunction(fnName.getFunction(), childTypes,
                 Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
-            if (fn != null && fn.getArgs()[1].isDatetime() && childTypes[1].isDatetimeV2()) {
-                fn.setArgType(childTypes[1], 1);
-            } else if (fn != null && fn.getArgs()[1].isDatetime() && childTypes[1].isDateV2()) {
-                fn.setArgType(ScalarType.DATETIMEV2, 1);
-            }
-            if (fn != null && childTypes[1].isDate()) {
-                // cast date to datetime
-                uncheckedCastChild(ScalarType.DATETIME, 1);
-            } else if (fn != null && childTypes[1].isDateV2()) {
-                // cast date to datetime
-                uncheckedCastChild(ScalarType.DATETIMEV2, 1);
-            }
         } else if (fnName.getFunction().equalsIgnoreCase("if")) {
             Type[] childTypes = collectChildReturnTypes();
             Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
@@ -1229,8 +1224,6 @@ public class FunctionCallExpr extends Expr {
             fn.setReturnType(new ArrayType(getChild(0).type));
         }
 
-        applyAutoTypeConversionForDatetimeV2();
-
         if (fnName.getFunction().equalsIgnoreCase("from_unixtime")
                 || fnName.getFunction().equalsIgnoreCase("date_format")) {
             // if has only one child, it has default time format: yyyy-MM-dd HH:mm:ss.SSSSSS
@@ -1342,6 +1335,13 @@ public class FunctionCallExpr extends Expr {
             } else {
                 this.type = ScalarType.getDefaultDateType(Type.DATETIME);
             }
+        } else if (TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase())
+                && fn.getReturnType().isDatetimeV2()) {
+            if (children.size() == 1 && children.get(0) instanceof IntLiteral) {
+                this.type = ScalarType.createDatetimeV2Type((int) ((IntLiteral) children.get(0)).getLongValue());
+            } else if (children.size() == 1) {
+                this.type = ScalarType.createDatetimeV2Type(6);
+            }
         } else {
             this.type = fn.getReturnType();
         }
@@ -1383,39 +1383,16 @@ public class FunctionCallExpr extends Expr {
             fn.setReturnType(Type.MAX_DECIMALV2_TYPE);
         }
 
-        if (this.type.isDecimalV3()) {
+        if (this.type.isDecimalV3() || (this.type.isDatetimeV2()
+                && !TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase()))) {
             // TODO(gabriel): If type exceeds max precision of DECIMALV3, we should change it to a double function
-            this.type = DECIMAL_INFER_RULE.getOrDefault(fnName.getFunction(), DEFAULT_DECIMAL_INFER_RULE)
-                    .apply(collectChildReturnTypes());
+            this.type = PRECISION_INFER_RULE.getOrDefault(fnName.getFunction(), DEFAULT_PRECISION_INFER_RULE)
+                    .apply(collectChildReturnTypes(), this.type);
         }
         // rewrite return type if is nested type function
         analyzeNestedFunction();
     }
 
-    private void applyAutoTypeConversionForDatetimeV2() {
-        // Rule1: Now we treat datetimev2 with different precisions as different types and we only register functions
-        // for datetimev2(0). So we must apply an automatic type conversion from datetimev2(0) to the real type.
-        if (fn.getArgs().length == children.size() && fn.getArgs().length > 0) {
-            if (fn.getArgs()[0].isDatetimeV2() && children.get(0).getType().isDatetimeV2()) {
-                fn.setArgType(children.get(0).getType(), 0);
-                if (fn.getReturnType().isDatetimeV2()) {
-                    fn.setReturnType(children.get(0).getType());
-                }
-            }
-        }
-
-        // Rule2: For functions in TIME_FUNCTIONS_WITH_PRECISION, we can't figure out which function should be use when
-        // searching in FunctionSet. So we adjust the return type by hand here.
-        if (TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase())
-                && fn != null && fn.getReturnType().isDatetimeV2()) {
-            if (children.size() == 1 && children.get(0) instanceof IntLiteral) {
-                fn.setReturnType(ScalarType.createDatetimeV2Type((int) ((IntLiteral) children.get(0)).getLongValue()));
-            } else if (children.size() == 1) {
-                fn.setReturnType(ScalarType.createDatetimeV2Type(6));
-            }
-        }
-    }
-
     // if return type is nested type, need to be determined the sub-element type
     private void analyzeNestedFunction() {
         // array
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
index ac9ebec41e..f1a608a9b8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
@@ -183,10 +183,6 @@ public class Function implements Writable {
         this.retType = type;
     }
 
-    public void setArgType(Type type, int i) {
-        argTypes[i] = type;
-    }
-
     public Type[] getArgs() {
         return argTypes;
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org