You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2020/02/17 14:13:56 UTC

[incubator-doris] branch master updated: [UDF] Fix bug that UDF can't handle constant null value (#2914)

This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fb52c5  [UDF] Fix bug that UDF can't handle constant null value (#2914)
0fb52c5 is described below

commit 0fb52c514b5191d48079f804f3123571f0e47c5c
Author: Mingyu Chen <mo...@gmail.com>
AuthorDate: Mon Feb 17 22:13:50 2020 +0800

    [UDF] Fix bug that UDF can't handle constant null value (#2914)
    
    This CL modify the `evalExpr()` of ExpressionFunctions, so that it won't change the
    `FunctionCallExpr` to `NullLiteral` when there is null parameter in UDF. Which will fix the
    problem described in ISSUE: #2913
---
 .../apache/doris/analysis/CreateFunctionStmt.java  |   8 +-
 .../apache/doris/analysis/ExpressionFunctions.java |  38 +++-----
 .../java/org/apache/doris/catalog/Catalog.java     |   4 +
 .../java/org/apache/doris/catalog/Function.java    |  15 ++-
 .../java/org/apache/doris/catalog/FunctionSet.java |  38 ++++++--
 .../java/org/apache/doris/qe/StmtExecutor.java     |   2 +-
 .../org/apache/doris/alter/BatchRollupJobTest.java |   4 +-
 .../apache/doris/catalog/CreateFunctionTest.java   | 105 +++++++++++++++++++++
 .../org/apache/doris/utframe/UtFrameUtils.java     |   6 +-
 gensrc/script/doris_builtins_functions.py          |  14 +++
 gensrc/script/gen_builtins_functions.py            |  11 ++-
 11 files changed, 202 insertions(+), 43 deletions(-)

diff --git a/fe/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
index 05c742d..a69412d 100644
--- a/fe/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
+++ b/fe/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java
@@ -17,9 +17,6 @@
 
 package org.apache.doris.analysis;
 
-import com.google.common.base.Strings;
-import com.google.common.collect.ImmutableSortedMap;
-import org.apache.commons.codec.binary.Hex;
 import org.apache.doris.catalog.AggregateFunction;
 import org.apache.doris.catalog.Catalog;
 import org.apache.doris.catalog.Function;
@@ -31,6 +28,11 @@ import org.apache.doris.common.UserException;
 import org.apache.doris.mysql.privilege.PrivPredicate;
 import org.apache.doris.qe.ConnectContext;
 
+import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableSortedMap;
+
+import org.apache.commons.codec.binary.Hex;
+
 import java.io.IOException;
 import java.io.InputStream;
 import java.net.URL;
diff --git a/fe/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java b/fe/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java
index e62b05a..ee9a6f5 100644
--- a/fe/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java
+++ b/fe/src/main/java/org/apache/doris/analysis/ExpressionFunctions.java
@@ -17,18 +17,20 @@
 
 package org.apache.doris.analysis;
 
-import com.google.common.base.Joiner;
-import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableMultimap;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
+import org.apache.doris.catalog.Catalog;
 import org.apache.doris.catalog.Function;
 import org.apache.doris.catalog.ScalarType;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.AnalysisException;
 import org.apache.doris.rewrite.FEFunction;
 import org.apache.doris.rewrite.FEFunctions;
+
+import com.google.common.base.Joiner;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMultimap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 
@@ -46,12 +48,6 @@ public enum ExpressionFunctions {
 
     private static final Logger LOG = LogManager.getLogger(ExpressionFunctions.class);
     private ImmutableMultimap<String, FEFunctionInvoker> functions;
-    // For most build-in functions, it will return NullLiteral when params contain NullLiteral.
-    // But a few functions need to handle NullLiteral differently, such as "if". It need to add
-    // an attribute to LiteralExpr to mark null and check the attribute to decide whether to
-    // replace the result with NullLiteral when function finished. It leaves to be realized.
-    // TODO chenhao16.
-    private ImmutableSet<String> nonNullResultWithNullParamFunctions;
 
     private ExpressionFunctions() {
         registerFunctions();
@@ -71,8 +67,13 @@ public enum ExpressionFunctions {
             Function fn = constExpr.getFn();
             
             Preconditions.checkNotNull(fn, "Expr's fn can't be null.");
-            // null
-            if (!nonNullResultWithNullParamFunctions.contains(fn.getFunctionName().getFunction())) {
+            
+            // return NullLiteral directly iff:
+            // 1. Not UDF
+            // 2. Not in NonNullResultWithNullParamFunctions
+            // 3. Has null parameter
+            if (!Catalog.getCurrentCatalog().isNonNullResultWithNullParamFunction(fn.getFunctionName().getFunction())
+                    && !fn.isUdf()) {
                 for (Expr e : constExpr.getChildren()) {
                     if (e instanceof NullLiteral) {
                         return new NullLiteral();
@@ -144,15 +145,6 @@ public enum ExpressionFunctions {
             }
         }
         this.functions = mapBuilder.build();
-
-        // Functions that need to handle null.
-        ImmutableSet.Builder<String> setBuilder =
-                new ImmutableSet.Builder<String>();
-        setBuilder.add("if");
-        setBuilder.add("hll_hash");
-        setBuilder.add("concat_ws");
-        setBuilder.add("ifnull");
-        this.nonNullResultWithNullParamFunctions = setBuilder.build();
     }
 
     public static class FEFunctionInvoker {
diff --git a/fe/src/main/java/org/apache/doris/catalog/Catalog.java b/fe/src/main/java/org/apache/doris/catalog/Catalog.java
index 52850fc..0c73b10 100644
--- a/fe/src/main/java/org/apache/doris/catalog/Catalog.java
+++ b/fe/src/main/java/org/apache/doris/catalog/Catalog.java
@@ -5317,6 +5317,10 @@ public class Catalog {
         return functionSet.getBulitinFunctions();
     }
 
+    public boolean isNonNullResultWithNullParamFunction(String funcName) {
+        return functionSet.isNonNullResultWithNullParamFunctions(funcName);
+    }
+
     /**
      * create cluster
      *
diff --git a/fe/src/main/java/org/apache/doris/catalog/Function.java b/fe/src/main/java/org/apache/doris/catalog/Function.java
index f5f7710..52d36a5 100644
--- a/fe/src/main/java/org/apache/doris/catalog/Function.java
+++ b/fe/src/main/java/org/apache/doris/catalog/Function.java
@@ -19,15 +19,17 @@ package org.apache.doris.catalog;
 
 import static org.apache.doris.common.io.IOUtils.writeOptionString;
 
-import com.google.common.base.Joiner;
-import com.google.common.base.Preconditions;
-import com.google.common.collect.Lists;
 import org.apache.doris.analysis.FunctionName;
 import org.apache.doris.analysis.HdfsURI;
 import org.apache.doris.common.io.Text;
 import org.apache.doris.common.io.Writable;
 import org.apache.doris.thrift.TFunction;
 import org.apache.doris.thrift.TFunctionBinaryType;
+
+import com.google.common.base.Joiner;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 
@@ -208,6 +210,13 @@ public class Function implements Writable {
     public void setChecksum(String checksum) { this.checksum = checksum; }
     public String getChecksum() { return checksum; }
 
+    // TODO(cmy): Currently we judge whether it is UDF by wheter the 'location' is set.
+    // Maybe we should use a separate variable to identify,
+    // but additional variables need to modify the persistence information.
+    public boolean isUdf() {
+        return location != null;
+    }
+
     // Returns a string with the signature in human readable format:
     // FnName(argtype1, argtyp2).  e.g. Add(int, int)
     public String signatureString() {
diff --git a/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java
index a5b436a..114c41c 100644
--- a/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -17,9 +17,6 @@
 
 package org.apache.doris.catalog;
 
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import org.apache.doris.analysis.ArithmeticExpr;
 import org.apache.doris.analysis.BinaryPredicate;
 import org.apache.doris.analysis.CastExpr;
@@ -27,14 +24,20 @@ import org.apache.doris.analysis.InPredicate;
 import org.apache.doris.analysis.IsNullPredicate;
 import org.apache.doris.analysis.LikePredicate;
 import org.apache.doris.builtins.ScalarBuiltins;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 
-
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 public class FunctionSet {
     private static final Logger LOG = LogManager.getLogger(FunctionSet.class);
@@ -46,6 +49,16 @@ public class FunctionSet {
     // FunctionResolutionOrder.
     private final HashMap<String, List<Function>> functions;
 
+    // For most build-in functions, it will return NullLiteral when params contain NullLiteral.
+    // But a few functions need to handle NullLiteral differently, such as "if". It need to add
+    // an attribute to LiteralExpr to mark null and check the attribute to decide whether to
+    // replace the result with NullLiteral when function finished. It leaves to be realized.
+    // Functions in this set is defined in `gensrc/script/doris_builtins_functions.py`,
+    // and will be built automatically.
+
+    // cmy: This does not contain any user defined functions. All UDFs handle null values by themselves.
+    private ImmutableSet<String> nonNullResultWithNullParamFunctions;
+
     public FunctionSet() {
         functions = Maps.newHashMap();
     }
@@ -63,6 +76,18 @@ public class FunctionSet {
         InPredicate.initBuiltins(this);
     }
 
+    public void buildNonNullResultWithNullParamFunction(Set<String> funcNames) {
+        ImmutableSet.Builder<String> setBuilder = new ImmutableSet.Builder<String>();
+        for (String funcName : funcNames) {
+            setBuilder.add(funcName);
+        }
+        this.nonNullResultWithNullParamFunctions = setBuilder.build();
+    }
+
+    public boolean isNonNullResultWithNullParamFunctions(String funcName) {
+        return nonNullResultWithNullParamFunctions.contains(funcName);
+    }
+
     private static final Map<Type, String> MIN_UPDATE_SYMBOL =
         ImmutableMap.<Type, String>builder()
                 .put(Type.BOOLEAN,
@@ -746,8 +771,7 @@ public class FunctionSet {
         return null;
     }
 
-    // Only used
-    public boolean addFunction(Function fn) {
+    private boolean addFunction(Function fn, boolean isBuiltin) {
         // TODO: add this to persistent store
         if (getFunction(fn, Function.CompareMode.IS_INDISTINGUISHABLE) != null) {
             return false;
@@ -791,7 +815,7 @@ public class FunctionSet {
      * Adds a builtin to this database. The function must not already exist.
      */
     public void addBuiltin(Function fn) {
-        addFunction(fn);
+        addFunction(fn, true);
     }
 
     // Populate all the aggregate builtins in the catalog.
diff --git a/fe/src/main/java/org/apache/doris/qe/StmtExecutor.java b/fe/src/main/java/org/apache/doris/qe/StmtExecutor.java
index de9517f..5430c75 100644
--- a/fe/src/main/java/org/apache/doris/qe/StmtExecutor.java
+++ b/fe/src/main/java/org/apache/doris/qe/StmtExecutor.java
@@ -547,7 +547,7 @@ public class StmtExecutor {
 
         coord.exec();
 
-        // if python's MysqlDb get error after sendfields, it can't catch the excpetion
+        // if python's MysqlDb get error after sendfields, it can't catch the exception
         // so We need to send fields after first batch arrived
 
         // send result
diff --git a/fe/src/test/java/org/apache/doris/alter/BatchRollupJobTest.java b/fe/src/test/java/org/apache/doris/alter/BatchRollupJobTest.java
index 1c92db4..a24e307 100644
--- a/fe/src/test/java/org/apache/doris/alter/BatchRollupJobTest.java
+++ b/fe/src/test/java/org/apache/doris/alter/BatchRollupJobTest.java
@@ -40,8 +40,6 @@ public class BatchRollupJobTest {
 
     private static String runningDir = "fe/mocked/BatchRollupJobTest/" + UUID.randomUUID().toString() + "/";
 
-    private static ConnectContext ctx = UtFrameUtils.createDefaultCtx();
-
     @BeforeClass
     public static void setup() throws Exception {
         UtFrameUtils.createMinDorisCluster(runningDir);
@@ -49,7 +47,7 @@ public class BatchRollupJobTest {
 
     @Test
     public void test() throws Exception {
-        System.out.println("xxx");
+        ConnectContext ctx = UtFrameUtils.createDefaultCtx();
         // create database db1
         String createDbStmtStr = "create database db1;";
         CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
diff --git a/fe/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
new file mode 100644
index 0000000..74c37d7
--- /dev/null
+++ b/fe/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.catalog;
+
+import org.apache.doris.analysis.CreateDbStmt;
+import org.apache.doris.analysis.CreateFunctionStmt;
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.FunctionCallExpr;
+import org.apache.doris.common.jmockit.Deencapsulation;
+import org.apache.doris.planner.PlanFragment;
+import org.apache.doris.planner.Planner;
+import org.apache.doris.planner.UnionNode;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.qe.QueryState;
+import org.apache.doris.qe.StmtExecutor;
+import org.apache.doris.utframe.UtFrameUtils;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.List;
+import java.util.UUID;
+
+/*
+ * Author: Chenmingyu
+ * Date: Feb 16, 2020
+ */
+
+public class CreateFunctionTest {
+
+    private static String runningDir = "fe/mocked/CreateFunctionTest/" + UUID.randomUUID().toString() + "/";
+
+    @BeforeClass
+    public static void setup() throws Exception {
+        UtFrameUtils.createMinDorisCluster(runningDir);
+    }
+
+    @AfterClass
+    public static void teardown() {
+        File file = new File("fe/mocked/CreateFunctionTest/");
+        file.delete();
+    }
+
+    @Test
+    public void test() throws Exception {
+        ConnectContext ctx = UtFrameUtils.createDefaultCtx();
+
+        // create database db1
+        String createDbStmtStr = "create database db1;";
+        CreateDbStmt createDbStmt = (CreateDbStmt) UtFrameUtils.parseAndAnalyzeStmt(createDbStmtStr, ctx);
+        Catalog.getCurrentCatalog().createDb(createDbStmt);
+        System.out.println(Catalog.getCurrentCatalog().getDbNames());
+
+        Database db = Catalog.getCurrentCatalog().getDb("default_cluster:db1");
+        Assert.assertNotNull(db);
+
+        String createFuncStr = "create function db1.my_add(VARCHAR(1024)) RETURNS BOOLEAN properties\n" +
+                "(\n" +
+                "\"symbol\" =  \"_ZN9doris_udf6AddUdfEPNS_15FunctionContextERKNS_9StringValE\",\n" +
+                "\"prepare_fn\" = \"_ZN9doris_udf13AddUdfPrepareEPNS_15FunctionContextENS0_18FunctionStateScopeE\",\n" +
+                "\"close_fn\" = \"_ZN9doris_udf11AddUdfCloseEPNS_15FunctionContextENS0_18FunctionStateScopeE\",\n" +
+                "\"object_file\" = \"http://nmg01-inf-dorishb00.nmg01.baidu.com:8456/libcmy_udf.so\"\n" +
+                ");";
+        
+        CreateFunctionStmt createFunctionStmt = (CreateFunctionStmt) UtFrameUtils.parseAndAnalyzeStmt(createFuncStr, ctx);
+        Catalog.getCurrentCatalog().createFunction(createFunctionStmt);
+
+        List<Function> functions = db.getFunctions();
+        Assert.assertEquals(1, functions.size());
+        Assert.assertTrue(functions.get(0).isUdf());
+
+        String queryStr = "select db1.my_add(null)";
+        ctx.getState().reset();
+        StmtExecutor stmtExecutor = new StmtExecutor(ctx, queryStr);
+        stmtExecutor.execute();
+        Assert.assertNotEquals(QueryState.MysqlStateType.ERR, ctx.getState().getStateType());
+        Planner planner = stmtExecutor.planner();
+        Assert.assertEquals(1, planner.getFragments().size());
+        PlanFragment fragment = planner.getFragments().get(0);
+        Assert.assertTrue(fragment.getPlanRoot() instanceof UnionNode);
+        UnionNode unionNode =  (UnionNode)fragment.getPlanRoot();
+        List<List<Expr>> constExprLists = Deencapsulation.getField(unionNode, "constExprLists_");
+        Assert.assertEquals(1, constExprLists.size());
+        Assert.assertEquals(1, constExprLists.get(0).size());
+        Assert.assertTrue(constExprLists.get(0).get(0) instanceof FunctionCallExpr);
+    }
+}
diff --git a/fe/src/test/java/org/apache/doris/utframe/UtFrameUtils.java b/fe/src/test/java/org/apache/doris/utframe/UtFrameUtils.java
index 325afc9..f35b1af 100644
--- a/fe/src/test/java/org/apache/doris/utframe/UtFrameUtils.java
+++ b/fe/src/test/java/org/apache/doris/utframe/UtFrameUtils.java
@@ -42,6 +42,7 @@ import com.google.common.collect.Maps;
 
 import java.io.IOException;
 import java.io.StringReader;
+import java.nio.channels.SocketChannel;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
@@ -49,8 +50,9 @@ import java.util.Random;
 public class UtFrameUtils {
 
     // Help to create a mocked ConnectContext.
-    public static ConnectContext createDefaultCtx() {
-        ConnectContext ctx = new ConnectContext();
+    public static ConnectContext createDefaultCtx() throws IOException {
+        SocketChannel channel = SocketChannel.open();
+        ConnectContext ctx = new ConnectContext(channel);
         ctx.setCluster(SystemInfoService.DEFAULT_CLUSTER);
         ctx.setCurrentUserIdentity(UserIdentity.ROOT);
         ctx.setQualifiedUser(PaloAuth.ROOT_USER);
diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py
index 0b938a0..919c0ff 100755
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -708,5 +708,19 @@ visible_functions = [
     [['grouping'], 'BIGINT', ['BIGINT'], '_ZN5doris21GroupingSetsFunctions8groupingEPN9doris_udf15FunctionContextERKNS1_9BigIntValE'],
 ]
 
+# Except the following functions, other function will directly return
+# null if there is null parameters.
+# Functions in this set will handle null values, not just return null.
+#
+# This set is only used to replace 'functions with null parameters' with NullLiteral
+# when applying FoldConstantsRule rules on the FE side.
+# TODO(cmy): Are these functions only required to handle null values?
+non_null_result_with_null_param_functions = [
+    'if',
+    'hll_hash',
+    'concat_ws',
+    'ifnull'
+]
+
 invisible_functions = [
 ]
diff --git a/gensrc/script/gen_builtins_functions.py b/gensrc/script/gen_builtins_functions.py
index 8e906b7..647a812 100755
--- a/gensrc/script/gen_builtins_functions.py
+++ b/gensrc/script/gen_builtins_functions.py
@@ -36,6 +36,8 @@ package org.apache.doris.builtins;\n\
 \n\
 import org.apache.doris.catalog.PrimitiveType;\n\
 import org.apache.doris.catalog.FunctionSet;\n\
+import com.google.common.collect.Sets;\n\
+import java.util.Set;\n\
 \n\
 public class ScalarBuiltins { \n\
     public static void initBuiltins(FunctionSet functionSet) { \
@@ -111,9 +113,16 @@ def generate_fe_registry_init(filename):
     for entry in meta_data_entries:
         for name in entry["sql_names"]:
             java_output = generate_fe_entry(entry, name)
-            java_registry_file.write("    functionSet.addScalarBuiltin(%s);\n" % java_output)
+            java_registry_file.write("        functionSet.addScalarBuiltin(%s);\n" % java_output)
 
     java_registry_file.write("\n")
+
+    # add non_null_result_with_null_param_functions
+    java_registry_file.write("        Set<String> funcNames = Sets.newHashSet();\n")
+    for entry in doris_builtins_functions.non_null_result_with_null_param_functions:
+        java_registry_file.write("        funcNames.add(\"%s\");\n" % entry)
+    java_registry_file.write("        functionSet.buildNonNullResultWithNullParamFunction(funcNames);\n");
+
     java_registry_file.write(java_registry_epilogue)
     java_registry_file.close()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org