You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/16 06:58:21 UTC

[iotdb] branch lmh/forecast updated: implement BE

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/lmh/forecast by this push:
     new c338fd9fdf implement BE
c338fd9fdf is described below

commit c338fd9fdf0ac9c677dac0505afff263dcd0b244
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Tue May 16 14:56:52 2023 +0800

    implement BE
---
 .../org/apache/iotdb/db/client/MLNodeClient.java   |  28 ++-
 .../java/org/apache/iotdb/db/conf/IoTDBConfig.java |  11 +
 .../exception/ModelInferenceProcessException.java  |  27 +++
 .../fragment/FragmentInstanceManager.java          |  11 +
 .../db/mpp/execution/operator/AggregationUtil.java |   2 +-
 .../operator/process/ml/ForecastOperator.java      | 232 +++++++++++++++++++++
 .../db/mpp/plan/planner/OperatorTreeGenerator.java |  47 +++++
 .../planner/plan/node/process/ml/ForecastNode.java |   4 +
 .../model/ForecastModelInferenceDescriptor.java    |   8 +
 .../parameter/model/ModelInferenceDescriptor.java  |   4 +
 thrift-mlnode/src/main/thrift/mlnode.thrift        |   5 +-
 11 files changed, 370 insertions(+), 9 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
index 1ff54d43b6..768bf7a74d 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
@@ -30,7 +30,7 @@ import org.apache.iotdb.mlnode.rpc.thrift.TDeleteModelReq;
 import org.apache.iotdb.mlnode.rpc.thrift.TForecastReq;
 import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
 import org.apache.iotdb.rpc.TConfigurationConst;
-import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
 import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
 
@@ -45,6 +45,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 
@@ -115,14 +117,26 @@ public class MLNodeClient implements AutoCloseable {
     }
   }
 
-  public TsBlock forecast(String modelPath, TsBlock inputTsBlock) throws TException {
+  public TForecastResp forecast(
+      String modelPath,
+      TsBlock inputTsBlock,
+      List<TSDataType> inputTypeList,
+      List<String> inputColumnNameList,
+      int predictLength)
+      throws TException {
     try {
-      TForecastReq forecastReq = new TForecastReq(modelPath, tsBlockSerde.serialize(inputTsBlock));
-      TForecastResp resp = client.forecast(forecastReq);
-      if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
-        throw new TException("Failed to execute forecast task, because: " + resp.status.message);
+      List<String> reqInputTypeList = new ArrayList<>();
+      for (TSDataType dataType : inputTypeList) {
+        reqInputTypeList.add(dataType.toString());
       }
-      return tsBlockSerde.deserialize(resp.forecastResult);
+      TForecastReq forecastReq =
+          new TForecastReq(
+              modelPath,
+              tsBlockSerde.serialize(inputTsBlock),
+              reqInputTypeList,
+              inputColumnNameList,
+              predictLength);
+      return client.forecast(forecastReq);
     } catch (IOException e) {
       throw new TException("An exception occurred while serializing input tsblock", e);
     } catch (TException e) {
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
index 882d6c7a22..d08cbaee0d 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
@@ -717,6 +717,9 @@ public class IoTDBConfig {
   /** The number of threads in the thread pool that execute insert-tablet tasks. */
   private int intoOperationExecutionThreadCount = 2;
 
+  /** The number of threads in the thread pool that execute model inference tasks. */
+  private int modelInferenceExecutionThreadCount = 10;
+
   /** Default TSfile storage is in local file system */
   private FSType tsFileStorageFs = FSType.LOCAL;
 
@@ -1979,6 +1982,14 @@ public class IoTDBConfig {
     this.intoOperationExecutionThreadCount = intoOperationExecutionThreadCount;
   }
 
+  public int getModelInferenceExecutionThreadCount() {
+    return modelInferenceExecutionThreadCount;
+  }
+
+  public void setModelInferenceExecutionThreadCount(int modelInferenceExecutionThreadCount) {
+    this.modelInferenceExecutionThreadCount = modelInferenceExecutionThreadCount;
+  }
+
   public int getCompactionWriteThroughputMbPerSec() {
     return compactionWriteThroughputMbPerSec;
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java b/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java
new file mode 100644
index 0000000000..1ddb212f11
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/exception/ModelInferenceProcessException.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.exception;
+
+public class ModelInferenceProcessException extends RuntimeException {
+
+  public ModelInferenceProcessException(String message) {
+    super(message);
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
index d563f9cf24..74f61f15a6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
@@ -75,6 +75,8 @@ public class FragmentInstanceManager {
 
   private final ExecutorService intoOperationExecutor;
 
+  private final ExecutorService modelInferenceExecutor;
+
   private static final QueryMetricsManager QUERY_METRICS = QueryMetricsManager.getInstance();
 
   public static FragmentInstanceManager getInstance() {
@@ -104,6 +106,11 @@ public class FragmentInstanceManager {
         IoTDBThreadPoolFactory.newFixedThreadPool(
             IoTDBDescriptor.getInstance().getConfig().getIntoOperationExecutionThreadCount(),
             "into-operation-executor");
+
+    this.modelInferenceExecutor =
+        IoTDBThreadPoolFactory.newFixedThreadPool(
+            IoTDBDescriptor.getInstance().getConfig().getModelInferenceExecutionThreadCount(),
+            "model-inference-executor");
   }
 
   public FragmentInstanceInfo execDataQueryFragmentInstance(
@@ -314,6 +321,10 @@ public class FragmentInstanceManager {
     return intoOperationExecutor;
   }
 
+  public ExecutorService getModelInferenceExecutor() {
+    return modelInferenceExecutor;
+  }
+
   private static class InstanceHolder {
 
     private InstanceHolder() {}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
index 586ff48be3..da8efa94e7 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/AggregationUtil.java
@@ -223,7 +223,7 @@ public class AggregationUtil {
     return timeValueColumnsSizePerLine;
   }
 
-  private static long getOutputColumnSizePerLine(
+  public static long getOutputColumnSizePerLine(
       TSDataType tsDataType, PartialPath inputSeriesPath) {
     switch (tsDataType) {
       case INT32:
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
new file mode 100644
index 0000000000..5aac6830fb
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/process/ml/ForecastOperator.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.execution.operator.process.ml;
+
+import org.apache.iotdb.db.client.MLNodeClient;
+import org.apache.iotdb.db.exception.ModelInferenceProcessException;
+import org.apache.iotdb.db.mpp.execution.operator.Operator;
+import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
+import org.apache.iotdb.db.mpp.execution.operator.process.ProcessOperator;
+import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.ColumnBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.TimeColumnBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
+
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.thrift.TException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+import static com.google.common.util.concurrent.Futures.successfulAsList;
+
+public class ForecastOperator implements ProcessOperator {
+
+  private static final Logger LOGGER = LoggerFactory.getLogger(ProcessOperator.class);
+
+  private final OperatorContext operatorContext;
+  private final Operator child;
+
+  private final String modelPath;
+  private final List<TSDataType> inputTypeList;
+  private final List<String> inputColumnNameList;
+  private final int expectedPredictLength;
+
+  private final TsBlockBuilder inputTsBlockBuilder;
+
+  private MLNodeClient client;
+  private final ExecutorService modelInferenceExecutor;
+  private ListenableFuture<TForecastResp> forecastExecutionFuture;
+
+  private boolean finished = false;
+
+  private final long maxRetainedSize;
+  private final long maxReturnSize;
+
+  public ForecastOperator(
+      OperatorContext operatorContext,
+      Operator child,
+      String modelPath,
+      List<TSDataType> inputTypeList,
+      List<String> inputColumnNameList,
+      int expectedPredictLength,
+      ExecutorService modelInferenceExecutor,
+      long maxRetainedSize,
+      long maxReturnSize) {
+    this.operatorContext = operatorContext;
+    this.child = child;
+    this.modelPath = modelPath;
+    this.inputTypeList = inputTypeList;
+    this.inputColumnNameList = inputColumnNameList;
+    this.expectedPredictLength = expectedPredictLength;
+    this.inputTsBlockBuilder = new TsBlockBuilder(inputTypeList);
+    this.modelInferenceExecutor = modelInferenceExecutor;
+    this.maxRetainedSize = maxRetainedSize;
+    this.maxReturnSize = maxReturnSize;
+  }
+
+  @Override
+  public OperatorContext getOperatorContext() {
+    return operatorContext;
+  }
+
+  @Override
+  public ListenableFuture<?> isBlocked() {
+    ListenableFuture<?> childBlocked = child.isBlocked();
+    boolean executionDone = forecastExecutionDone();
+    if (executionDone && childBlocked.isDone()) {
+      return NOT_BLOCKED;
+    } else if (childBlocked.isDone()) {
+      return forecastExecutionFuture;
+    } else if (executionDone) {
+      return childBlocked;
+    } else {
+      return successfulAsList(Arrays.asList(forecastExecutionFuture, childBlocked));
+    }
+  }
+
+  private boolean forecastExecutionDone() {
+    if (forecastExecutionFuture == null) {
+      return true;
+    }
+    return forecastExecutionFuture.isDone();
+  }
+
+  @Override
+  public boolean hasNext() throws Exception {
+    return !finished;
+  }
+
+  @Override
+  public TsBlock next() throws Exception {
+    if (forecastExecutionFuture == null) {
+      if (child.hasNextWithTimer()) {
+        TsBlock inputTsBlock = child.nextWithTimer();
+        if (inputTsBlock != null) {
+          appendTsBlockToBuilder(inputTsBlock);
+        }
+      } else {
+        submitForecastTask();
+      }
+      return null;
+    } else {
+      try {
+        if (!forecastExecutionFuture.isDone()) {
+          throw new IllegalStateException(
+              "The operator cannot continue until the forecast execution is done.");
+        }
+
+        TForecastResp forecastResp = forecastExecutionFuture.get();
+        if (forecastResp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+          String message =
+              String.format(
+                  "Error occurred while executing forecast: %s",
+                  forecastResp.getStatus().getMessage());
+          throw new ModelInferenceProcessException(message);
+        }
+
+        finished = true;
+        return new TsBlockSerde().deserialize(forecastResp.bufferForForecastResult());
+      } catch (InterruptedException e) {
+        LOGGER.warn(
+            "{}: interrupted when processing write operation future with exception {}", this, e);
+        Thread.currentThread().interrupt();
+        throw new ModelInferenceProcessException(e.getMessage());
+      } catch (ExecutionException e) {
+        throw new ModelInferenceProcessException(e.getMessage());
+      }
+    }
+  }
+
+  private void appendTsBlockToBuilder(TsBlock inputTsBlock) {
+    TimeColumnBuilder timeColumnBuilder = inputTsBlockBuilder.getTimeColumnBuilder();
+    ColumnBuilder[] columnBuilders = inputTsBlockBuilder.getValueColumnBuilders();
+
+    for (int i = 0; i < inputTsBlock.getPositionCount(); i++) {
+      timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i));
+      for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); columnIndex++) {
+        columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), i);
+      }
+      inputTsBlockBuilder.declarePosition();
+    }
+  }
+
+  private void submitForecastTask() {
+    try {
+      if (client == null) {
+        client = new MLNodeClient();
+      }
+    } catch (TException e) {
+      throw new ModelInferenceProcessException(e.getMessage());
+    }
+
+    TsBlock inputTsBlock = inputTsBlockBuilder.build();
+    inputTsBlock.reverse();
+
+    forecastExecutionFuture =
+        Futures.submit(
+            () ->
+                client.forecast(
+                    modelPath,
+                    inputTsBlock,
+                    inputTypeList,
+                    inputColumnNameList,
+                    expectedPredictLength),
+            modelInferenceExecutor);
+  }
+
+  @Override
+  public boolean isFinished() throws Exception {
+    return finished;
+  }
+
+  @Override
+  public void close() throws Exception {
+    client.close();
+    if (forecastExecutionFuture != null) {
+      forecastExecutionFuture.cancel(true);
+    }
+    child.close();
+  }
+
+  @Override
+  public long calculateMaxPeekMemory() {
+    return maxReturnSize + maxRetainedSize + child.calculateMaxPeekMemory();
+  }
+
+  @Override
+  public long calculateMaxReturnSize() {
+    return maxReturnSize;
+  }
+
+  @Override
+  public long calculateRetainedSizeAfterCallingNext() {
+    return maxRetainedSize + child.calculateRetainedSizeAfterCallingNext();
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
index 4e945c8be6..e40a87c976 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
@@ -102,6 +102,7 @@ import org.apache.iotdb.db.mpp.execution.operator.process.last.LastQueryOperator
 import org.apache.iotdb.db.mpp.execution.operator.process.last.LastQuerySortOperator;
 import org.apache.iotdb.db.mpp.execution.operator.process.last.LastQueryUtil;
 import org.apache.iotdb.db.mpp.execution.operator.process.last.UpdateLastCacheOperator;
+import org.apache.iotdb.db.mpp.execution.operator.process.ml.ForecastOperator;
 import org.apache.iotdb.db.mpp.execution.operator.schema.CountGroupByLevelMergeOperator;
 import org.apache.iotdb.db.mpp.execution.operator.schema.CountGroupByLevelScanOperator;
 import org.apache.iotdb.db.mpp.execution.operator.schema.CountMergeOperator;
@@ -177,6 +178,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -200,6 +202,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.InputLocation;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OutputColumn;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.SeriesScanOptions;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
 import org.apache.iotdb.db.mpp.plan.statement.component.FillPolicy;
 import org.apache.iotdb.db.mpp.plan.statement.component.OrderByKey;
 import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
@@ -213,6 +216,8 @@ import org.apache.iotdb.db.utils.datastructure.TimeSelector;
 import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
 import org.apache.iotdb.tsfile.read.TimeValuePair;
 import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.read.common.block.column.DoubleColumn;
+import org.apache.iotdb.tsfile.read.common.block.column.TimeColumn;
 import org.apache.iotdb.tsfile.read.filter.basic.Filter;
 import org.apache.iotdb.tsfile.read.filter.operator.Gt;
 import org.apache.iotdb.tsfile.read.filter.operator.GtEq;
@@ -241,6 +246,7 @@ import static com.google.common.base.Preconditions.checkArgument;
 import static org.apache.iotdb.db.mpp.common.DataNodeEndPoints.isSameNode;
 import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.calculateMaxAggregationResultSize;
 import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.calculateMaxAggregationResultSizeForLastQuery;
+import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.getOutputColumnSizePerLine;
 import static org.apache.iotdb.db.mpp.execution.operator.AggregationUtil.initTimeRangeIterator;
 import static org.apache.iotdb.db.mpp.plan.planner.plan.parameter.SeriesScanOptions.updateFilterUsingTTL;
 
@@ -1643,6 +1649,47 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
         MergeSortComparator.getComparator(sortItemList, sortItemIndexList, sortItemDataTypeList));
   }
 
+  @Override
+  public Operator visitForecast(ForecastNode node, LocalExecutionPlanContext context) {
+    Operator child = node.getChild().accept(this, context);
+    OperatorContext operatorContext =
+        context
+            .getDriverContext()
+            .addOperatorContext(
+                context.getNextOperatorId(),
+                node.getPlanNodeId(),
+                ForecastOperator.class.getSimpleName());
+
+    ForecastModelInferenceDescriptor forecastModelInferenceDescriptor =
+        node.getModelInferenceDescriptor();
+
+    List<TSDataType> inputTypeList = forecastModelInferenceDescriptor.getInputTypeList();
+    int modelInputLength = forecastModelInferenceDescriptor.getModelInputLength();
+    long timeValueColumnsSizePerLine = TimeColumn.SIZE_IN_BYTES_PER_POSITION;
+    for (TSDataType dataType : inputTypeList) {
+      timeValueColumnsSizePerLine += getOutputColumnSizePerLine(dataType, new PartialPath());
+    }
+    long maxRetainedSize = timeValueColumnsSizePerLine * modelInputLength;
+
+    int expectedPredictLength = forecastModelInferenceDescriptor.getExpectedPredictLength();
+    int outputColumnNum = forecastModelInferenceDescriptor.getPredictIndexList().size();
+    long maxReturnSize =
+        (TimeColumn.SIZE_IN_BYTES_PER_POSITION
+                + (long) outputColumnNum * DoubleColumn.SIZE_IN_BYTES_PER_POSITION)
+            * expectedPredictLength;
+
+    return new ForecastOperator(
+        operatorContext,
+        child,
+        forecastModelInferenceDescriptor.getModelPath(),
+        inputTypeList,
+        node.getChild().getOutputColumnNames(),
+        expectedPredictLength,
+        FragmentInstanceManager.getInstance().getModelInferenceExecutor(),
+        maxRetainedSize,
+        maxReturnSize);
+  }
+
   @Override
   public Operator visitInto(IntoNode node, LocalExecutionPlanContext context) {
     Operator child = node.getChild().accept(this, context);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
index f71f1fe663..4780935464 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
@@ -51,6 +51,10 @@ public class ForecastNode extends SingleChildProcessNode {
     this.modelInferenceDescriptor = modelInferenceDescriptor;
   }
 
+  public ForecastModelInferenceDescriptor getModelInferenceDescriptor() {
+    return modelInferenceDescriptor;
+  }
+
   @Override
   public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
     return visitor.visitForecast(this, context);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
index f47f319f54..97c1a8cab1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
@@ -88,6 +88,14 @@ public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor {
     return modelInputLength;
   }
 
+  public int getModelPredictLength() {
+    return modelPredictLength;
+  }
+
+  public int getExpectedPredictLength() {
+    return expectedPredictLength;
+  }
+
   @Override
   public LinkedHashMap<String, String> getOutputAttributes() {
     if (outputAttributes == null) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
index db11e50f18..1948ebeb98 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
@@ -62,6 +62,10 @@ public abstract class ModelInferenceDescriptor {
     return modelId;
   }
 
+  public String getModelPath() {
+    return modelPath;
+  }
+
   public List<FunctionExpression> getModelInferenceOutputExpressions() {
     return modelInferenceOutputExpressions;
   }
diff --git a/thrift-mlnode/src/main/thrift/mlnode.thrift b/thrift-mlnode/src/main/thrift/mlnode.thrift
index abadc79576..46f7b025f4 100644
--- a/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -36,7 +36,10 @@ struct TDeleteModelReq {
 
 struct TForecastReq {
   1: required string modelPath
-  2: required binary dataset
+  2: required binary inputData
+  3: required list<string> inputTypeList
+  4: required list<string> inputColumnNameList
+  5: required i32 predictLength
 }
 
 struct TForecastResp {