You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/15 09:23:58 UTC

[iotdb] branch lmh/forecast updated (445537b5197 -> 7f2fd39a826)

This is an automated email from the ASF dual-hosted git repository.

hui pushed a change to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git


    from 445537b5197 implement analyzer (tmp save)
     new 4eb0f8c2556 implement analyzer (finish)
     new c4fee543fbd implement planner
     new 7f2fd39a826 add model available check

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../iotdb/commons/model/ModelInformation.java      |  13 ++
 .../org/apache/iotdb/db/constant/SqlConstant.java  |   3 +
 .../apache/iotdb/db/mpp/plan/analyze/Analysis.java |  11 ++
 .../iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java  |  88 ++++++++++-
 .../mpp/plan/analyze/ExpressionTypeAnalyzer.java   |   2 +
 .../plan/expression/multi/FunctionExpression.java  |   2 +-
 .../db/mpp/plan/planner/LogicalPlanBuilder.java    |  10 ++
 .../db/mpp/plan/planner/LogicalPlanVisitor.java    |  17 +++
 .../plan/planner/plan/node/PlanGraphPrinter.java   |  12 ++
 .../mpp/plan/planner/plan/node/PlanNodeType.java   |   6 +-
 .../db/mpp/plan/planner/plan/node/PlanVisitor.java |   5 +
 .../planner/plan/node/process/ml/ForecastNode.java | 118 +++++++++++++++
 .../plan/parameter/ModelInferenceDescriptor.java   |  22 ---
 .../model/ForecastModelInferenceDescriptor.java    | 168 +++++++++++++++++++++
 .../parameter/model/ModelInferenceDescriptor.java  | 107 +++++++++++++
 .../db/mpp/plan/statement/crud/QueryStatement.java |  17 ++-
 16 files changed, 575 insertions(+), 26 deletions(-)
 create mode 100644 server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
 delete mode 100644 server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/ModelInferenceDescriptor.java
 create mode 100644 server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
 create mode 100644 server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java


[iotdb] 02/03: implement planner

Posted by hu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit c4fee543fbd0f9a06ac6df7de1e7c71b86c91df9
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon May 15 17:09:14 2023 +0800

    implement planner
---
 .../iotdb/commons/model/ModelInformation.java      |   9 ++
 .../iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java  |  11 +-
 .../db/mpp/plan/planner/LogicalPlanBuilder.java    |  10 ++
 .../db/mpp/plan/planner/LogicalPlanVisitor.java    |  17 +++
 .../plan/planner/plan/node/PlanGraphPrinter.java   |  12 +++
 .../mpp/plan/planner/plan/node/PlanNodeType.java   |   6 +-
 .../db/mpp/plan/planner/plan/node/PlanVisitor.java |   5 +
 .../planner/plan/node/process/ml/ForecastNode.java | 118 +++++++++++++++++++++
 .../model/ForecastModelInferenceDescriptor.java    | 118 +++++++++++++++++----
 .../parameter/model/ModelInferenceDescriptor.java  |  54 +++++++++-
 .../db/mpp/plan/statement/crud/QueryStatement.java |  10 +-
 11 files changed, 337 insertions(+), 33 deletions(-)

diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 522f609e51e..f052c8d2121 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -186,6 +186,15 @@ public class ModelInformation {
     }
   }
 
+  public String getModelPath() {
+    if (bestTrailId != null) {
+      TrailInformation bestTrail = trailMap.get(bestTrailId);
+      return bestTrail.getModelPath();
+    } else {
+      return "UNKNOWN";
+    }
+  }
+
   public void serialize(DataOutputStream stream) throws IOException {
     ReadWriteIOUtils.write(modelId, stream);
     ReadWriteIOUtils.write(modelTask.ordinal(), stream);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
index 39b76dbfafb..f96f3de274a 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
@@ -90,6 +90,7 @@ import org.apache.iotdb.db.mpp.plan.statement.component.GroupBySessionComponent;
 import org.apache.iotdb.db.mpp.plan.statement.component.GroupByTimeComponent;
 import org.apache.iotdb.db.mpp.plan.statement.component.GroupByVariationComponent;
 import org.apache.iotdb.db.mpp.plan.statement.component.IntoComponent;
+import org.apache.iotdb.db.mpp.plan.statement.component.OrderByComponent;
 import org.apache.iotdb.db.mpp.plan.statement.component.OrderByKey;
 import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
 import org.apache.iotdb.db.mpp.plan.statement.component.ResultColumn;
@@ -375,7 +376,7 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
 
     ModelInformation modelInformation = partitionFetcher.getModelInformation(modelId);
     if (modelInformation == null) {
-      // throw new SemanticException("");
+      throw new SemanticException("");
     }
 
     ModelInferenceFunction functionType =
@@ -383,7 +384,7 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
     switch (functionType) {
       case FORECAST:
         ModelInferenceDescriptor modelInferenceDescriptor =
-            new ForecastModelInferenceDescriptor(functionType, modelId, modelInformation);
+            new ForecastModelInferenceDescriptor(functionType, modelInformation);
         analysis.setModelInferenceDescriptor(modelInferenceDescriptor);
 
         List<ResultColumn> newResultColumns = new ArrayList<>();
@@ -391,9 +392,13 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
           newResultColumns.add(new ResultColumn(inputExpression, ResultColumn.ColumnType.RAW));
         }
         queryStatement.getSelectComponent().setResultColumns(newResultColumns);
+
+        OrderByComponent descTimeOrder = new OrderByComponent();
+        descTimeOrder.addSortItem(new SortItem("TIME", Ordering.DESC));
+        queryStatement.setOrderByComponent(descTimeOrder);
         break;
       default:
-        throw new SemanticException("");
+        throw new IllegalArgumentException("");
     }
   }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
index 486b9f2f455..66d3b8ff2d2 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanBuilder.java
@@ -70,6 +70,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SortNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
@@ -87,6 +88,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByParameter;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByTimeParameter;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
 import org.apache.iotdb.db.mpp.plan.statement.component.OrderByKey;
 import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
 import org.apache.iotdb.db.mpp.plan.statement.component.SortItem;
@@ -1253,4 +1255,12 @@ public class LogicalPlanBuilder {
     }
     return this;
   }
+
+  public LogicalPlanBuilder planForecast(
+      ForecastModelInferenceDescriptor forecastModelInferenceDescriptor) {
+    this.root =
+        new ForecastNode(
+            context.getQueryId().genPlanNodeId(), root, forecastModelInferenceDescriptor);
+    return this;
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
index 5a864398097..4d6006bc458 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LogicalPlanVisitor.java
@@ -49,6 +49,8 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.write.InsertRowsOfOneDevic
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.write.InsertTabletNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
 import org.apache.iotdb.db.mpp.plan.statement.StatementNode;
 import org.apache.iotdb.db.mpp.plan.statement.StatementVisitor;
 import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
@@ -203,6 +205,21 @@ public class LogicalPlanVisitor extends StatementVisitor<PlanNode, MPPQueryConte
             .planOffset(queryStatement.getRowOffset())
             .planLimit(queryStatement.getRowLimit());
 
+    if (queryStatement.isModelInferenceQuery()) {
+      ModelInferenceDescriptor modelInferenceDescriptor = analysis.getModelInferenceDescriptor();
+      switch (modelInferenceDescriptor.getFunctionType()) {
+        case FORECAST:
+          ForecastModelInferenceDescriptor forecastModelInferenceDescriptor =
+              (ForecastModelInferenceDescriptor) modelInferenceDescriptor;
+          planBuilder
+              .planLimit(forecastModelInferenceDescriptor.getModelInputLength())
+              .planForecast(forecastModelInferenceDescriptor);
+          break;
+        default:
+          throw new IllegalArgumentException();
+      }
+    }
+
     // plan select into
     if (queryStatement.isAlignByDevice()) {
       planBuilder = planBuilder.planDeviceViewInto(analysis.getDeviceViewIntoPathDescriptor());
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
index a9726d0c2f7..2ad21bb98f3 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
@@ -45,6 +45,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -443,6 +444,17 @@ public class PlanGraphPrinter extends PlanVisitor<List<String>, PlanGraphPrinter
     return render(node, boxValue, context);
   }
 
+  @Override
+  public List<String> visitForecast(ForecastNode node, GraphContext context) {
+    List<String> boxValue = new ArrayList<>();
+    boxValue.add(String.format("Forecast-%s", node.getPlanNodeId().getId()));
+    boxValue.add("Output: ");
+    for (String outputColumnName : node.getOutputColumnNames()) {
+      boxValue.add(String.format("  %s", outputColumnName));
+    }
+    return render(node, boxValue, context);
+  }
+
   private String printRegion(TRegionReplicaSet regionReplicaSet) {
     return String.format(
         "Partition: %s",
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
index 762238ac915..8ddc931a9fc 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
@@ -74,6 +74,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -169,7 +170,8 @@ public enum PlanNodeType {
   IDENTITY_SINK((short) 70),
   SHUFFLE_SINK((short) 71),
   BATCH_ACTIVATE_TEMPLATE((short) 72),
-  CREATE_LOGICAL_VIEW((short) 73);
+  CREATE_LOGICAL_VIEW((short) 73),
+  FORECAST((short) 74);
 
   public static final int BYTES = Short.BYTES;
 
@@ -364,6 +366,8 @@ public enum PlanNodeType {
         return BatchActivateTemplateNode.deserialize(buffer);
       case 73:
         return CreateLogicalViewNode.deserialize(buffer);
+      case 74:
+        return ForecastNode.deserialize(buffer);
       default:
         throw new IllegalArgumentException("Invalid node type: " + nodeType);
     }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
index 771a8050145..bc07fd049aa 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
@@ -74,6 +74,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml.ForecastNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
@@ -187,6 +188,10 @@ public abstract class PlanVisitor<R, C> {
     return visitSingleChildProcess(node, context);
   }
 
+  public R visitForecast(ForecastNode node, C context) {
+    return visitSingleChildProcess(node, context);
+  }
+
   // multi child --------------------------------------------------------------------------------
 
   public R visitMultiChildProcess(MultiChildProcessNode node, C context) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
new file mode 100644
index 00000000000..f71f1fe6631
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ml/ForecastNode.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ml;
+
+import org.apache.iotdb.db.mpp.plan.expression.Expression;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeType;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SingleChildProcessNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+public class ForecastNode extends SingleChildProcessNode {
+
+  private final ForecastModelInferenceDescriptor modelInferenceDescriptor;
+
+  private List<String> outputColumnNames;
+
+  public ForecastNode(
+      PlanNodeId id, PlanNode child, ForecastModelInferenceDescriptor modelInferenceDescriptor) {
+    super(id, child);
+    this.modelInferenceDescriptor = modelInferenceDescriptor;
+  }
+
+  public ForecastNode(PlanNodeId id, ForecastModelInferenceDescriptor modelInferenceDescriptor) {
+    super(id);
+    this.modelInferenceDescriptor = modelInferenceDescriptor;
+  }
+
+  @Override
+  public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
+    return visitor.visitForecast(this, context);
+  }
+
+  @Override
+  public PlanNode clone() {
+    return new ForecastNode(getPlanNodeId(), child, modelInferenceDescriptor);
+  }
+
+  @Override
+  public List<String> getOutputColumnNames() {
+    if (outputColumnNames == null) {
+      outputColumnNames = new ArrayList<>();
+      for (Expression expression : modelInferenceDescriptor.getModelInferenceOutputExpressions()) {
+        outputColumnNames.add(expression.toString());
+      }
+    }
+    return outputColumnNames;
+  }
+
+  @Override
+  protected void serializeAttributes(ByteBuffer byteBuffer) {
+    PlanNodeType.FORECAST.serialize(byteBuffer);
+    modelInferenceDescriptor.serialize(byteBuffer);
+  }
+
+  @Override
+  protected void serializeAttributes(DataOutputStream stream) throws IOException {
+    PlanNodeType.FORECAST.serialize(stream);
+    modelInferenceDescriptor.serialize(stream);
+  }
+
+  public static ForecastNode deserialize(ByteBuffer buffer) {
+    ForecastModelInferenceDescriptor modelInferenceDescriptor =
+        ForecastModelInferenceDescriptor.deserialize(buffer);
+    PlanNodeId planNodeId = PlanNodeId.deserialize(buffer);
+    return new ForecastNode(planNodeId, modelInferenceDescriptor);
+  }
+
+  @Override
+  public String toString() {
+    return "ForecastNode-" + this.getPlanNodeId();
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    if (!super.equals(o)) {
+      return false;
+    }
+    ForecastNode that = (ForecastNode) o;
+    return modelInferenceDescriptor.equals(that.modelInferenceDescriptor);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(super.hashCode(), modelInferenceDescriptor);
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
index cd339b1c570..f47f319f54e 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
@@ -22,10 +22,15 @@ package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model;
 import org.apache.iotdb.commons.model.ModelInformation;
 import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
 import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
-import java.util.Arrays;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Objects;
 
 import static org.apache.iotdb.db.constant.SqlConstant.MODEL_ID;
 import static org.apache.iotdb.db.constant.SqlConstant.PREDICT_LENGTH;
@@ -39,16 +44,32 @@ public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor {
   private int modelPredictLength;
   private int expectedPredictLength;
 
-  private String parametersString;
   private LinkedHashMap<String, String> outputAttributes;
 
   public ForecastModelInferenceDescriptor(
-      ModelInferenceFunction functionType, String modelId, ModelInformation modelInformation) {
-    super(functionType, modelId);
+      ModelInferenceFunction functionType, ModelInformation modelInformation) {
+    super(functionType, modelInformation);
+  }
+
+  public ForecastModelInferenceDescriptor(ByteBuffer buffer) {
+    super(buffer);
+    int listSize = ReadWriteIOUtils.readInt(buffer);
+    this.inputTypeList = new ArrayList<>(listSize);
+    for (int i = 0; i < listSize; i++) {
+      this.inputTypeList.add(TSDataType.deserializeFrom(buffer));
+    }
+    listSize = ReadWriteIOUtils.readInt(buffer);
+    this.predictIndexList = new ArrayList<>(listSize);
+    for (int i = 0; i < listSize; i++) {
+      this.predictIndexList.add(ReadWriteIOUtils.readInt(buffer));
+    }
+    this.modelInputLength = ReadWriteIOUtils.readInt(buffer);
+    this.modelPredictLength = ReadWriteIOUtils.readInt(buffer);
+    this.expectedPredictLength = ReadWriteIOUtils.readInt(buffer);
   }
 
   public List<Integer> getPredictIndexList() {
-    return Arrays.asList(0, 1);
+    return predictIndexList;
   }
 
   public void setPredictIndexList(List<Integer> predictIndexList) {
@@ -56,30 +77,15 @@ public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor {
   }
 
   public List<TSDataType> getInputTypeList() {
-    return Arrays.asList(TSDataType.FLOAT, TSDataType.FLOAT);
+    return inputTypeList;
   }
 
   public void setInputTypeList(List<TSDataType> inputTypeList) {
     this.inputTypeList = inputTypeList;
   }
 
-  @Override
-  public String getParametersString() {
-    if (parametersString == null) {
-      StringBuilder builder = new StringBuilder();
-      builder.append("\"").append(MODEL_ID).append("\"=\"").append(modelId).append("\"");
-      if (expectedPredictLength != modelPredictLength) {
-        builder
-            .append(", ")
-            .append("\"")
-            .append(PREDICT_LENGTH)
-            .append("\"=\"")
-            .append(expectedPredictLength)
-            .append("\"");
-      }
-      parametersString = builder.toString();
-    }
-    return parametersString;
+  public int getModelInputLength() {
+    return modelInputLength;
   }
 
   @Override
@@ -93,4 +99,70 @@ public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor {
     }
     return outputAttributes;
   }
+
+  @Override
+  public void serialize(ByteBuffer byteBuffer) {
+    super.serialize(byteBuffer);
+    ReadWriteIOUtils.write(inputTypeList.size(), byteBuffer);
+    for (TSDataType dataType : inputTypeList) {
+      dataType.serializeTo(byteBuffer);
+    }
+    ReadWriteIOUtils.write(predictIndexList.size(), byteBuffer);
+    for (Integer index : predictIndexList) {
+      ReadWriteIOUtils.write(index, byteBuffer);
+    }
+    ReadWriteIOUtils.write(modelInputLength, byteBuffer);
+    ReadWriteIOUtils.write(modelPredictLength, byteBuffer);
+    ReadWriteIOUtils.write(expectedPredictLength, byteBuffer);
+  }
+
+  @Override
+  public void serialize(DataOutputStream stream) throws IOException {
+    super.serialize(stream);
+    ReadWriteIOUtils.write(inputTypeList.size(), stream);
+    for (TSDataType dataType : inputTypeList) {
+      dataType.serializeTo(stream);
+    }
+    ReadWriteIOUtils.write(predictIndexList.size(), stream);
+    for (Integer index : predictIndexList) {
+      ReadWriteIOUtils.write(index, stream);
+    }
+    ReadWriteIOUtils.write(modelInputLength, stream);
+    ReadWriteIOUtils.write(modelPredictLength, stream);
+    ReadWriteIOUtils.write(expectedPredictLength, stream);
+  }
+
+  public static ForecastModelInferenceDescriptor deserialize(ByteBuffer buffer) {
+    return new ForecastModelInferenceDescriptor(buffer);
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    if (!super.equals(o)) {
+      return false;
+    }
+    ForecastModelInferenceDescriptor that = (ForecastModelInferenceDescriptor) o;
+    return modelInputLength == that.modelInputLength
+        && modelPredictLength == that.modelPredictLength
+        && expectedPredictLength == that.expectedPredictLength
+        && inputTypeList.equals(that.inputTypeList)
+        && predictIndexList.equals(that.predictIndexList);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(
+        super.hashCode(),
+        inputTypeList,
+        predictIndexList,
+        modelInputLength,
+        modelPredictLength,
+        expectedPredictLength);
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
index eb0f29b38c2..db11e50f186 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
@@ -19,11 +19,17 @@
 
 package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model;
 
+import org.apache.iotdb.commons.model.ModelInformation;
 import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
 import org.apache.iotdb.db.mpp.plan.expression.multi.FunctionExpression;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Objects;
 
 public abstract class ModelInferenceDescriptor {
 
@@ -31,13 +37,21 @@ public abstract class ModelInferenceDescriptor {
 
   protected final String modelId;
 
-  protected String modelPath;
+  protected final String modelPath;
 
   protected List<FunctionExpression> modelInferenceOutputExpressions;
 
-  public ModelInferenceDescriptor(ModelInferenceFunction functionType, String modelId) {
+  public ModelInferenceDescriptor(
+      ModelInferenceFunction functionType, ModelInformation modelInformation) {
     this.functionType = functionType;
-    this.modelId = modelId;
+    this.modelId = modelInformation.getModelId();
+    this.modelPath = modelInformation.getModelPath();
+  }
+
+  public ModelInferenceDescriptor(ByteBuffer buffer) {
+    this.functionType = ModelInferenceFunction.values()[ReadWriteIOUtils.readInt(buffer)];
+    this.modelId = ReadWriteIOUtils.readString(buffer);
+    this.modelPath = ReadWriteIOUtils.readString(buffer);
   }
 
   public ModelInferenceFunction getFunctionType() {
@@ -57,7 +71,37 @@ public abstract class ModelInferenceDescriptor {
     this.modelInferenceOutputExpressions = modelInferenceOutputExpressions;
   }
 
-  public abstract String getParametersString();
-
   public abstract LinkedHashMap<String, String> getOutputAttributes();
+
+  public void serialize(ByteBuffer byteBuffer) {
+    ReadWriteIOUtils.write(functionType.ordinal(), byteBuffer);
+    ReadWriteIOUtils.write(modelId, byteBuffer);
+    ReadWriteIOUtils.write(modelPath, byteBuffer);
+  }
+
+  public void serialize(DataOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(functionType.ordinal(), stream);
+    ReadWriteIOUtils.write(modelId, stream);
+    ReadWriteIOUtils.write(modelPath, stream);
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    ModelInferenceDescriptor that = (ModelInferenceDescriptor) o;
+    return functionType == that.functionType
+        && modelId.equals(that.modelId)
+        && modelPath.equals(that.modelPath)
+        && modelInferenceOutputExpressions.equals(that.modelInferenceOutputExpressions);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(functionType, modelId, modelPath, modelInferenceOutputExpressions);
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
index decdfbeca47..89b58cb197a 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
@@ -482,7 +482,15 @@ public class QueryStatement extends Statement {
           || isLastQuery()
           || seriesLimit > 0
           || seriesOffset > 0
-          || isSelectInto()) {
+          || isSelectInto()
+          || isOrderByDevice()
+          || isOrderByTimeseries()) {
+        throw new SemanticException("");
+      }
+
+      if (orderByComponent != null
+          && (!orderByComponent.isOrderByTime()
+              || orderByComponent.getTimeOrder() != Ordering.ASC)) {
         throw new SemanticException("");
       }
     }


[iotdb] 03/03: add model available check

Posted by hu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 7f2fd39a826f5288bd081ea30bbbe01c6c6dad64
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon May 15 17:13:48 2023 +0800

    add model available check
---
 .../main/java/org/apache/iotdb/commons/model/ModelInformation.java    | 4 ++++
 .../java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java     | 2 +-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index f052c8d2121..7c46a6c3437 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -152,6 +152,10 @@ public class ModelInformation {
     return queryFilter;
   }
 
+  public boolean available() {
+    return trainingState == TrainingState.FINISHED;
+  }
+
   public TrailInformation getTrailInformationById(String trailId) {
     if (trailMap.containsKey(trailId)) {
       return trailMap.get(trailId);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
index f96f3de274a..5034007063e 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
@@ -375,7 +375,7 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
     String modelId = modelInferenceExpression.getFunctionAttributes().get(MODEL_ID);
 
     ModelInformation modelInformation = partitionFetcher.getModelInformation(modelId);
-    if (modelInformation == null) {
+    if (modelInformation == null || !modelInformation.available()) {
       throw new SemanticException("");
     }
 


[iotdb] 01/03: implement analyzer (finish)

Posted by hu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/forecast
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 4eb0f8c25561e8324f6709d6ef662f707e7bbbe5
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon May 15 15:32:48 2023 +0800

    implement analyzer (finish)
---
 .../org/apache/iotdb/db/constant/SqlConstant.java  |  3 +
 .../apache/iotdb/db/mpp/plan/analyze/Analysis.java | 11 +++
 .../iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java  | 83 ++++++++++++++++++-
 .../mpp/plan/analyze/ExpressionTypeAnalyzer.java   |  2 +
 .../plan/expression/multi/FunctionExpression.java  |  2 +-
 .../plan/parameter/ModelInferenceDescriptor.java   | 22 -----
 .../model/ForecastModelInferenceDescriptor.java    | 96 ++++++++++++++++++++++
 .../parameter/model/ModelInferenceDescriptor.java  | 63 ++++++++++++++
 .../db/mpp/plan/statement/crud/QueryStatement.java |  7 ++
 9 files changed, 265 insertions(+), 24 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java b/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
index 4f18776b6b5..976862547cb 100644
--- a/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
+++ b/server/src/main/java/org/apache/iotdb/db/constant/SqlConstant.java
@@ -78,6 +78,9 @@ public class SqlConstant {
   public static final String SUBSTRING_IS_STANDARD = "isStandard";
   public static final String SUBSTRING_FOR = "FOR";
 
+  public static final String MODEL_ID = "model_id";
+  public static final String PREDICT_LENGTH = "predict_length";
+
   public static String[] getSingleRootArray() {
     return SINGLE_ROOT_ARRAY;
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java
index 6b3f4780a67..96ae2299974 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analysis.java
@@ -38,6 +38,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByParameter;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByTimeParameter;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
 import org.apache.iotdb.db.mpp.plan.statement.Statement;
 import org.apache.iotdb.db.mpp.plan.statement.component.SortItem;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
@@ -199,6 +200,8 @@ public class Analysis {
   // indicate whether the Nodes produce source data are VirtualSourceNodes
   private boolean isVirtualSource = false;
 
+  private ModelInferenceDescriptor modelInferenceDescriptor;
+
   /////////////////////////////////////////////////////////////////////////////////////////////////
   // SELECT INTO Analysis
   /////////////////////////////////////////////////////////////////////////////////////////////////
@@ -666,4 +669,12 @@ public class Analysis {
   public void setDeviceToSortItems(Map<String, List<SortItem>> deviceToSortItems) {
     this.deviceToSortItems = deviceToSortItems;
   }
+
+  public ModelInferenceDescriptor getModelInferenceDescriptor() {
+    return modelInferenceDescriptor;
+  }
+
+  public void setModelInferenceDescriptor(ModelInferenceDescriptor modelInferenceDescriptor) {
+    this.modelInferenceDescriptor = modelInferenceDescriptor;
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
index 973c02a9602..39b76dbfafb 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/AnalyzeVisitor.java
@@ -24,6 +24,7 @@ import org.apache.iotdb.commons.client.exception.ClientManagerException;
 import org.apache.iotdb.commons.conf.IoTDBConstant;
 import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.exception.MetadataException;
+import org.apache.iotdb.commons.model.ModelInformation;
 import org.apache.iotdb.commons.partition.DataPartition;
 import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
 import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
@@ -31,6 +32,7 @@ import org.apache.iotdb.commons.partition.SchemaPartition;
 import org.apache.iotdb.commons.path.MeasurementPath;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.commons.path.PathPatternTree;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
 import org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp;
 import org.apache.iotdb.db.client.ConfigNodeClient;
 import org.apache.iotdb.db.client.ConfigNodeClientManager;
@@ -75,6 +77,8 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByTimeParameter;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.GroupByVariationParameter;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.IntoPathDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ForecastModelInferenceDescriptor;
+import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
 import org.apache.iotdb.db.mpp.plan.statement.Statement;
 import org.apache.iotdb.db.mpp.plan.statement.StatementNode;
 import org.apache.iotdb.db.mpp.plan.statement.StatementVisitor;
@@ -182,6 +186,8 @@ import static org.apache.iotdb.commons.conf.IoTDBConstant.ALLOWED_SCHEMA_PROPS;
 import static org.apache.iotdb.commons.conf.IoTDBConstant.DEADBAND;
 import static org.apache.iotdb.commons.conf.IoTDBConstant.LOSS;
 import static org.apache.iotdb.commons.conf.IoTDBConstant.ONE_LEVEL_PATH_WILDCARD;
+import static org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction.FORECAST;
+import static org.apache.iotdb.db.constant.SqlConstant.MODEL_ID;
 import static org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant.DEVICE;
 import static org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant.ENDTIME;
 import static org.apache.iotdb.db.mpp.metric.QueryPlanCostMetricSet.PARTITION_FETCHER;
@@ -361,7 +367,35 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
     return analysis;
   }
 
-  private void analyzeModelInference(Analysis analysis, QueryStatement queryStatement) {}
+  private void analyzeModelInference(Analysis analysis, QueryStatement queryStatement) {
+    FunctionExpression modelInferenceExpression =
+        (FunctionExpression)
+            queryStatement.getSelectComponent().getResultColumns().get(0).getExpression();
+    String modelId = modelInferenceExpression.getFunctionAttributes().get(MODEL_ID);
+
+    ModelInformation modelInformation = partitionFetcher.getModelInformation(modelId);
+    if (modelInformation == null) {
+      // throw new SemanticException("");
+    }
+
+    ModelInferenceFunction functionType =
+        ModelInferenceFunction.valueOf(modelInferenceExpression.getFunctionName().toUpperCase());
+    switch (functionType) {
+      case FORECAST:
+        ModelInferenceDescriptor modelInferenceDescriptor =
+            new ForecastModelInferenceDescriptor(functionType, modelId, modelInformation);
+        analysis.setModelInferenceDescriptor(modelInferenceDescriptor);
+
+        List<ResultColumn> newResultColumns = new ArrayList<>();
+        for (Expression inputExpression : modelInferenceExpression.getExpressions()) {
+          newResultColumns.add(new ResultColumn(inputExpression, ResultColumn.ColumnType.RAW));
+        }
+        queryStatement.getSelectComponent().setResultColumns(newResultColumns);
+        break;
+      default:
+        throw new SemanticException("");
+    }
+  }
 
   private Analysis finishQuery(QueryStatement queryStatement, Analysis analysis) {
     if (queryStatement.isSelectInto()) {
@@ -1215,6 +1249,53 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext>
       return;
     }
 
+    if (queryStatement.isModelInferenceQuery()) {
+      List<ColumnHeader> columnHeaders = new ArrayList<>();
+      boolean isIgnoreTimestamp;
+
+      ModelInferenceDescriptor modelInferenceDescriptor = analysis.getModelInferenceDescriptor();
+      switch (modelInferenceDescriptor.getFunctionType()) {
+        case FORECAST:
+          isIgnoreTimestamp = false;
+          ForecastModelInferenceDescriptor forecastModelInferenceDescriptor =
+              (ForecastModelInferenceDescriptor) modelInferenceDescriptor;
+
+          List<TSDataType> inputTypeList = forecastModelInferenceDescriptor.getInputTypeList();
+          if (outputExpressions.size() != inputTypeList.size()) {
+            throw new SemanticException("");
+          }
+          for (int i = 0; i < inputTypeList.size(); i++) {
+            Expression inputExpression = outputExpressions.get(i).left;
+            if (analysis.getType(inputExpression) != inputTypeList.get(i)) {
+              throw new SemanticException("");
+            }
+          }
+
+          List<FunctionExpression> modelInferenceOutputExpressions = new ArrayList<>();
+          for (int predictIndex : forecastModelInferenceDescriptor.getPredictIndexList()) {
+            Expression inputExpression = outputExpressions.get(predictIndex).left;
+            FunctionExpression modelInferenceOutputExpression =
+                new FunctionExpression(
+                    FORECAST.getFunctionName(),
+                    forecastModelInferenceDescriptor.getOutputAttributes(),
+                    Collections.singletonList(inputExpression));
+            analyzeExpression(analysis, modelInferenceOutputExpression);
+            modelInferenceOutputExpressions.add(modelInferenceOutputExpression);
+            columnHeaders.add(
+                new ColumnHeader(
+                    modelInferenceOutputExpression.toString(),
+                    analysis.getType(modelInferenceOutputExpression)));
+          }
+          forecastModelInferenceDescriptor.setModelInferenceOutputExpressions(
+              modelInferenceOutputExpressions);
+          break;
+        default:
+          throw new SemanticException("");
+      }
+      analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp));
+      return;
+    }
+
     boolean isIgnoreTimestamp = queryStatement.isAggregationQuery() && !queryStatement.isGroupBy();
     List<ColumnHeader> columnHeaders = new ArrayList<>();
     if (queryStatement.isAlignByDevice()) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
index ee2f2e1b73d..754b948e006 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/ExpressionTypeAnalyzer.java
@@ -265,6 +265,8 @@ public class ExpressionTypeAnalyzer {
             functionExpression,
             TypeInferenceUtils.getBuiltInScalarFunctionDataType(
                 functionExpression, expressionTypes.get(NodeRef.of(inputExpressions.get(0)))));
+      } else if (functionExpression.isModelInferenceFunction()) {
+        return setExpressionType(functionExpression, TSDataType.DOUBLE);
       } else {
         return setExpressionType(
             functionExpression,
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
index b9b8960cb48..6295d0e5d12 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/expression/multi/FunctionExpression.java
@@ -161,7 +161,7 @@ public class FunctionExpression extends Expression {
   }
 
   public void addAttribute(String key, String value) {
-    functionAttributes.put(key, value);
+    functionAttributes.put(key.toLowerCase(), value);
   }
 
   public void addExpression(Expression expression) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/ModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/ModelInferenceDescriptor.java
deleted file mode 100644
index 9e2c7dc1817..00000000000
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/ModelInferenceDescriptor.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.mpp.plan.planner.plan.parameter;
-
-public class ModelInferenceDescriptor {}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
new file mode 100644
index 00000000000..cd339b1c570
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ForecastModelInferenceDescriptor.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model;
+
+import org.apache.iotdb.commons.model.ModelInformation;
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.List;
+
+import static org.apache.iotdb.db.constant.SqlConstant.MODEL_ID;
+import static org.apache.iotdb.db.constant.SqlConstant.PREDICT_LENGTH;
+
+public class ForecastModelInferenceDescriptor extends ModelInferenceDescriptor {
+
+  private List<TSDataType> inputTypeList;
+  private List<Integer> predictIndexList;
+
+  private int modelInputLength;
+  private int modelPredictLength;
+  private int expectedPredictLength;
+
+  private String parametersString;
+  private LinkedHashMap<String, String> outputAttributes;
+
+  public ForecastModelInferenceDescriptor(
+      ModelInferenceFunction functionType, String modelId, ModelInformation modelInformation) {
+    super(functionType, modelId);
+  }
+
+  public List<Integer> getPredictIndexList() {
+    return Arrays.asList(0, 1);
+  }
+
+  public void setPredictIndexList(List<Integer> predictIndexList) {
+    this.predictIndexList = predictIndexList;
+  }
+
+  public List<TSDataType> getInputTypeList() {
+    return Arrays.asList(TSDataType.FLOAT, TSDataType.FLOAT);
+  }
+
+  public void setInputTypeList(List<TSDataType> inputTypeList) {
+    this.inputTypeList = inputTypeList;
+  }
+
+  @Override
+  public String getParametersString() {
+    if (parametersString == null) {
+      StringBuilder builder = new StringBuilder();
+      builder.append("\"").append(MODEL_ID).append("\"=\"").append(modelId).append("\"");
+      if (expectedPredictLength != modelPredictLength) {
+        builder
+            .append(", ")
+            .append("\"")
+            .append(PREDICT_LENGTH)
+            .append("\"=\"")
+            .append(expectedPredictLength)
+            .append("\"");
+      }
+      parametersString = builder.toString();
+    }
+    return parametersString;
+  }
+
+  @Override
+  public LinkedHashMap<String, String> getOutputAttributes() {
+    if (outputAttributes == null) {
+      outputAttributes = new LinkedHashMap<>();
+      outputAttributes.put(MODEL_ID, modelId);
+      if (expectedPredictLength != modelPredictLength) {
+        outputAttributes.put(PREDICT_LENGTH, String.valueOf(expectedPredictLength));
+      }
+    }
+    return outputAttributes;
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
new file mode 100644
index 00000000000..eb0f29b38c2
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.parameter.model;
+
+import org.apache.iotdb.commons.udf.builtin.ModelInferenceFunction;
+import org.apache.iotdb.db.mpp.plan.expression.multi.FunctionExpression;
+
+import java.util.LinkedHashMap;
+import java.util.List;
+
+public abstract class ModelInferenceDescriptor {
+
+  protected final ModelInferenceFunction functionType;
+
+  protected final String modelId;
+
+  protected String modelPath;
+
+  protected List<FunctionExpression> modelInferenceOutputExpressions;
+
+  public ModelInferenceDescriptor(ModelInferenceFunction functionType, String modelId) {
+    this.functionType = functionType;
+    this.modelId = modelId;
+  }
+
+  public ModelInferenceFunction getFunctionType() {
+    return functionType;
+  }
+
+  public String getModelId() {
+    return modelId;
+  }
+
+  public List<FunctionExpression> getModelInferenceOutputExpressions() {
+    return modelInferenceOutputExpressions;
+  }
+
+  public void setModelInferenceOutputExpressions(
+      List<FunctionExpression> modelInferenceOutputExpressions) {
+    this.modelInferenceOutputExpressions = modelInferenceOutputExpressions;
+  }
+
+  public abstract String getParametersString();
+
+  public abstract LinkedHashMap<String, String> getOutputAttributes();
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
index 4f02cc41a77..decdfbeca47 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/crud/QueryStatement.java
@@ -51,6 +51,8 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
+import static org.apache.iotdb.db.constant.SqlConstant.MODEL_ID;
+
 /**
  * Base class of SELECT statement.
  *
@@ -463,6 +465,11 @@ public class QueryStatement extends Statement {
           && ((FunctionExpression) modelInferenceExpression).isModelInferenceFunction())) {
         throw new SemanticException("");
       }
+      if (!((FunctionExpression) modelInferenceExpression)
+          .getFunctionAttributes()
+          .containsKey(MODEL_ID)) {
+        throw new SemanticException("");
+      }
       if (ExpressionAnalyzer.searchAggregationExpressions(modelInferenceExpression).size() > 0) {
         throw new SemanticException("");
       }