You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/14 15:09:27 UTC

[iotdb] 01/04: add `ModelInferenceFunction`

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 4ca7d6db46d625c0b62b59c04dea6e8cdcd71711
Author: liuminghui233 <54...@qq.com>
AuthorDate: Sun May 14 21:32:39 2023 +0800

    add `ModelInferenceFunction`
---
 .../udf/builtin/ModelInferenceFunction.java        | 34 ++++++++++++++++++----
 .../commons/udf/service/UDFManagementService.java  | 19 ++++++++++++
 .../plan/expression/multi/FunctionExpression.java  | 10 +++++++
 .../db/mpp/plan/expression/multi/FunctionType.java |  3 +-
 4 files changed, 59 insertions(+), 7 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
similarity index 52%
copy from server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
copy to node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
index 734ebb4bef4..f52a0348af6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
@@ -17,11 +17,33 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.plan.expression.multi;
+package org.apache.iotdb.commons.udf.builtin;
 
-/** */
-public enum FunctionType {
-  AGGREGATION_FUNCTION,
-  BUILT_IN_SCALAR_FUNCTION,
-  UDF
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public enum ModelInferenceFunction {
+  FORECAST("forecast");
+
+  private final String functionName;
+
+  ModelInferenceFunction(String functionName) {
+    this.functionName = functionName;
+  }
+
+  public String getFunctionName() {
+    return functionName;
+  }
+
+  private static final Set<String> NATIVE_FUNCTION_NAMES =
+      new HashSet<>(
+          Arrays.stream(ModelInferenceFunction.values())
+              .map(ModelInferenceFunction::getFunctionName)
+              .collect(Collectors.toList()));
+
+  public static Set<String> getNativeFunctionNames() {
+    return NATIVE_FUNCTION_NAMES;
+  }
 }
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java b/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
index fbdf8684e91..e7ecb32f008 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
@@ -22,6 +22,7 @@ package org.apache.iotdb.commons.udf.service;
 import org.apache.iotdb.commons.udf.UDFInformation;
 import org.apache.iotdb.commons.udf.UDFTable;
 import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
 import org.apache.iotdb.commons.utils.TestOnly;
 import org.apache.iotdb.udf.api.UDF;
 import org.apache.iotdb.udf.api.UDTF;
@@ -111,8 +112,26 @@ public class UDFManagementService {
     throw new UDFManagementException(errorMessage);
   }
 
+  private void checkIsModelInferenceFunctionName(UDFInformation udfInformation)
+      throws UDFManagementException {
+    String functionName = udfInformation.getFunctionName();
+    String className = udfInformation.getClassName();
+    if (!ModelInferenceFunction.getNativeFunctionNames().contains(functionName.toLowerCase())) {
+      return;
+    }
+
+    String errorMessage =
+        String.format(
+            "Failed to register UDF %s(%s), because the given function name conflicts with the ML model inference function name",
+            functionName, className);
+
+    LOGGER.warn(errorMessage);
+    throw new UDFManagementException(errorMessage);
+  }
+
   private void checkIfRegistered(UDFInformation udfInformation) throws UDFManagementException {
     checkIsBuiltInAggregationFunctionName(udfInformation);
+    checkIsModelInferenceFunctionName(udfInformation);
     String functionName = udfInformation.getFunctionName();
     String className = udfInformation.getClassName();
     UDFInformation information = udfTable.getUDFInformation(functionName);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
index 6a36dc93c9a..b9b8960cb48 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.commons.conf.IoTDBConstant;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
 import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
 import org.apache.iotdb.db.mpp.common.NodeRef;
 import org.apache.iotdb.db.mpp.plan.expression.Expression;
 import org.apache.iotdb.db.mpp.plan.expression.ExpressionType;
@@ -105,6 +106,8 @@ public class FunctionExpression extends Expression {
       functionType = FunctionType.AGGREGATION_FUNCTION;
     } else if (BuiltinScalarFunction.getNativeFunctionNames().contains(functionName)) {
       functionType = FunctionType.BUILT_IN_SCALAR_FUNCTION;
+    } else if (ModelInferenceFunction.getNativeFunctionNames().contains(functionName)) {
+      functionType = FunctionType.MODEL_INFERENCE_FUNCTION;
     } else {
       functionType = FunctionType.UDF;
     }
@@ -125,6 +128,13 @@ public class FunctionExpression extends Expression {
     return functionType == FunctionType.BUILT_IN_SCALAR_FUNCTION;
   }
 
+  public boolean isModelInferenceFunction() {
+    if (functionType == null) {
+      initializeFunctionType();
+    }
+    return functionType == FunctionType.MODEL_INFERENCE_FUNCTION;
+  }
+
   @Override
   public boolean isConstantOperandInternal() {
     if (isConstantOperandCache == null) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
index 734ebb4bef4..c7e9d4ade2c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
@@ -23,5 +23,6 @@ package org.apache.iotdb.db.mpp.plan.expression.multi;
 public enum FunctionType {
   AGGREGATION_FUNCTION,
   BUILT_IN_SCALAR_FUNCTION,
-  UDF
+  UDF,
+  MODEL_INFERENCE_FUNCTION
 }