You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/04/10 03:37:26 UTC

[iotdb] 01/01: init

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch mlnode/dropModel
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit c29cfc28812d3c5aecc7268413e011dfd75fd769
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon Apr 10 11:37:03 2023 +0800

    init
---
 .../procedure/impl/model/DropModelProcedure.java   |   4 +-
 mlnode/iotdb/mlnode/client.py                      | 107 ++++++++++++++-------
 mlnode/iotdb/mlnode/handler.py                     |   6 +-
 mlnode/iotdb/mlnode/model_storage.py               |  23 +++--
 mlnode/pyproject.toml                              |   1 +
 mlnode/requirements.txt                            |   2 +-
 mlnode/test/test_model_storage.py                  |  30 ++++--
 .../db/mpp/plan/parser/StatementGenerator.java     |  21 +++-
 .../impl/DataNodeInternalRPCServiceImpl.java       |  27 +++++-
 .../service/thrift/impl/MLNodeRPCServiceImpl.java  |   7 +-
 thrift-mlnode/src/main/thrift/mlnode.thrift        |   2 +-
 thrift/src/main/thrift/datanode.thrift             |   1 -
 12 files changed, 166 insertions(+), 65 deletions(-)

diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index bfa461a8b2..8f41de070d 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -156,7 +156,9 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
         if (getCycles() > RETRY_THRESHOLD) {
           setFailure(
               new ProcedureException(
-                  String.format("Fail to drop model [%s] at STATE [%s]", modelId, state)));
+                  String.format(
+                      "Fail to drop model [%s] at STATE [%s], %s",
+                      modelId, state, e.getMessage())));
         }
       }
     }
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 244b6975c9..3157006e57 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -16,38 +16,33 @@
 # under the License.
 #
 import time
+from typing import Dict, List
 
+import pandas as pd
 from thrift.protocol import TBinaryProtocol, TCompactProtocol
 from thrift.Thrift import TException
 from thrift.transport import TSocket, TTransport
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode import serde
+from iotdb.mlnode.config import descriptor
+from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.log import logger
-from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
+from iotdb.mlnode.util import verify_success
+from iotdb.thrift.common.ttypes import TEndPoint, TrainingState, TSStatus
 from iotdb.thrift.confignode import IConfigNodeRPCService
-from iotdb.thrift.confignode.ttypes import TUpdateModelInfoReq
-from iotdb.thrift.datanode import IDataNodeRPCService
+from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq,
+                                            TUpdateModelStateReq)
+from iotdb.thrift.datanode import IMLNodeInternalRPCService
 from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
-                                          TFetchTimeseriesResp,
                                           TRecordModelMetricsReq)
 from iotdb.thrift.mlnode import IMLNodeRPCService
 from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
 
-# status code
-SUCCESS_STATUS = 200
-REDIRECTION_RECOMMEND = 400
-
-
-def verify_success(status: TSStatus, err_msg: str) -> None:
-    if status.code != SUCCESS_STATUS:
-        logger.warn(err_msg + ", error status is ", status)
-        raise RuntimeError(str(status.code) + ": " + status.message)
-
 
 class ClientManager(object):
     def __init__(self):
-        self.__data_node_endpoint = config.get_mn_target_data_node()
-        self.__config_node_endpoint = config.get_mn_target_config_node()
+        self.__data_node_endpoint = descriptor.get_config().get_mn_target_data_node()
+        self.__config_node_endpoint = descriptor.get_config().get_mn_target_config_node()
 
     def borrow_data_node_client(self):
         return DataNodeClient(host=self.__data_node_endpoint.ip,
@@ -77,9 +72,9 @@ class MLNodeClient(object):
     def create_training_task(self,
                              model_id: str,
                              is_auto: bool,
-                             model_configs: dict,
-                             query_expressions: list[str],
-                             query_filter: str = None) -> None:
+                             model_configs: Dict,
+                             query_expressions: List[str],
+                             query_filter: str = '') -> None:
         req = TCreateTrainingTaskReq(
             modelId=model_id,
             isAuto=is_auto,
@@ -124,20 +119,17 @@ class DataNodeClient(object):
                 transport.open()
             except TTransport.TTransportException as e:
                 logger.exception("TTransportException!", exc_info=e)
+                raise e
 
         protocol = TBinaryProtocol.TBinaryProtocol(transport)
-        self.__client = IDataNodeRPCService.Client(protocol)
+        self.__client = IMLNodeInternalRPCService.Client(protocol)
 
     def fetch_timeseries(self,
-                         session_id: int,
-                         statement_id: int,
-                         query_expressions: list[str],
+                         query_expressions: List[str],
                          query_filter: str = None,
                          fetch_size: int = DEFAULT_FETCH_SIZE,
-                         timeout: int = DEFAULT_TIMEOUT) -> TFetchTimeseriesResp:
+                         timeout: int = DEFAULT_TIMEOUT) -> [int, bool, pd.DataFrame]:
         req = TFetchTimeseriesReq(
-            sessionId=session_id,
-            statementId=statement_id,
             queryExpressions=query_expressions,
             queryFilter=query_filter,
             fetchSize=fetch_size,
@@ -146,15 +138,35 @@ class DataNodeClient(object):
         try:
             resp = self.__client.fetchTimeseries(req)
             verify_success(resp.status, "An error occurs when calling fetch_timeseries()")
-            return resp
-        except TTransport.TException as e:
+
+            if len(resp.tsDataset) == 0:
+                raise RuntimeError(f'No data fetched with query filter: {query_filter}')
+
+            data = serde.convert_to_df(resp.columnNameList,
+                                       resp.columnTypeList,
+                                       resp.columnNameIndexMap,
+                                       resp.tsDataset)
+            if data.empty:
+                raise RuntimeError(
+                    f'Fetched empty data with query expressions: {query_expressions} and query filter: {query_filter}')
+            return resp.queryId, resp.hasMoreData, data
+        except Exception as e:
+            logger.warn(
+                f'Fail to fetch data with query expressions: {query_expressions} and query filter: {query_filter}')
             raise e
 
+    def fetch_window_batch(self,
+                           query_expressions: list,
+                           query_filter: str = None,
+                           fetch_size: int = DEFAULT_FETCH_SIZE,
+                           timeout: int = DEFAULT_TIMEOUT) -> [int, bool, List[pd.DataFrame]]:
+        pass
+
     def record_model_metrics(self,
                              model_id: str,
                              trial_id: str,
-                             metrics: list[str],
-                             values: list[float]) -> None:
+                             metrics: List[str],
+                             values: List) -> None:
         req = TRecordModelMetricsReq(
             modelId=model_id,
             trialId=trial_id,
@@ -192,6 +204,7 @@ class ConfigNodeClient(object):
         if self.__config_leader is not None:
             try:
                 self.__connect(self.__config_leader)
+                return
             except TException:
                 logger.warn("The current node {} may have been down, try next node", self.__config_leader)
                 self.__config_leader = None
@@ -206,6 +219,7 @@ class ConfigNodeClient(object):
             try_endpoint = self.__config_nodes[self.__cursor]
             try:
                 self.__connect(try_endpoint)
+                return
             except TException:
                 logger.warn("The current node {} may have been down, try next node", try_endpoint)
 
@@ -223,7 +237,7 @@ class ConfigNodeClient(object):
             except TTransport.TTransportException as e:
                 logger.exception("TTransportException!", exc_info=e)
 
-        protocol = TCompactProtocol.TBinaryProtocol(transport)
+        protocol = TBinaryProtocol.TBinaryProtocol(transport)
         self.__client = IConfigNodeRPCService.Client(protocol)
 
     def __wait_and_reconnect(self) -> None:
@@ -242,7 +256,7 @@ class ConfigNodeClient(object):
         pass
 
     def __update_config_node_leader(self, status: TSStatus) -> bool:
-        if status.code == REDIRECTION_RECOMMEND:
+        if status.code == TSStatusCode.REDIRECTION_RECOMMEND:
             if status.redirectNode is not None:
                 self.__config_leader = status.redirectNode
             else:
@@ -250,15 +264,38 @@ class ConfigNodeClient(object):
             return True
         return False
 
+    def update_model_state(self,
+                           model_id: str,
+                           training_state: TrainingState,
+                           best_trail_id: str = None) -> None:
+        req = TUpdateModelStateReq(
+            modelId=model_id,
+            state=training_state,
+            bestTrailId=best_trail_id
+        )
+        for i in range(0, self.__RETRY_NUM):
+            try:
+                status = self.__client.updateModelState(req)
+                if not self.__update_config_node_leader(status):
+                    verify_success(status, "An error occurs when calling update_model_state()")
+                    return
+            except TTransport.TException:
+                logger.warn("Failed to connect to ConfigNode {} from MLNode when executing update_model_info()",
+                            self.__config_leader)
+                self.__config_leader = None
+            self.__wait_and_reconnect()
+
+        raise TException(self.__MSG_RECONNECTION_FAIL)
+
     def update_model_info(self,
                           model_id: str,
                           trial_id: str,
-                          model_info: dict) -> None:
+                          model_info: Dict) -> None:
         if model_info is None:
             model_info = {}
         req = TUpdateModelInfoReq(
             modelId=model_id,
-            trialId=trial_id,
+            trailId=trial_id,
             modelInfo={k: str(v) for k, v in model_info.items()},
         )
 
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 8a36353d47..a5dd639756 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -42,7 +42,11 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
         pass
 
     def deleteModel(self, req: TDeleteModelReq):
-        return get_status(TSStatusCode.SUCCESS_STATUS, "")
+        try:
+            model_storage.delete_model(req.modelId)
+            return get_status(TSStatusCode.SUCCESS_STATUS)
+        except Exception as e:
+            return get_status(TSStatusCode.FAIL_STATUS, str(e))
 
     def createTrainingTask(self, req: TCreateTrainingTaskReq):
         return get_status(TSStatusCode.SUCCESS_STATUS, "")
diff --git a/mlnode/iotdb/mlnode/model_storage.py b/mlnode/iotdb/mlnode/model_storage.py
index ee745689b1..84d7dfd7ed 100644
--- a/mlnode/iotdb/mlnode/model_storage.py
+++ b/mlnode/iotdb/mlnode/model_storage.py
@@ -19,42 +19,49 @@
 import json
 import os
 import shutil
+from typing import Dict, Tuple
 
 import torch
 import torch.nn as nn
 from pylru import lrucache
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.log import logger
 
 
 class ModelStorage(object):
     def __init__(self):
-        self.__model_dir = os.path.join(os.getcwd(), config.get_mn_model_storage_dir())
+        self.__model_dir = os.path.join('.', descriptor.get_config().get_mn_model_storage_dir())
         if not os.path.exists(self.__model_dir):
-            os.mkdir(self.__model_dir)
+            try:
+                os.mkdir(self.__model_dir)
+            except PermissionError as e:
+                logger.error(e)
+                raise e
 
-        self.__model_cache = lrucache(config.get_mn_model_storage_cache_size())
+        self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
 
     def save_model(self,
                    model: nn.Module,
-                   model_config: dict,
+                   model_config: Dict,
                    model_id: str,
-                   trial_id: str) -> None:
+                   trial_id: str) -> str:
         """
         Note: model config for time series should contain 'input_len' and 'input_vars'
         """
         model_dir_path = os.path.join(self.__model_dir, f'{model_id}')
         if not os.path.exists(model_dir_path):
-            os.mkdir(model_dir_path)
+            os.makedirs(model_dir_path)
         model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt')
 
         sample_input = [torch.randn(1, model_config['input_len'], model_config['input_vars'])]
         torch.jit.save(torch.jit.trace(model, sample_input),
                        model_file_path,
                        _extra_files={'model_config': json.dumps(model_config)})
+        return os.path.abspath(model_file_path)
 
-    def load_model(self, model_id: str, trial_id: str) -> (torch.jit.ScriptModule, dict):
+    def load_model(self, model_id: str, trial_id: str) -> Tuple[torch.jit.ScriptModule, Dict]:
         """
         Returns:
             jit_model: a ScriptModule contains model architecture and parameters, which can be deployed cross-platform
diff --git a/mlnode/pyproject.toml b/mlnode/pyproject.toml
index 3944e2910d..56290f8d4e 100644
--- a/mlnode/pyproject.toml
+++ b/mlnode/pyproject.toml
@@ -49,6 +49,7 @@ packages = [
 python = "^3.7"
 thrift = "^0.13.0"
 dynaconf = "^3.1.11"
+pylru = "^1.2.1"
 
 [tool.poetry.scripts]
 mlnode = "iotdb.mlnode.script:main"
\ No newline at end of file
diff --git a/mlnode/requirements.txt b/mlnode/requirements.txt
index edd85701ab..c49c8a0189 100644
--- a/mlnode/requirements.txt
+++ b/mlnode/requirements.txt
@@ -20,7 +20,7 @@ pandas>=1.3.5
 numpy>=1.21.4
 apache-iotdb
 poetry
-torch
+torch~=2.0.0
 pylru
 
 thrift~=0.13.0
diff --git a/mlnode/test/test_model_storage.py b/mlnode/test/test_model_storage.py
index 99857db37e..1b18f974aa 100644
--- a/mlnode/test/test_model_storage.py
+++ b/mlnode/test/test_model_storage.py
@@ -23,19 +23,20 @@ import time
 import torch.nn as nn
 
 from iotdb.mlnode.config import config
-from iotdb.mlnode.model_storage import model_storage
+from iotdb.mlnode.exception import ModelNotExistError
+from iotdb.mlnode.storage import model_storage
 
 
-class TestModel(nn.Module):
+class ExampleModel(nn.Module):
     def __init__(self):
-        super(TestModel, self).__init__()
+        super(ExampleModel, self).__init__()
         self.layer = nn.Identity()
 
     def forward(self, x):
         return self.layer(x)
 
 
-model = TestModel()
+model = ExampleModel()
 model_config = {
     'input_len': 1,
     'input_vars': 1,
@@ -47,7 +48,7 @@ def test_save_model():
     trial_id = 'tid_0'
     model_id = 'mid_test_model_save'
     model_storage.save_model(model, model_config, model_id=model_id, trial_id=trial_id)
-    assert os.path.exists(os.path.join(config.get_mn_model_storage_dir(), model_id, f'{trial_id}.pt'))
+    assert os.path.exists(os.path.join(config.get_mn_model_storage_dir(), f'{model_id}', f'{trial_id}.pt'))
 
 
 def test_load_model():
@@ -58,6 +59,17 @@ def test_load_model():
     assert model_config == model_config_loaded
 
 
+def test_load_not_exist_model():
+    trial_id = 'dummy_trial'
+    model_id = 'dummy_model'
+    try:
+        model_loaded, model_config_loaded = model_storage.load_model(model_id=model_id, trial_id=trial_id)
+    except Exception as e:
+        assert e.message == ModelNotExistError(
+            os.path.join(os.getcwd(), config.get_mn_model_storage_dir(),
+                         model_id, f'{trial_id}.pt')).message
+
+
 def test_delete_model():
     trial_id1 = 'tid_1'
     trial_id2 = 'tid_2'
@@ -65,9 +77,9 @@ def test_delete_model():
     model_storage.save_model(model, model_config, model_id=model_id, trial_id=trial_id1)
     model_storage.save_model(model, model_config, model_id=model_id, trial_id=trial_id2)
     model_storage.delete_model(model_id=model_id)
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), model_id, f'{trial_id1}.pt'))
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), model_id, f'{trial_id2}.pt'))
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), model_id))
+    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), f'{model_id}', f'{trial_id1}.pt'))
+    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), f'{model_id}', f'{trial_id2}.pt'))
+    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), f'{model_id}'))
 
 
 def test_delete_trial():
@@ -75,4 +87,4 @@ def test_delete_trial():
     model_id = 'mid_test_model_delete'
     model_storage.save_model(model, model_config, model_id=model_id, trial_id=trial_id)
     model_storage.delete_trial(model_id=model_id, trial_id=trial_id)
-    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), model_id, f'{trial_id}.pt'))
+    assert not os.path.exists(os.path.join(config.get_mn_model_storage_dir(), f'{model_id}', f'{trial_id}.pt'))
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
index ee59471e6e..432b76a718 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
@@ -66,6 +66,7 @@ import org.apache.iotdb.db.mpp.plan.statement.metadata.template.UnsetSchemaTempl
 import org.apache.iotdb.db.qp.sql.IoTDBSqlParser;
 import org.apache.iotdb.db.qp.sql.SqlLexer;
 import org.apache.iotdb.db.utils.QueryDataSetUtils;
+import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq;
 import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq;
 import org.apache.iotdb.mpp.rpc.thrift.TRecordModelMetricsReq;
 import org.apache.iotdb.service.rpc.thrift.TSAggregationQueryReq;
@@ -111,6 +112,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
+import static org.apache.iotdb.commons.conf.IoTDBConstant.MULTI_LEVEL_PATH_WILDCARD;
+import static org.apache.iotdb.db.service.thrift.impl.MLNodeRPCServiceImpl.ML_METRICS_PATH_PREFIX;
+
 /** Convert SQL and RPC requests to {@link Statement}. */
 public class StatementGenerator {
   private static final PerformanceOverviewMetrics PERFORMANCE_OVERVIEW_METRICS =
@@ -807,10 +811,10 @@ public class StatementGenerator {
     return databasePath;
   }
 
-  public static InsertRowStatement createStatement(
-      TRecordModelMetricsReq recordModelMetricsReq, String prefix) throws IllegalPathException {
+  public static InsertRowStatement createStatement(TRecordModelMetricsReq recordModelMetricsReq)
+      throws IllegalPathException {
     String path =
-        prefix
+        ML_METRICS_PATH_PREFIX
             + TsFileConstant.PATH_SEPARATOR
             + recordModelMetricsReq.getModelId()
             + TsFileConstant.PATH_SEPARATOR
@@ -873,4 +877,15 @@ public class StatementGenerator {
     }
     return queryStatement;
   }
+
+  public static DeleteTimeSeriesStatement createStatement(TDeleteModelMetricsReq req)
+      throws IllegalPathException {
+    String path =
+        ML_METRICS_PATH_PREFIX
+            + TsFileConstant.PATH_SEPARATOR
+            + req.getModelId()
+            + TsFileConstant.PATH_SEPARATOR
+            + MULTI_LEVEL_PATH_WILDCARD;
+    return new DeleteTimeSeriesStatement(Collections.singletonList(new PartialPath(path)));
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
index c50c74e841..b575bc7543 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -104,6 +104,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.write.DeleteDataNode;
 import org.apache.iotdb.db.mpp.plan.scheduler.load.LoadTsFileScheduler;
 import org.apache.iotdb.db.mpp.plan.statement.component.WhereCondition;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
+import org.apache.iotdb.db.mpp.plan.statement.metadata.DeleteTimeSeriesStatement;
 import org.apache.iotdb.db.pipe.agent.PipeAgent;
 import org.apache.iotdb.db.query.control.SessionManager;
 import org.apache.iotdb.db.query.control.clientsession.IClientSession;
@@ -200,6 +201,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.TimeZone;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
@@ -880,7 +882,30 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface
 
   @Override
   public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws TException {
-    return RpcUtils.SUCCESS_STATUS;
+    IClientSession session = new InternalClientSession(req.getModelId());
+    SESSION_MANAGER.registerSession(session);
+    SESSION_MANAGER.supplySession(
+        session, "MLNode", TimeZone.getDefault().getID(), ClientVersion.V_1_0);
+
+    try {
+      DeleteTimeSeriesStatement deleteTimeSeriesStatement = StatementGenerator.createStatement(req);
+
+      long queryId = SESSION_MANAGER.requestQueryId();
+      ExecutionResult result =
+          COORDINATOR.execute(
+              deleteTimeSeriesStatement,
+              queryId,
+              SESSION_MANAGER.getSessionInfo(session),
+              "",
+              PARTITION_FETCHER,
+              SCHEMA_FETCHER);
+      return result.status;
+    } catch (Exception e) {
+      return onQueryException(e, OperationType.DELETE_TIMESERIES);
+    } finally {
+      SESSION_MANAGER.closeSession(session, COORDINATOR::cleanupQueryExecution);
+      SESSION_MANAGER.removeCurrSession();
+    }
   }
 
   @Override
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
index 544d4cd04c..74e14f9f3f 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
@@ -60,14 +60,14 @@ import static org.apache.iotdb.db.utils.ErrorHandlingUtils.onQueryException;
 
 public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
 
+  public static final String ML_METRICS_PATH_PREFIX = "root.__system.ml.exp";
+
   private static final Logger LOGGER = LoggerFactory.getLogger(MLNodeRPCServiceImpl.class);
 
   private static final SessionManager SESSION_MANAGER = SessionManager.getInstance();
 
   private static final Coordinator COORDINATOR = Coordinator.getInstance();
 
-  private static final String ML_METRICS_STORAGE_GROUP = "root.__system.ml.exp";
-
   private final IPartitionFetcher PARTITION_FETCHER;
 
   private final ISchemaFetcher SCHEMA_FETCHER;
@@ -176,8 +176,7 @@ public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
   @Override
   public TSStatus recordModelMetrics(TRecordModelMetricsReq req) throws TException {
     try {
-      InsertRowStatement insertRowStatement =
-          StatementGenerator.createStatement(req, ML_METRICS_STORAGE_GROUP);
+      InsertRowStatement insertRowStatement = StatementGenerator.createStatement(req);
 
       long queryId = SESSION_MANAGER.requestQueryId();
       ExecutionResult result =
diff --git a/thrift-mlnode/src/main/thrift/mlnode.thrift b/thrift-mlnode/src/main/thrift/mlnode.thrift
index 916022e973..abadc79576 100644
--- a/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -31,7 +31,7 @@ struct TCreateTrainingTaskReq {
 
 struct TDeleteModelReq {
   1: required string modelId
-  2: optional string trailId
+  2: optional string trialId
 }
 
 struct TForecastReq {
diff --git a/thrift/src/main/thrift/datanode.thrift b/thrift/src/main/thrift/datanode.thrift
index 49066f56c2..23cc0fa032 100644
--- a/thrift/src/main/thrift/datanode.thrift
+++ b/thrift/src/main/thrift/datanode.thrift
@@ -809,4 +809,3 @@ service IMLNodeInternalRPCService{
   */
   common.TSStatus recordModelMetrics(TRecordModelMetricsReq req)
 }
-