You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/16 06:58:21 UTC
[iotdb] branch lmh/forecast updated: implement BE
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/lmh/forecast by this push:
new c338fd9fdf implement BE
c338fd9fdf is described below
commit c338fd9fdf0ac9c677dac0505afff263dcd0b244
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Tue May 16 14:56:52 2023 +0800
implement BE
---
.../org/apache/iotdb/db/client/MLNodeClient.java | 28 ++-
.../java/org/apache/iotdb/db/conf/IoTDBConfig.java | 11 +
.../exception/ModelInferenceProcessException.java | 27 +++
.../fragment/FragmentInstanceManager.java | 11 +
.../db/mpp/execution/operator/AggregationUtil.java | 2 +-
.../operator/process/ml/ForecastOperator.java | 232 +++++++++++++++++++++
.../db/mpp/plan/planner/OperatorTreeGenerator.java | 47 +++++
.../planner/plan/node/process/ml/ForecastNode.java | 4 +
.../model/ForecastModelInferenceDescriptor.java | 8 +
.../parameter/model/ModelInferenceDescriptor.java | 4 +
thrift-mlnode/src/main/thrift/mlnode.thrift | 5 +-
11 files changed, 370 insertions(+), 9 deletions(-)
diff --git a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
index 1ff54d43b6..768bf7a74d 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
@@ -30,7 +30,7 @@ import org.apache.iotdb.mlnode.rpc.thrift.TDeleteModelReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
import org.apache.iotdb.rpc.TConfigurationConst;
-import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
@@ -45,6 +45,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -115,14 +117,26 @@ public class MLNodeClient implements AutoCloseable {
}
}
- public TsBlock forecast(String modelPath, TsBlock inputTsBlock) throws TException {
+ public TForecastResp forecast(
+ String modelPath,
+ TsBlock inputTsBlock,
+ List<TSDataType> inputTypeList,
+ List<String> inputColumnNameList,
+ int predictLength)
+ throws TException {
try {
- TForecastReq forecastReq = new TForecastReq(modelPath, tsBlockSerde.serialize(inputTsBlock));
- TForecastResp resp = client.forecast(forecastReq);
- if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- throw new TException("Failed to execute forecast task, because: " + resp.status.message);
+ List<String> reqInputTypeList = new ArrayList<>();
+ for (TSDataType dataType : inputTypeList) {
+ reqInputTypeList.add(dataType.toString());
}
- return tsBlockSerde.deserialize(resp.forecastResult);
+ TForecastReq forecastReq =
+ new TForecastReq(
+ modelPath,
+ tsBlockSerde.serialize(inputTsBlock),
+ reqInputTypeList,
+ inputColumnNameList,
+ predictLength);
+ return client.forecast(forecastReq);
} catch (IOException e) {
throw new TException("An exception occurred while serializing input tsblock", e);
} catch (TException e) {
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
index 882d6c7a22..d08cbaee0d 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
@@ -717,6 +717,9 @@ public class IoTDBConfig {
/** The number of threads in the thread pool that execute insert-tablet tasks. */
private int intoOperationExecutionThreadCount = 2;
+ /** The number of threads in the thread pool that execute model inference tasks. */
+ private int modelInferenceExecutionThreadCount = 10;
+
/** Default TSfile storage is in local file system */
private FSType tsFileStorageFs = FSType.LOCAL;
@@ -1979,6 +1982,14 @@ public class IoTDBConfig {
this.intoOperationExecutionThreadCount = intoOperationExecutionThreadCount;
}
+ public int getModelInferenceExecutionThreadCount() {
+ return modelInferenceExecutionThreadCount;
+ }
+
+ public void setModelInferenceExecutionThreadCount(int modelInferenceExecutionThreadCount) {
+ this.modelInferenceExecutionThreadCount = modelInferenceExecutionThreadCount;
+ }
+
public int getCompactionWriteThroughputMbPerSec() {
return compactionWriteThroughputMbPerSec;
}
diff --git a/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java b/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java
new file mode 100644
index 0000000000..1ddb212f11
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.exception;
+
+public class ModelInferenceProcessException extends RuntimeException {
+
+ public ModelInferenceProcessException(String message) {
+ super(message);
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
index d563f9cf24..74f61f15a6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
@@ -75,6 +75,8 @@ public class FragmentInstanceManager {
private final ExecutorService intoOperationExecutor;
+ private final ExecutorService modelInferenceExecutor;
+
private static final QueryMetricsManager QUERY_METRICS = QueryMetricsManager.getInstance();
public static FragmentInstanceManager getInstance() {
@@ -104,6 +106,11 @@ public class FragmentInstanceManager {
IoTDBThreadPoolFactory.newFixedThreadPool(
IoTDBDescriptor.getInstance().getConfig().getIntoOperationExecutionThreadCount(),
"into-operation-executor");
+
+ this.modelInferenceExecutor =
+ IoTDBThreadPoolFactory.newFixedThreadPool(
+ IoTDBDescriptor.getInstance().getConfig().getModelInferenceExecutionThreadCount(),
+ "model-inference-executor");
}
public FragmentInstanceInfo execDataQueryFragmentInstance(
@@ -314,6 +321,10 @@ public class FragmentInstanceManager {
return intoOperationExecutor;
}
+ public ExecutorService getModelInferenceExecutor() {
+ return modelInferenceExecutor;
+ }
+
private static class InstanceHolder {
private InstanceHolder() {}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
index 586ff48be3..da8efa94e7 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
@@ -223,7 +223,7 @@ public class AggregationUtil {
return timeValueColumnsSizePerLine;
}
- private static long getOutputColumnSizePerLine(
+ public static long getOutputColumnSizePerLine(
TSDataType tsDataType, PartialPath inputSeriesPath) {
switch (tsDataType) {
case INT32:
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
new file mode 100644
index 0000000000..5aac6830fb
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.execution.operator.process.ml;
+
+import org.apache.iotdb.db.client.MLNodeClient;
+import org.apache.iotdb.db.exception.ModelInferenceProcessException;
+import org.apache.iotdb.db.mpp.execution.operator.Operator;
+import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
+import org.apache.iotdb.db.mpp.execution.operator.process.ProcessOperator;
+import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.ColumnBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.TimeColumnBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
+
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.thrift.TException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+import static com.google.common.util.concurrent.Futures.successfulAsList;
+
+public class ForecastOperator implements ProcessOperator {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(ProcessOperator.class);
+
+ private final OperatorContext operatorContext;
+ private final Operator child;
+
+ private final String modelPath;
+ private final List<TSDataType> inputTypeList;
+ private final List<String> inputColumnNameList;
+ private final int expectedPredictLength;
+
+ private final TsBlockBuilder inputTsBlockBuilder;
+
+ private MLNodeClient client;
+ private final ExecutorService modelInferenceExecutor;
+ private ListenableFuture<TForecastResp> forecastExecutionFuture;
+
+ private boolean finished = false;
+
+ private final long maxRetainedSize;
+ private final long maxReturnSize;
+
+ public ForecastOperator(
+ OperatorContext operatorContext,
+ Operator child,
+ String modelPath,
+ List<TSDataType> inputTypeList,
+ List<String> inputColumnNameList,
+ int expectedPredictLength,
+ ExecutorService modelInferenceExecutor,
+ long maxRetainedSize,
+ long maxReturnSize) {
+ this.operatorContext = operatorContext;
+ this.child = child;
+ this.modelPath = modelPath;
+ this.inputTypeList = inputTypeList;
+ this.inputColumnNameList = inputColumnNameList;
+ this.expectedPredictLength = expectedPredictLength;
+ this.inputTsBlockBuilder = new TsBlockBuilder(inputTypeList);
+ this.modelInferenceExecutor = modelInferenceExecutor;
+ this.maxRetainedSize = maxRetainedSize;
+ this.maxReturnSize = maxReturnSize;
+ }
+
+ @Override
+ public OperatorContext getOperatorContext() {
+ return operatorContext;
+ }
+
+ @Override
+ public ListenableFuture<?> isBlocked() {
+ ListenableFuture<?> childBlocked = child.isBlocked();
+ boolean executionDone = forecastExecutionDone();
+ if (executionDone && childBlocked.isDone()) {
+ return NOT_BLOCKED;
+ } else if (childBlocked.isDone()) {
+ return forecastExecutionFuture;
+ } else if (executionDone) {
+ return childBlocked;
+ } else {
+ return successfulAsList(Arrays.asList(forecastExecutionFuture, childBlocked));
+ }
+ }
+
+ private boolean forecastExecutionDone() {
+ if (forecastExecutionFuture == null) {
+ return true;
+ }
+ return forecastExecutionFuture.isDone();
+ }
+
+ @Override
+ public boolean hasNext() throws Exception {
+ return !finished;
+ }
+
+ @Override
+ public TsBlock next() throws Exception {
+ if (forecastExecutionFuture == null) {
+ if (child.hasNextWithTimer()) {
+ TsBlock inputTsBlock = child.nextWithTimer();
+ if (inputTsBlock != null) {
+ appendTsBlockToBuilder(inputTsBlock);
+ }
+ } else {
+ submitForecastTask();
+ }
+ return null;
+ } else {
+ try {
+ if (!forecastExecutionFuture.isDone()) {
+ throw new IllegalStateException(
+ "The operator cannot continue until the forecast execution is done.");
+ }
+
+ TForecastResp forecastResp = forecastExecutionFuture.get();
+ if (forecastResp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ String message =
+ String.format(
+ "Error occurred while executing forecast: %s",
+ forecastResp.getStatus().getMessage());
+ throw new ModelInferenceProcessException(message);
+ }
+
+ finished = true;
+ return new TsBlockSerde().deserialize(forecastResp.bufferForForecastResult());
+ } catch (InterruptedException e) {
+ LOGGER.warn(
+ "{}: interrupted when processing write operation future with exception {}", this, e);
+ Thread.currentThread().interrupt();
+ throw new ModelInferenceProcessException(e.getMessage());
+ } catch (ExecutionException e) {
+ throw new ModelInferenceProcessException(e.getMessage());
+ }
+ }
+ }
+
+ private void appendTsBlockToBuilder(TsBlock inputTsBlock) {
+ TimeColumnBuilder timeColumnBuilder = inputTsBlockBuilder.getTimeColumnBuilder();
+ ColumnBuilder[] columnBuilders = inputTsBlockBuilder.getValueColumnBuilders();
+
+ for (int i = 0; i < inputTsBlock.getPositionCount(); i++) {
+ timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i));
+ for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); columnIndex++) {
+ columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), i);
+ }
+ inputTsBlockBuilder.declarePosition();
+ }
+ }
+
+ private void submitForecastTask() {
+ try {
+ if (client == null) {
+ client = new MLNodeClient();
+ }
+ } catch (TException e) {
+ throw new ModelInferenceProcessException(e.getMessage());
+ }
+
+ TsBlock inputTsBlock = inputTsBlockBuilder.build();
+ inputTsBlock.reverse();
+
+ forecastExecutionFuture =
+ Futures.submit(
+ () ->
+ client.forecast(
+ modelPath,
+ inputTsBlock,
+ inputTypeList,
+ inputColumnNameList,
+ expectedPredictLength),
+ modelInferenceExecutor);
+ }
+
+ @Override
+ public boolean isFinished() throws Exception {
+ return finished;
+ }
+
+ @Override
+ public void close() throws Exception {
+ client.close();
+ if (forecastExecutionFuture != null) {
+ forecastExecutionFuture.cancel(true);
+ }
+ child.close();
+ }
+
+ @Override
+ public long calculateMaxPeekMemory() {
+ return maxReturnSize + maxRetainedSize + child.calculateMaxPeekMemory();
+ }
+
+ @Override
+ public long calculateMaxReturnSize() {
+ return maxReturnSize;
+ }
+
+ @Override
+ public long calculateRetainedSizeAfterCallingNext() {
+ return maxRetainedSize + child.calculateRetainedSizeAfterCallingNext();
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
index 4e945c8be6..e40a87c976 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
@@ -102,6 +102,7 @@ import org.apache.iotdb.db.mpp.execution.operator.process.last.LastQueryOperator
import org.apache.iotdb.db.mpp.execution.operator.process.last.LastQuerySortOperator;
import org.apache.iotdb.db.mpp.execution.operator.process.last.LastQueryUtil;
import org.apache.iotdb.db.mpp.execution.operator.process.last.UpdateLastCacheOperator;
+import org.apache.iotdb.db.mpp.execution.operator.process.ml.ForecastOperator;
import org.apache.iotdb.db.mpp.execution.operator.schema.CountGroupByLevelMergeOperator;
import org.apache.iotdb.db.mpp.execution.operator.schema.CountGroupByLevelScanOperator;
import org.apache.iotdb.db.mpp.execution.operator.schema.CountMergeOperator;
@@ -177,6 +178,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -200,6 +202,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.InputLocation;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OutputColumn;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.SeriesScanOptions;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
import org.apache.iotdb.db.mpp.plan.statement.component.FillPolicy;
import org.apache.iotdb.db.mpp.plan.statement.component.OrderByKey;
import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
@@ -213,6 +216,8 @@ import org.apache.iotdb.db.utils.datastructure.TimeSelector;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.apache.iotdb.tsfile.read.TimeValuePair;
import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.DoubleColumn;
+import org.apache.iotdb.tsfile.read.common.block.column.TimeColumn;
import org.apache.iotdb.tsfile.read.filter.basic.Filter;
import org.apache.iotdb.tsfile.read.filter.operator.Gt;
import org.apache.iotdb.tsfile.read.filter.operator.GtEq;
@@ -241,6 +246,7 @@ import static com.google.common.base.Preconditions.checkArgument;
import static org.apache.iotdb.db.mpp.common.DataNodeEndPoints.isSameNode;
import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.calculateMaxAggregationResultSize;
import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.calculateMaxAggregationResultSizeForLastQuery;
+import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.getOutputColumnSizePerLine;
import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.initTimeRangeIterator;
import static org.apache.iotdb.db.mpp.plan.planner.plan.parameter.SeriesScanOptions.updateFilterUsingTTL;
@@ -1643,6 +1649,47 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
MergeSortComparator.getComparator(sortItemList, sortItemIndexList, sortItemDataTypeList));
}
+ @Override
+ public Operator visitForecast(ForecastNode node, LocalExecutionPlanContext context) {
+ Operator child = node.getChild().accept(this, context);
+ OperatorContext operatorContext =
+ context
+ .getDriverContext()
+ .addOperatorContext(
+ context.getNextOperatorId(),
+ node.getPlanNodeId(),
+ ForecastOperator.class.getSimpleName());
+
+ ForecastModelInferenceDescriptor forecastModelInferenceDescriptor =
+ node.getModelInferenceDescriptor();
+
+ List<TSDataType> inputTypeList = forecastModelInferenceDescriptor.getInputTypeList();
+ int modelInputLength = forecastModelInferenceDescriptor.getModelInputLength();
+ long timeValueColumnsSizePerLine = TimeColumn.SIZE_IN_BYTES_PER_POSITION;
+ for (TSDataType dataType : inputTypeList) {
+ timeValueColumnsSizePerLine += getOutputColumnSizePerLine(dataType, new PartialPath());
+ }
+ long maxRetainedSize = timeValueColumnsSizePerLine * modelInputLength;
+
+ int expectedPredictLength = forecastModelInferenceDescriptor.getExpectedPredictLength();
+ int outputColumnNum = forecastModelInferenceDescriptor.getPredictIndexList().size();
+ long maxReturnSize =
+ (TimeColumn.SIZE_IN_BYTES_PER_POSITION
+ + (long) outputColumnNum * DoubleColumn.SIZE_IN_BYTES_PER_POSITION)
+ * expectedPredictLength;
+
+ return new ForecastOperator(
+ operatorContext,
+ child,
+ forecastModelInferenceDescriptor.getModelPath(),
+ inputTypeList,
+ node.getChild().getOutputColumnNames(),
+ expectedPredictLength,
+ FragmentInstanceManager.getInstance().getModelInferenceExecutor(),
+ maxRetainedSize,
+ maxReturnSize);
+ }
+
@Override
public Operator visitInto(IntoNode node, LocalExecutionPlanContext context) {
Operator child = node.getChild().accept(this, context);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
index f71f1fe663..4780935464 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
@@ -51,6 +51,10 @@ public class ForecastNode extends SingleChildProcessNode {
this.modelInferenceDescriptor = modelInferenceDescriptor;
}
+ public ForecastModelInferenceDescriptor getModelInferenceDescriptor() {
+ return modelInferenceDescriptor;
+ }
+
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitForecast(this, context);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
index f47f319f54..97c1a8cab1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
@@ -88,6 +88,14 @@ public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor {
return modelInputLength;
}
+ public int getModelPredictLength() {
+ return modelPredictLength;
+ }
+
+ public int getExpectedPredictLength() {
+ return expectedPredictLength;
+ }
+
@Override
public LinkedHashMap<String, String> getOutputAttributes() {
if (outputAttributes == null) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
index db11e50f18..1948ebeb98 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
@@ -62,6 +62,10 @@ public abstract class ModelInferenceDescriptor {
return modelId;
}
+ public String getModelPath() {
+ return modelPath;
+ }
+
public List<FunctionExpression> getModelInferenceOutputExpressions() {
return modelInferenceOutputExpressions;
}
diff --git a/thrift-mlnode/src/main/thrift/mlnode.thrift b/thrift-mlnode/src/main/thrift/mlnode.thrift
index abadc79576..46f7b025f4 100644
--- a/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -36,7 +36,10 @@ struct TDeleteModelReq {
struct TForecastReq {
1: required string modelPath
- 2: required binary dataset
+ 2: required binary inputData
+ 3: required list<string> inputTypeList
+ 4: required list<string> inputColumnNameList
+ 5: required i32 predictLength
}
struct TForecastResp {