You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/02/17 07:00:52 UTC

[iotdb] 16/16: implement ModelInfo

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/modelManager
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 7202cba5b6e22c3dc914edc1367c721c00eeaa02
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Fri Feb 17 15:00:10 2023 +0800

    implement ModelInfo
---
 .../request/write/model/UpdateModelInfoPlan.java   |  14 +-
 .../consensus/response/ModelTableResp.java         |  11 +-
 .../consensus/response/TrailTableResp.java         |  11 +-
 .../iotdb/confignode/manager/ModelManager.java     |  11 +-
 .../iotdb/confignode/manager/ProcedureManager.java |   5 +-
 .../iotdb/confignode/persistence/ModelInfo.java    |  43 +++---
 .../procedure/impl/model/CreateModelProcedure.java |  28 ++--
 .../procedure/state/model/CreateModelState.java    |   5 +-
 .../commons/model/ForecastTrailInformation.java    |  32 -----
 .../iotdb/commons/model/ModelHyperparameter.java   |  71 +++++++++
 .../iotdb/commons/model/ModelInformation.java      | 158 ++++++++++++++++++---
 .../org/apache/iotdb/commons/model/ModelTable.java |   4 +-
 .../iotdb/commons/model/TrailInformation.java      |  84 +++++++++--
 thrift-commons/src/main/thrift/common.thrift       |  12 +-
 .../src/main/thrift/confignode.thrift              |  26 +++-
 15 files changed, 390 insertions(+), 125 deletions(-)

diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
index d1eacadae2..52c1100a00 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
@@ -33,6 +33,7 @@ import java.util.Objects;
 public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
 
   private String modelId;
+  private String trailId;
   private Map<String, String> modelInfo;
 
   public UpdateModelInfoPlan() {
@@ -42,6 +43,7 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
   public UpdateModelInfoPlan(TUpdateModelInfoReq updateModelInfoReq) {
     super(ConfigPhysicalPlanType.UpdateModelInfo);
     this.modelId = updateModelInfoReq.getModelId();
+    this.trailId = updateModelInfoReq.getTrailId();
     this.modelInfo = updateModelInfoReq.getModelInfo();
   }
 
@@ -49,6 +51,10 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
     return modelId;
   }
 
+  public String getTrailId() {
+    return trailId;
+  }
+
   public Map<String, String> getModelInfo() {
     return modelInfo;
   }
@@ -57,12 +63,14 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
   protected void serializeImpl(DataOutputStream stream) throws IOException {
     stream.writeShort(getType().getPlanType());
     ReadWriteIOUtils.write(modelId, stream);
+    ReadWriteIOUtils.write(trailId, stream);
     ReadWriteIOUtils.write(modelInfo, stream);
   }
 
   @Override
   protected void deserializeImpl(ByteBuffer buffer) throws IOException {
     this.modelId = ReadWriteIOUtils.readString(buffer);
+    this.trailId = ReadWriteIOUtils.readString(buffer);
     this.modelInfo = ReadWriteIOUtils.readMap(buffer);
   }
 
@@ -78,11 +86,13 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
       return false;
     }
     UpdateModelInfoPlan that = (UpdateModelInfoPlan) o;
-    return modelId.equals(that.modelId) && modelInfo.equals(that.modelInfo);
+    return modelId.equals(that.modelId)
+        && trailId.equals(that.trailId)
+        && modelInfo.equals(that.modelInfo);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(super.hashCode(), modelId, modelInfo);
+    return Objects.hash(super.hashCode(), modelId, trailId, modelInfo);
   }
 }
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
index 82787ca653..6642f76be9 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
@@ -34,14 +34,21 @@ public class ModelTableResp implements DataSet {
   private final TSStatus status;
   private final List<ByteBuffer> serializedAllModelInformation;
 
-  public ModelTableResp(TSStatus status, List<ModelInformation> allModelInformation) {
+  public ModelTableResp(TSStatus status) {
     this.status = status;
     this.serializedAllModelInformation = new ArrayList<>();
-    for (ModelInformation modelInformation : allModelInformation) {
+  }
+
+  public void addModelInformation(List<ModelInformation> modelInformationList) throws IOException {
+    for (ModelInformation modelInformation : modelInformationList) {
       this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult());
     }
   }
 
+  public void addModelInformation(ModelInformation modelInformation) throws IOException {
+    this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult());
+  }
+
   public TShowModelResp convertToThriftResponse() throws IOException {
     return new TShowModelResp(status, serializedAllModelInformation);
   }
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
index 21e3fcd230..1f9a6b5acb 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
@@ -34,10 +34,17 @@ public class TrailTableResp implements DataSet {
   private final TSStatus status;
   private final List<ByteBuffer> serializedAllTrailInformation;
 
-  public TrailTableResp(TSStatus status, List<TrailInformation> allTrailInformation) {
+  public TrailTableResp(TSStatus status) {
     this.status = status;
     this.serializedAllTrailInformation = new ArrayList<>();
-    for (TrailInformation trailInformation : allTrailInformation) {
+  }
+
+  public void addTrailInformation(TrailInformation trailInformation) throws IOException {
+    this.serializedAllTrailInformation.add(trailInformation.serializeShowTrailResult());
+  }
+
+  public void addTrailInformation(List<TrailInformation> trailInformationList) throws IOException {
+    for (TrailInformation trailInformation : trailInformationList) {
       this.serializedAllTrailInformation.add(trailInformation.serializeShowTrailResult());
     }
   }
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index 347c5c0b94..1d268df5b5 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -61,8 +61,15 @@ public class ModelManager {
   }
 
   public TSStatus createModel(TCreateModelReq req) {
-    ModelInformation modelInformation = new ModelInformation();
-    return configManager.getProcedureManager().createModel(modelInformation);
+    ModelInformation modelInformation =
+        new ModelInformation(
+            req.getModelId(),
+            req.getModelTask(),
+            req.getModelType(),
+            req.isIsAuto(),
+            req.getQueryExpressions(),
+            req.getQueryFilter());
+    return configManager.getProcedureManager().createModel(modelInformation, req.getModelConfigs());
   }
 
   public TSStatus dropModel(TDropModelReq req) {
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
index 7cd0da45cd..7b318b4e87 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
@@ -535,8 +535,9 @@ public class ProcedureManager {
     return statusList.get(0);
   }
 
-  public TSStatus createModel(ModelInformation modelInformation) {
-    long procedureId = executor.submitProcedure(new CreateModelProcedure(modelInformation));
+  public TSStatus createModel(ModelInformation modelInformation, Map<String, String> modelConfigs) {
+    long procedureId =
+        executor.submitProcedure(new CreateModelProcedure(modelInformation, modelConfigs));
     List<TSStatus> statusList = new ArrayList<>();
     boolean isSucceed =
         waitingProcedureFinished(Collections.singletonList(procedureId), statusList);
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index d2c5144d10..bb422b3b1e 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -43,7 +43,6 @@ import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
 import java.io.IOException;
-import java.util.Collections;
 import java.util.concurrent.locks.ReentrantLock;
 
 @ThreadSafe
@@ -98,18 +97,22 @@ public class ModelInfo implements SnapshotProcessor {
   public ModelTableResp showModel(ShowModelPlan plan) {
     acquireModelTableLock();
     try {
+      ModelTableResp modelTableResp =
+          new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
       if (plan.isSetModelId()) {
         ModelInformation modelInformation = modelTable.getModelInformationById(plan.getModelId());
-        return new ModelTableResp(
-            new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
-            modelInformation != null
-                ? Collections.singletonList(modelInformation)
-                : Collections.emptyList());
+        if (modelInformation != null) {
+          modelTableResp.addModelInformation(modelInformation);
+        }
       } else {
-        return new ModelTableResp(
-            new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
-            modelTable.getAllModelInformation());
+        modelTableResp.addModelInformation(modelTable.getAllModelInformation());
       }
+      return modelTableResp;
+    } catch (IOException e) {
+      LOGGER.warn("Fail to get ModelTable", e);
+      return new ModelTableResp(
+          new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+              .setMessage(e.getMessage()));
     } finally {
       releaseModelTableLock();
     }
@@ -119,19 +122,23 @@ public class ModelInfo implements SnapshotProcessor {
     acquireModelTableLock();
     try {
       ModelInformation modelInformation = modelTable.getModelInformationById(plan.getModelId());
+      TrailTableResp trailTableResp =
+          new TrailTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
       if (plan.isSetTrailId()) {
         TrailInformation trailInformation =
             modelInformation.getTrailInformationById(plan.getTrailId());
-        return new TrailTableResp(
-            new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
-            trailInformation != null
-                ? Collections.singletonList(trailInformation)
-                : Collections.emptyList());
+        if (trailInformation != null) {
+          trailTableResp.addTrailInformation(trailInformation);
+        }
       } else {
-        return new TrailTableResp(
-            new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
-            modelInformation.getAllTrailInformation());
+        trailTableResp.addTrailInformation(modelInformation.getAllTrailInformation());
       }
+      return trailTableResp;
+    } catch (IOException e) {
+      LOGGER.warn("Fail to get TrailTable", e);
+      return new TrailTableResp(
+          new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+              .setMessage(e.getMessage()));
     } finally {
       releaseModelTableLock();
     }
@@ -142,7 +149,7 @@ public class ModelInfo implements SnapshotProcessor {
     try {
       String modelId = plan.getModelId();
       if (modelTable.containsModel(modelId)) {
-        modelTable.updateModel(modelId, plan.getModelInfo());
+        modelTable.updateModel(modelId, plan.getTrailId(), plan.getModelInfo());
       }
       return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
     } finally {
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
index 0eb4b3800b..b003210f94 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
@@ -23,7 +23,6 @@ import org.apache.iotdb.commons.model.ModelInformation;
 import org.apache.iotdb.commons.model.exception.ModelManagementException;
 import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
 import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
 import org.apache.iotdb.confignode.manager.ConfigManager;
 import org.apache.iotdb.confignode.persistence.ModelInfo;
 import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv;
@@ -32,6 +31,7 @@ import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure;
 import org.apache.iotdb.confignode.procedure.state.model.CreateModelState;
 import org.apache.iotdb.confignode.procedure.store.ProcedureType;
 import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -39,6 +39,7 @@ import org.slf4j.LoggerFactory;
 import java.io.DataOutputStream;
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.Map;
 
 public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState> {
 
@@ -46,14 +47,16 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
   private static final int RETRY_THRESHOLD = 5;
 
   private ModelInformation modelInformation;
+  private Map<String, String> modelConfigs;
 
   public CreateModelProcedure() {
     super();
   }
 
-  public CreateModelProcedure(ModelInformation modelInformation) {
+  public CreateModelProcedure(ModelInformation modelInformation, Map<String, String> modelConfigs) {
     super();
     this.modelInformation = modelInformation;
+    this.modelConfigs = modelConfigs;
   }
 
   @Override
@@ -90,10 +93,10 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
             throw new ModelManagementException(response.getErrorMessage());
           }
 
-          setNextState(CreateModelState.CONFIG_NODE_INACTIVE);
+          setNextState(CreateModelState.CONFIG_NODE_ACTIVE);
           break;
 
-        case CONFIG_NODE_INACTIVE:
+        case CONFIG_NODE_ACTIVE:
           LOGGER.info("Start to train model [{}] on ML Node", modelInformation.getModelId());
 
           if (true) {
@@ -107,16 +110,6 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
           break;
 
         case ML_NODE_ACTIVE:
-          LOGGER.info("Start to active model [{}] on Config Nodes", modelInformation.getModelId());
-          env.getConfigManager()
-              .getConsensusManager()
-              .write(
-                  // TODO
-                  new UpdateModelInfoPlan());
-          setNextState(CreateModelState.CONFIG_NODE_ACTIVE);
-          break;
-
-        case CONFIG_NODE_ACTIVE:
           env.getConfigManager().getTriggerManager().getTriggerInfo().releaseTriggerTableLock();
           return Flow.NO_MORE_STATE;
       }
@@ -160,7 +153,7 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
             .write(new DropModelPlan(modelInformation.getModelId()));
         break;
 
-      case CONFIG_NODE_INACTIVE:
+      case ML_NODE_ACTIVE:
         LOGGER.info(
             "Start to [CONFIG_NODE_INACTIVE] rollback of model [{}]",
             modelInformation.getModelId());
@@ -197,12 +190,14 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
     stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode());
     super.serialize(stream);
     modelInformation.serialize(stream);
+    ReadWriteIOUtils.write(modelConfigs, stream);
   }
 
   @Override
   public void deserialize(ByteBuffer byteBuffer) {
     super.deserialize(byteBuffer);
     modelInformation = ModelInformation.deserialize(byteBuffer);
+    modelConfigs = ReadWriteIOUtils.readMap(byteBuffer);
   }
 
   @Override
@@ -211,7 +206,8 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
       CreateModelProcedure thatProc = (CreateModelProcedure) that;
       return thatProc.getProcId() == this.getProcId()
           && thatProc.getState() == this.getState()
-          && thatProc.modelInformation.equals(this.modelInformation);
+          && thatProc.modelInformation.equals(this.modelInformation)
+          && thatProc.modelConfigs.equals(this.modelConfigs);
     }
     return false;
   }
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java
index 2f9380ce80..304d209817 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java
@@ -22,7 +22,6 @@ package org.apache.iotdb.confignode.procedure.state.model;
 public enum CreateModelState {
   INIT,
   VALIDATED,
-  CONFIG_NODE_INACTIVE,
-  ML_NODE_ACTIVE,
-  CONFIG_NODE_ACTIVE
+  CONFIG_NODE_ACTIVE,
+  ML_NODE_ACTIVE
 }
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastTrailInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastTrailInformation.java
deleted file mode 100644
index f98d7339ea..0000000000
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastTrailInformation.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.commons.model;
-
-import java.nio.ByteBuffer;
-
-public class ForecastTrailInformation extends TrailInformation {
-
-    private long outputLen;
-
-    @Override
-    public ByteBuffer serializeShowTrailResult() {
-        return null;
-    }
-}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java
new file mode 100644
index 0000000000..151a6b7c59
--- /dev/null
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.commons.model;
+
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.Map;
+
+public class ModelHyperparameter {
+
+  private final Map<String, String> keyValueMap;
+
+  public ModelHyperparameter(Map<String, String> keyValueMap) {
+    this.keyValueMap = keyValueMap;
+  }
+
+  public void update(Map<String, String> modelInfo) {
+    this.keyValueMap.putAll(modelInfo);
+  }
+
+  @Override
+  public String toString() {
+    StringBuilder stringBuilder = new StringBuilder();
+    for (Map.Entry<String, String> keyValuePair : keyValueMap.entrySet()) {
+      stringBuilder
+          .append(keyValuePair.getKey())
+          .append('=')
+          .append(keyValuePair.getValue())
+          .append('\n');
+    }
+    return stringBuilder.toString();
+  }
+
+  public void serialize(DataOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(keyValueMap, stream);
+  }
+
+  public void serialize(FileOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(keyValueMap, stream);
+  }
+
+  public static ModelHyperparameter deserialize(ByteBuffer buffer) {
+    return new ModelHyperparameter(ReadWriteIOUtils.readMap(buffer));
+  }
+
+  public static ModelHyperparameter deserialize(InputStream stream) throws IOException {
+    return new ModelHyperparameter(ReadWriteIOUtils.readMap(stream));
+  }
+}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 308c1d7f36..bdd9b13fc9 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -22,32 +22,98 @@ package org.apache.iotdb.commons.model;
 import org.apache.iotdb.common.rpc.thrift.ModelTask;
 import org.apache.iotdb.common.rpc.thrift.TrainingState;
 import org.apache.iotdb.tsfile.utils.PublicBAOS;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
 import java.io.DataOutputStream;
 import java.io.FileOutputStream;
+import java.io.IOException;
 import java.io.InputStream;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
-public class ModelInformation {
+import static org.apache.iotdb.commons.model.TrailInformation.MODEL_PATH;
 
-  private String modelId;
-  private ModelTask modelTask;
-  private String modelType;
+public class ModelInformation {
 
-  private TrainingState modelState;
+  private final String modelId;
+  private final ModelTask modelTask;
+  private final String modelType;
 
-  private List<String> queryExpressions;
-  private String queryFilter;
+  private final List<String> queryExpressions;
+  private final String queryFilter;
 
-  private boolean isAuto;
+  private final boolean isAuto;
+  private TrainingState trainingState;
 
+  private String bestTrailId;
   private Map<String, TrailInformation> trailMap;
 
-  private String bestTrailId;
-  private String modelPath;
+  public ModelInformation(
+      String modelId,
+      ModelTask modelTask,
+      String modelType,
+      boolean isAuto,
+      List<String> queryExpressions,
+      String queryFilter) {
+    this.modelId = modelId;
+    this.modelTask = modelTask;
+    this.modelType = modelType;
+    this.isAuto = isAuto;
+    this.queryExpressions = queryExpressions;
+    this.queryFilter = queryFilter;
+  }
+
+  public ModelInformation(ByteBuffer buffer) {
+    this.modelId = ReadWriteIOUtils.readString(buffer);
+    this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer));
+    this.modelType = ReadWriteIOUtils.readString(buffer);
+
+    int listSize = ReadWriteIOUtils.readInt(buffer);
+    this.queryExpressions = new ArrayList<>(listSize);
+    for (int i = 0; i < listSize; i++) {
+      this.queryExpressions.add(ReadWriteIOUtils.readString(buffer));
+    }
+
+    this.queryFilter = ReadWriteIOUtils.readString(buffer);
+    this.isAuto = ReadWriteIOUtils.readBool(buffer);
+    this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(buffer));
+    this.bestTrailId = ReadWriteIOUtils.readString(buffer);
+
+    int mapSize = ReadWriteIOUtils.readInt(buffer);
+    this.trailMap = new HashMap<>();
+    for (int i = 0; i < mapSize; i++) {
+      TrailInformation trailInformation = TrailInformation.deserialize(buffer);
+      this.trailMap.put(trailInformation.getTrailId(), trailInformation);
+    }
+  }
+
+  public ModelInformation(InputStream stream) throws IOException {
+    this.modelId = ReadWriteIOUtils.readString(stream);
+    this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(stream));
+    this.modelType = ReadWriteIOUtils.readString(stream);
+
+    int listSize = ReadWriteIOUtils.readInt(stream);
+    this.queryExpressions = new ArrayList<>(listSize);
+    for (int i = 0; i < listSize; i++) {
+      this.queryExpressions.add(ReadWriteIOUtils.readString(stream));
+    }
+
+    this.queryFilter = ReadWriteIOUtils.readString(stream);
+    this.isAuto = ReadWriteIOUtils.readBool(stream);
+    this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(stream));
+    this.bestTrailId = ReadWriteIOUtils.readString(stream);
+
+    int mapSize = ReadWriteIOUtils.readInt(stream);
+    this.trailMap = new HashMap<>();
+    for (int i = 0; i < mapSize; i++) {
+      TrailInformation trailInformation = TrailInformation.deserialize(stream);
+      this.trailMap.put(trailInformation.getTrailId(), trailInformation);
+    }
+  }
 
   public String getModelId() {
     return modelId;
@@ -68,24 +134,80 @@ public class ModelInformation {
     return new ArrayList<>(trailMap.values());
   }
 
-  public void update(Map<String, String> modelInfo) {}
+  public void update(String trailId, Map<String, String> modelInfo) {
+    if (!trailMap.containsKey(trailId)) {
+      String modelPath = null;
+      if (modelInfo.containsKey(MODEL_PATH)) {
+        modelPath = modelInfo.get(MODEL_PATH);
+        modelInfo.remove(MODEL_PATH);
+      }
+      TrailInformation trailInformation =
+          new TrailInformation(trailId, new ModelHyperparameter(modelInfo), modelPath);
+      trailMap.put(trailId, trailInformation);
+    } else {
+      trailMap.get(trailId).update(modelInfo);
+    }
+  }
 
-  public void serialize(DataOutputStream stream) {}
+  public void serialize(DataOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(modelId, stream);
+    ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+    ReadWriteIOUtils.write(modelType, stream);
+    ReadWriteIOUtils.write(queryExpressions.size(), stream);
+    for (String queryExpression : queryExpressions) {
+      ReadWriteIOUtils.write(queryExpression, stream);
+    }
+    ReadWriteIOUtils.write(queryFilter, stream);
+    ReadWriteIOUtils.write(isAuto, stream);
+    ReadWriteIOUtils.write(trainingState.ordinal(), stream);
+    ReadWriteIOUtils.write(bestTrailId, stream);
+    ReadWriteIOUtils.write(trailMap.size(), stream);
+    for (TrailInformation trailInformation : trailMap.values()) {
+      trailInformation.serialize(stream);
+    }
+  }
 
-  public void serialize(FileOutputStream stream) {}
+  public void serialize(FileOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(modelId, stream);
+    ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+    ReadWriteIOUtils.write(modelType, stream);
 
-  public static ModelInformation deserialize(InputStream stream) {
-    return null;
+    ReadWriteIOUtils.write(queryExpressions.size(), stream);
+    for (String queryExpression : queryExpressions) {
+      ReadWriteIOUtils.write(queryExpression, stream);
+    }
+
+    ReadWriteIOUtils.write(queryFilter, stream);
+    ReadWriteIOUtils.write(isAuto, stream);
+    ReadWriteIOUtils.write(trainingState.ordinal(), stream);
+    ReadWriteIOUtils.write(bestTrailId, stream);
+
+    ReadWriteIOUtils.write(trailMap.size(), stream);
+    for (TrailInformation trailInformation : trailMap.values()) {
+      trailInformation.serialize(stream);
+    }
+  }
+
+  public static ModelInformation deserialize(InputStream stream) throws IOException {
+    return new ModelInformation(stream);
   }
 
   public static ModelInformation deserialize(ByteBuffer buffer) {
-    return null;
+    return new ModelInformation(buffer);
   }
 
-  public ByteBuffer serializeShowModelResult() {
+  public ByteBuffer serializeShowModelResult() throws IOException {
     PublicBAOS buffer = new PublicBAOS();
     DataOutputStream stream = new DataOutputStream(buffer);
-
+    ReadWriteIOUtils.write(modelId, stream);
+    ReadWriteIOUtils.write(modelTask.toString(), stream);
+    ReadWriteIOUtils.write(modelType, stream);
+    ReadWriteIOUtils.write(Arrays.toString(queryExpressions.toArray(new String[0])), stream);
+    ReadWriteIOUtils.write(trainingState.toString(), stream);
+
+    TrailInformation bestTrail = trailMap.get(bestTrailId);
+    ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream);
+    ReadWriteIOUtils.write(bestTrail.getModelPath(), stream);
     return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size());
   }
 }
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java
index 598c26017e..756a57060f 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java
@@ -60,8 +60,8 @@ public class ModelTable {
     return null;
   }
 
-  public void updateModel(String modelId, Map<String, String> modelInfo) {
-    modelInfoMap.get(modelId).update(modelInfo);
+  public void updateModel(String modelId, String trailId, Map<String, String> modelInfo) {
+    modelInfoMap.get(modelId).update(trailId, modelInfo);
   }
 
   public void clear() {
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java
index 0b35c3a98e..8551534d41 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java
@@ -19,23 +19,83 @@
 
 package org.apache.iotdb.commons.model;
 
-import org.apache.iotdb.common.rpc.thrift.Activation;
-import org.apache.iotdb.common.rpc.thrift.TrainingState;
+import org.apache.iotdb.tsfile.utils.PublicBAOS;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
 import java.nio.ByteBuffer;
+import java.util.Map;
 
-public abstract class TrailInformation {
+public class TrailInformation {
 
-  protected String trailId;
-  protected TrainingState trailState;
+  public static final String MODEL_PATH = "model_path";
 
-  protected long batchSize;
-  protected double learningRate;
-  protected long epochs;
+  private final String trailId;
+  private final ModelHyperparameter modelHyperparameter;
+  private String modelPath;
 
-  protected long dModel;
-  protected long dFF;
-  protected Activation activation;
+  public TrailInformation(
+      String trailId, ModelHyperparameter modelHyperparameter, String modelPath) {
+    this.trailId = trailId;
+    this.modelHyperparameter = modelHyperparameter;
+    this.modelPath = modelPath;
+  }
 
-  public abstract ByteBuffer serializeShowTrailResult();
+  public void update(Map<String, String> modelInfo) {
+    if (modelInfo.containsKey(MODEL_PATH)) {
+      modelPath = modelInfo.get(MODEL_PATH);
+      modelInfo.remove(MODEL_PATH);
+    }
+    modelHyperparameter.update(modelInfo);
+  }
+
+  public String getTrailId() {
+    return trailId;
+  }
+
+  public ModelHyperparameter getModelHyperparameter() {
+    return modelHyperparameter;
+  }
+
+  public String getModelPath() {
+    return modelPath;
+  }
+
+  public ByteBuffer serializeShowTrailResult() throws IOException {
+    PublicBAOS buffer = new PublicBAOS();
+    DataOutputStream stream = new DataOutputStream(buffer);
+    ReadWriteIOUtils.write(trailId, stream);
+    ReadWriteIOUtils.write(modelHyperparameter.toString(), stream);
+    ReadWriteIOUtils.write(modelPath, stream);
+    return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size());
+  }
+
+  public void serialize(DataOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(trailId, stream);
+    modelHyperparameter.serialize(stream);
+    ReadWriteIOUtils.write(modelPath, stream);
+  }
+
+  public void serialize(FileOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(trailId, stream);
+    modelHyperparameter.serialize(stream);
+    ReadWriteIOUtils.write(modelPath, stream);
+  }
+
+  public static TrailInformation deserialize(ByteBuffer buffer) {
+    return new TrailInformation(
+        ReadWriteIOUtils.readString(buffer),
+        ModelHyperparameter.deserialize(buffer),
+        ReadWriteIOUtils.readString(buffer));
+  }
+
+  public static TrailInformation deserialize(InputStream stream) throws IOException {
+    return new TrailInformation(
+        ReadWriteIOUtils.readString(stream),
+        ModelHyperparameter.deserialize(stream),
+        ReadWriteIOUtils.readString(stream));
+  }
 }
diff --git a/thrift-commons/src/main/thrift/common.thrift b/thrift-commons/src/main/thrift/common.thrift
index 108010de8b..916f810624 100644
--- a/thrift-commons/src/main/thrift/common.thrift
+++ b/thrift-commons/src/main/thrift/common.thrift
@@ -126,7 +126,10 @@ struct TFilesResp {
 
 // for MLNode
 enum TrainingState {
-
+  PENDING,
+  RUNNING,
+  FINISHED,
+  FAILED
 }
 
 enum ModelTask {
@@ -137,11 +140,4 @@ enum EvaluateMetric {
   MSE,
   MAE,
   RMSE
-}
-
-enum Activation {
-  RELU,
-  GELU,
-  SIGMOID,
-  TANH
 }
\ No newline at end of file
diff --git a/thrift-confignode/src/main/thrift/confignode.thrift b/thrift-confignode/src/main/thrift/confignode.thrift
index 0544374411..51dbd39d9c 100644
--- a/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/thrift-confignode/src/main/thrift/confignode.thrift
@@ -668,11 +668,12 @@ struct TUnsetSchemaTemplateReq{
 
 struct TCreateModelReq {
   1: required string modelId
-  2: required byte modelTask
-  3: required bool isAuto
-  4: required map<string, string> modelConfigs
-  5: required list<string> queryExpressions
-  6: optional string queryFilter
+  2: required common.ModelTask modelTask
+  3: required string modelType
+  4: required list<string> queryExpressions
+  5: optional string queryFilter
+  6: required bool isAuto
+  7: required map<string, string> modelConfigs
 }
 
 struct TDropModelReq {
@@ -700,7 +701,13 @@ struct TShowTrailResp {
 
 struct TUpdateModelInfoReq {
   1: required string modelId
-  2: required map<string, string> modelInfo
+  2: required string trailId
+  3: required map<string, string> modelInfo
+}
+
+struct TUpdateModelStateReq {
+  1: required string modelId
+  2: required common.TrainingState state
 }
 
 service IConfigNodeRPCService {
@@ -1250,5 +1257,12 @@ service IConfigNodeRPCService {
    * @return SUCCESS_STATUS if the model was removed successfully
    */
   common.TSStatus updateModelInfo(TUpdateModelInfoReq req)
+
+  /**
+   * Update the model state
+   *
+   * @return SUCCESS_STATUS if the model was removed successfully
+   */
+   common.TSStatus updateModelState(TUpdateModelStateReq req)
 }