You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/14 15:09:27 UTC
[iotdb] 01/04: add `ModelInferenceFunction`
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 4ca7d6db46d625c0b62b59c04dea6e8cdcd71711
Author: liuminghui233 <54...@qq.com>
AuthorDate: Sun May 14 21:32:39 2023 +0800
add `ModelInferenceFunction`
---
.../udf/builtin/ModelInferenceFunction.java | 34 ++++++++++++++++++----
.../commons/udf/service/UDFManagementService.java | 19 ++++++++++++
.../plan/expression/multi/FunctionExpression.java | 10 +++++++
.../db/mpp/plan/expression/multi/FunctionType.java | 3 +-
4 files changed, 59 insertions(+), 7 deletions(-)
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
similarity index 52%
copy from server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
copy to node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
index 734ebb4bef4..f52a0348af6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/ModelInferenceFunction.java
@@ -17,11 +17,33 @@
* under the License.
*/
-package org.apache.iotdb.db.mpp.plan.expression.multi;
+package org.apache.iotdb.commons.udf.builtin;
-/** */
-public enum FunctionType {
- AGGREGATION_FUNCTION,
- BUILT_IN_SCALAR_FUNCTION,
- UDF
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public enum ModelInferenceFunction {
+ FORECAST("forecast");
+
+ private final String functionName;
+
+ ModelInferenceFunction(String functionName) {
+ this.functionName = functionName;
+ }
+
+ public String getFunctionName() {
+ return functionName;
+ }
+
+ private static final Set<String> NATIVE_FUNCTION_NAMES =
+ new HashSet<>(
+ Arrays.stream(ModelInferenceFunction.values())
+ .map(ModelInferenceFunction::getFunctionName)
+ .collect(Collectors.toList()));
+
+ public static Set<String> getNativeFunctionNames() {
+ return NATIVE_FUNCTION_NAMES;
+ }
}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java b/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
index fbdf8684e91..e7ecb32f008 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java
@@ -22,6 +22,7 @@ package org.apache.iotdb.commons.udf.service;
import org.apache.iotdb.commons.udf.UDFInformation;
import org.apache.iotdb.commons.udf.UDFTable;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
import org.apache.iotdb.commons.utils.TestOnly;
import org.apache.iotdb.udf.api.UDF;
import org.apache.iotdb.udf.api.UDTF;
@@ -111,8 +112,26 @@ public class UDFManagementService {
throw new UDFManagementException(errorMessage);
}
+ private void checkIsModelInferenceFunctionName(UDFInformation udfInformation)
+ throws UDFManagementException {
+ String functionName = udfInformation.getFunctionName();
+ String className = udfInformation.getClassName();
+ if (!ModelInferenceFunction.getNativeFunctionNames().contains(functionName.toLowerCase())) {
+ return;
+ }
+
+ String errorMessage =
+ String.format(
+ "Failed to register UDF %s(%s), because the given function name conflicts with the ML model inference function name",
+ functionName, className);
+
+ LOGGER.warn(errorMessage);
+ throw new UDFManagementException(errorMessage);
+ }
+
private void checkIfRegistered(UDFInformation udfInformation) throws UDFManagementException {
checkIsBuiltInAggregationFunctionName(udfInformation);
+ checkIsModelInferenceFunctionName(udfInformation);
String functionName = udfInformation.getFunctionName();
String className = udfInformation.getClassName();
UDFInformation information = udfTable.getUDFInformation(functionName);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
index 6a36dc93c9a..b9b8960cb48 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
import org.apache.iotdb.db.mpp.common.NodeRef;
import org.apache.iotdb.db.mpp.plan.expression.Expression;
import org.apache.iotdb.db.mpp.plan.expression.ExpressionType;
@@ -105,6 +106,8 @@ public class FunctionExpression extends Expression {
functionType = FunctionType.AGGREGATION_FUNCTION;
} else if (BuiltinScalarFunction.getNativeFunctionNames().contains(functionName)) {
functionType = FunctionType.BUILT_IN_SCALAR_FUNCTION;
+ } else if (ModelInferenceFunction.getNativeFunctionNames().contains(functionName)) {
+ functionType = FunctionType.MODEL_INFERENCE_FUNCTION;
} else {
functionType = FunctionType.UDF;
}
@@ -125,6 +128,13 @@ public class FunctionExpression extends Expression {
return functionType == FunctionType.BUILT_IN_SCALAR_FUNCTION;
}
+ public boolean isModelInferenceFunction() {
+ if (functionType == null) {
+ initializeFunctionType();
+ }
+ return functionType == FunctionType.MODEL_INFERENCE_FUNCTION;
+ }
+
@Override
public boolean isConstantOperandInternal() {
if (isConstantOperandCache == null) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
index 734ebb4bef4..c7e9d4ade2c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionType.java
@@ -23,5 +23,6 @@ package org.apache.iotdb.db.mpp.plan.expression.multi;
public enum FunctionType {
AGGREGATION_FUNCTION,
BUILT_IN_SCALAR_FUNCTION,
- UDF
+ UDF,
+ MODEL_INFERENCE_FUNCTION
}