You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/04/03 13:21:26 UTC

[iotdb] 02/02: make mlnode available

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch mlnode/test
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 833d0619ed19711132042578ae0ea6123bec77c4
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon Apr 3 16:11:31 2023 +0800

    make mlnode available
---
 mlnode/iotdb/mlnode/algorithm/enums.py             |  3 ++
 mlnode/iotdb/mlnode/algorithm/factory.py           |  1 +
 .../mlnode/algorithm/models/forecast/dlinear.py    |  3 +-
 mlnode/iotdb/mlnode/client.py                      | 32 ++++++------
 mlnode/iotdb/mlnode/config.py                      | 10 ++--
 mlnode/iotdb/mlnode/constant.py                    |  6 ---
 mlnode/iotdb/mlnode/data_access/enums.py           |  3 ++
 mlnode/iotdb/mlnode/data_access/offline/source.py  |  4 +-
 mlnode/iotdb/mlnode/handler.py                     | 32 +++++-------
 mlnode/iotdb/mlnode/parser.py                      |  7 ++-
 mlnode/iotdb/mlnode/process/manager.py             | 36 +++++++------
 mlnode/iotdb/mlnode/process/task.py                |  4 +-
 mlnode/iotdb/mlnode/process/task_factory.py        |  2 +-
 mlnode/iotdb/mlnode/process/trial.py               | 59 ++++++++++++----------
 mlnode/iotdb/mlnode/service.py                     | 11 ++--
 mlnode/iotdb/mlnode/storage.py                     | 11 ++--
 mlnode/iotdb/mlnode/util.py                        |  2 +-
 mlnode/pyproject.toml                              |  1 +
 mlnode/requirements.txt                            |  2 +-
 19 files changed, 122 insertions(+), 107 deletions(-)

diff --git a/mlnode/iotdb/mlnode/algorithm/enums.py b/mlnode/iotdb/mlnode/algorithm/enums.py
index 4b05aa4bf8..2def3751cd 100644
--- a/mlnode/iotdb/mlnode/algorithm/enums.py
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@ -27,3 +27,6 @@ class ForecastTaskType(Enum):
 
     def __eq__(self, other: str) -> bool:
         return self.value == other
+
+    def __hash__(self) -> int:
+        return hash(self.value)
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py b/mlnode/iotdb/mlnode/algorithm/factory.py
index 92cb01a883..26eab10860 100644
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -19,6 +19,7 @@ import torch.nn as nn
 
 from iotdb.mlnode.algorithm.enums import ForecastTaskType
 from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
+from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear
 from iotdb.mlnode.exception import BadConfigValueError
 
 
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
index fa9ee04e56..966ea20347 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -21,6 +21,7 @@ import math
 import torch
 import torch.nn as nn
 
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
 from iotdb.mlnode.exception import BadConfigValueError
 
 
@@ -65,7 +66,7 @@ class DLinear(nn.Module):
             pred_len=96,
             input_vars=1,
             output_vars=1,
-            forecast_type='m',  # TODO, support others
+            forecast_task_type=ForecastTaskType.ENDOGENOUS,  # TODO, support others
     ):
         super(DLinear, self).__init__()
         self.input_len = input_len
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 6c1c549ea1..724d517316 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -21,7 +21,7 @@ from thrift.protocol import TBinaryProtocol, TCompactProtocol
 from thrift.Thrift import TException
 from thrift.transport import TSocket, TTransport
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.log import logger
 from iotdb.mlnode.util import verify_success
@@ -29,7 +29,7 @@ from iotdb.thrift.common.ttypes import TEndPoint, TrainingState, TSStatus
 from iotdb.thrift.confignode import IConfigNodeRPCService
 from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq,
                                             TUpdateModelStateReq)
-from iotdb.thrift.datanode import IDataNodeRPCService
+from iotdb.thrift.datanode import IMLNodeInternalRPCService
 from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
                                           TFetchTimeseriesResp,
                                           TRecordModelMetricsReq)
@@ -39,8 +39,8 @@ from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
 
 class ClientManager(object):
     def __init__(self):
-        self.__data_node_endpoint = config.get_mn_target_data_node()
-        self.__config_node_endpoint = config.get_mn_target_config_node()
+        self.__data_node_endpoint = descriptor.get_config().get_mn_target_data_node()
+        self.__config_node_endpoint = descriptor.get_config().get_mn_target_config_node()
 
     def borrow_data_node_client(self):
         return DataNodeClient(host=self.__data_node_endpoint.ip,
@@ -120,18 +120,14 @@ class DataNodeClient(object):
                 raise e
 
         protocol = TBinaryProtocol.TBinaryProtocol(transport)
-        self.__client = IDataNodeRPCService.Client(protocol)
+        self.__client = IMLNodeInternalRPCService.Client(protocol)
 
     def fetch_timeseries(self,
-                         session_id: int,
-                         statement_id: int,
                          query_expressions: list = [],
                          query_filter: str = None,
                          fetch_size: int = DEFAULT_FETCH_SIZE,
                          timeout: int = DEFAULT_TIMEOUT) -> TFetchTimeseriesResp:
         req = TFetchTimeseriesReq(
-            sessionId=session_id,
-            statementId=statement_id,
             queryExpressions=query_expressions,
             queryFilter=query_filter,
             fetchSize=fetch_size,
@@ -147,8 +143,8 @@ class DataNodeClient(object):
     def record_model_metrics(self,
                              model_id: str,
                              trial_id: str,
-                             metrics: list = [],
-                             values: list = []) -> None:
+                             metrics: list,
+                             values: list) -> None:
         req = TRecordModelMetricsReq(
             modelId=model_id,
             trialId=trial_id,
@@ -186,6 +182,7 @@ class ConfigNodeClient(object):
         if self.__config_leader is not None:
             try:
                 self.__connect(self.__config_leader)
+                return
             except TException:
                 logger.warn("The current node {} may have been down, try next node", self.__config_leader)
                 self.__config_leader = None
@@ -200,6 +197,7 @@ class ConfigNodeClient(object):
             try_endpoint = self.__config_nodes[self.__cursor]
             try:
                 self.__connect(try_endpoint)
+                return
             except TException:
                 logger.warn("The current node {} may have been down, try next node", try_endpoint)
 
@@ -217,7 +215,7 @@ class ConfigNodeClient(object):
             except TTransport.TTransportException as e:
                 logger.exception("TTransportException!", exc_info=e)
 
-        protocol = TCompactProtocol.TBinaryProtocol(transport)
+        protocol = TBinaryProtocol.TBinaryProtocol(transport)
         self.__client = IConfigNodeRPCService.Client(protocol)
 
     def __wait_and_reconnect(self) -> None:
@@ -246,12 +244,12 @@ class ConfigNodeClient(object):
 
     def update_model_state(self,
                            model_id: str,
-                           trial_id: str,
-                           training_state: TrainingState) -> None:
+                           training_state: TrainingState,
+                           best_trail_id: str = None) -> None:
         req = TUpdateModelStateReq(
             modelId=model_id,
-            trialId=trial_id,
-            trainingState=training_state
+            state=training_state,
+            bestTrailId=best_trail_id
         )
         for i in range(0, self.__RETRY_NUM):
             try:
@@ -275,7 +273,7 @@ class ConfigNodeClient(object):
             model_info = {}
         req = TUpdateModelInfoReq(
             modelId=model_id,
-            trialId=trial_id,
+            trailId=trial_id,
             modelInfo={k: str(v) for k, v in model_info.items()},
         )
 
diff --git a/mlnode/iotdb/mlnode/config.py b/mlnode/iotdb/mlnode/config.py
index e59338209a..109452eab5 100644
--- a/mlnode/iotdb/mlnode/config.py
+++ b/mlnode/iotdb/mlnode/config.py
@@ -44,7 +44,7 @@ class MLNodeConfig(object):
         self.__mn_target_config_node: TEndPoint = TEndPoint("127.0.0.1", 10710)
 
         # Target DataNode to be connected by MLNode
-        self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10730)
+        self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10780)
 
     def get_mn_rpc_address(self) -> str:
         return self.__mn_rpc_address
@@ -86,9 +86,8 @@ class MLNodeConfig(object):
 class MLNodeDescriptor(object):
     def __init__(self):
         self.__config = MLNodeConfig()
-        self.__load_config_from_file()
 
-    def __load_config_from_file(self) -> None:
+    def load_config_from_file(self) -> None:
         conf_file = os.path.join(os.getcwd(), MLNODE_CONF_DIRECTORY_NAME, MLNODE_CONF_FILE_NAME)
         if not os.path.exists(conf_file):
             logger.info("Cannot find MLNode config file '{}', use default configuration.".format(conf_file))
@@ -113,7 +112,7 @@ class MLNodeDescriptor(object):
                 self.__config.set_mn_model_storage_dir(file_configs.mn_model_storage_dir)
 
             if file_configs.mn_model_storage_cache_size is not None:
-                self.__config.set_mn_model_storage_cachesize(file_configs.mn_model_storage_cache_size)
+                self.__config.set_mn_model_storage_cache_size(file_configs.mn_model_storage_cache_size)
 
             if file_configs.mn_target_config_node is not None:
                 self.__config.set_mn_target_config_node(file_configs.mn_target_config_node)
@@ -129,4 +128,5 @@ class MLNodeDescriptor(object):
         return self.__config
 
 
-config = MLNodeDescriptor().get_config()
+# initialize a singleton
+descriptor = MLNodeDescriptor()
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index e0be2a7b63..3bffa06526 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -31,9 +31,3 @@ class TSStatusCode(Enum):
 
     def get_status_code(self) -> int:
         return self.value
-
-
-class ModelState(Enum):
-    RUNNING = 'running'
-    FINISHED = 'finished'
-    FAILED = 'failed'
diff --git a/mlnode/iotdb/mlnode/data_access/enums.py b/mlnode/iotdb/mlnode/data_access/enums.py
index d21a9f69c4..e7f5417b3d 100644
--- a/mlnode/iotdb/mlnode/data_access/enums.py
+++ b/mlnode/iotdb/mlnode/data_access/enums.py
@@ -27,3 +27,6 @@ class DatasetType(Enum):
 
     def __eq__(self, other: str) -> bool:
         return self.value == other
+
+    def __hash__(self) -> int:
+        return hash(self.value)
diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py b/mlnode/iotdb/mlnode/data_access/offline/source.py
index a63371ec7a..0422bb373d 100644
--- a/mlnode/iotdb/mlnode/data_access/offline/source.py
+++ b/mlnode/iotdb/mlnode/data_access/offline/source.py
@@ -74,8 +74,8 @@ class ThriftDataSource(DataSource):
 
         try:
             res = data_client.fetch_timeseries(
-                queryExpressions=self.query_expressions,
-                queryFilter=self.query_filter,
+                query_expressions=self.query_expressions,
+                query_filter=self.query_filter,
             )
         except Exception:
             raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}'
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index e43f26c226..1a6e3eb90a 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -19,7 +19,6 @@
 from iotdb.mlnode.algorithm.factory import create_forecast_model
 from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.data_access.factory import create_forecast_dataset
-from iotdb.mlnode.log import logger
 from iotdb.mlnode.parser import parse_training_request
 from iotdb.mlnode.process.manager import TaskManager
 from iotdb.mlnode.util import get_status
@@ -37,29 +36,26 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
         return get_status(TSStatusCode.SUCCESS_STATUS, "")
 
     def createTrainingTask(self, req: TCreateTrainingTaskReq):
-        # parse request stage (check required config and config type)
-        data_config, model_config, task_config = parse_training_request(req)
-
-        # create model stage (check model config legitimacy)
+        task = None
         try:
+            # parse request, check required config and config type
+            data_config, model_config, task_config = parse_training_request(req)
+
+            # create model & check model config legitimacy
             model, model_config = create_forecast_model(**model_config)
-        except Exception as e:  # Create model failed
-            return get_status(TSStatusCode.FAIL_STATUS, str(e))
-        logger.info('model config: ' + str(model_config))
 
-        # create data stage (check data config legitimacy)
-        try:
+            # create dataset & check data config legitimacy
             dataset, data_config = create_forecast_dataset(**data_config)
-        except Exception as e:  # Create data failed
-            return get_status(TSStatusCode.FAIL_STATUS, str(e))
-        logger.info('data config: ' + str(data_config))
-
-        # create task stage (check task config legitimacy)
 
-        # submit task stage (check resource and decide pending/start)
-        self.__task_manager.submit_training_task(task_config, model_config, model, dataset)
+            # create task & check task config legitimacy
+            task = self.__task_manager.create_training_task(dataset, model, model_config, task_config)
 
-        return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task')
+            return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task')
+        except Exception as e:
+            return get_status(TSStatusCode.FAIL_STATUS, str(e))
+        finally:
+            # submit task stage & check resource and decide pending/start
+            self.__task_manager.submit_training_task(task)
 
     def forecast(self, req: TForecastReq):
         status = get_status(TSStatusCode.SUCCESS_STATUS, "")
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
index 236032b9a0..c052cd5050 100644
--- a/mlnode/iotdb/mlnode/parser.py
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -91,8 +91,9 @@ class _ConfigParser(argparse.ArgumentParser):
  - output_vars: number of output variables
 """
 _data_config_parser = _ConfigParser()
-_data_config_parser.add_argument('--source_type', type=str, required=True)
-_data_config_parser.add_argument('--dataset_type', type=DatasetType, required=True)
+_data_config_parser.add_argument('--source_type', type=str, default="thrift")
+_data_config_parser.add_argument('--dataset_type', type=DatasetType, default=DatasetType.WINDOW,
+                                 choices=list(DatasetType))
 _data_config_parser.add_argument('--filename', type=str, default='')
 _data_config_parser.add_argument('--query_expressions', type=str, nargs='*', default=[])
 _data_config_parser.add_argument('--query_filter', type=str, default='')
@@ -183,6 +184,8 @@ def parse_training_request(req: TCreateTrainingTaskReq):
         task_config: configurations related to task
     """
     config = req.modelConfigs
+    config.update(model_name=config['model_type'])
+    config.update(task_class=config['model_task'])
     config.update(model_id=req.modelId)
     config.update(tuning=req.isAuto)
     config.update(query_expressions=req.queryExpressions)
diff --git a/mlnode/iotdb/mlnode/process/manager.py b/mlnode/iotdb/mlnode/process/manager.py
index bfb035f27b..0af0353973 100644
--- a/mlnode/iotdb/mlnode/process/manager.py
+++ b/mlnode/iotdb/mlnode/process/manager.py
@@ -18,7 +18,11 @@
 
 import multiprocessing as mp
 
+from torch import nn
+from torch.utils.data import Dataset
+
 from iotdb.mlnode.log import logger
+from iotdb.mlnode.process.task import ForecastingTrainingTask
 from iotdb.mlnode.process.task_factory import create_task
 
 
@@ -33,22 +37,22 @@ class TaskManager(object):
         self.__pid_info = self.__shared_resource_manager.dict()
         self.__training_process_pool = mp.Pool(pool_num)
 
-    def submit_training_task(self, task_configs, model_configs, model, dataset):
-        assert 'model_id' in task_configs.keys(), 'Task config should contain model_id'
+    def create_training_task(self,
+                             dataset: Dataset,
+                             model: nn.Module,
+                             model_configs: dict,
+                             task_configs: dict) -> ForecastingTrainingTask:
         model_id = task_configs['model_id']
         self.__pid_info[model_id] = self.__shared_resource_manager.dict()
-        try:
-            task = create_task(
-                task_configs,
-                model_configs,
-                model,
-                dataset,
-                self.__pid_info
-            )
-        except Exception as e:
-            logger.exception(e)
-            return e, False
+        return create_task(
+            task_configs,
+            model_configs,
+            model,
+            dataset,
+            self.__pid_info
+        )
 
-        logger.info(f'Task: ({model_id}) - Training process submitted successfully')
-        self.__training_process_pool.apply_async(task, args=())
-        return model_id, True
+    def submit_training_task(self, task: ForecastingTrainingTask) -> None:
+        if task is not None:
+            self.__training_process_pool.apply_async(task, args=())
+            logger.info(f'Task: ({task.model_id}) - Training process submitted successfully')
diff --git a/mlnode/iotdb/mlnode/process/task.py b/mlnode/iotdb/mlnode/process/task.py
index 7fac9cb1c5..85d5b5d2cf 100644
--- a/mlnode/iotdb/mlnode/process/task.py
+++ b/mlnode/iotdb/mlnode/process/task.py
@@ -75,7 +75,7 @@ class _BasicTask(object):
 class ForecastingTrainingTask(_BasicTask):
     def __init__(self, task_configs, model_configs, model, dataset, task_trial_map):
         super(ForecastingTrainingTask, self).__init__(task_configs, model_configs, model, dataset, task_trial_map)
-        model_id = self.task_configs['model_id']
+        self.model_id = self.task_configs['model_id']
         self.tuning = self.task_configs["tuning"]
 
         if self.tuning:  # TODO implement tuning task
@@ -83,7 +83,7 @@ class ForecastingTrainingTask(_BasicTask):
         else:
             self.task_configs['trial_id'] = 'tid_0'  # TODO: set a default trial id
             self.trial = ForecastingTrainingTrial(self.task_configs, self.model, self.model_configs, self.dataset)
-            self.task_trial_map[model_id]['tid_0'] = os.getpid()
+            self.task_trial_map[self.model_id]['tid_0'] = os.getpid()
 
     def __call__(self):
         try:
diff --git a/mlnode/iotdb/mlnode/process/task_factory.py b/mlnode/iotdb/mlnode/process/task_factory.py
index 7b9966a8f3..083b84eba2 100644
--- a/mlnode/iotdb/mlnode/process/task_factory.py
+++ b/mlnode/iotdb/mlnode/process/task_factory.py
@@ -20,7 +20,7 @@
 from iotdb.mlnode.process.task import ForecastingTrainingTask
 
 support_task_types = {
-    'forecast_training_task': ForecastingTrainingTask
+    'forecast': ForecastingTrainingTask
 }
 
 
diff --git a/mlnode/iotdb/mlnode/process/trial.py b/mlnode/iotdb/mlnode/process/trial.py
index f8671b4657..9852e3ffb4 100644
--- a/mlnode/iotdb/mlnode/process/trial.py
+++ b/mlnode/iotdb/mlnode/process/trial.py
@@ -23,11 +23,11 @@ import torch
 import torch.nn as nn
 from torch.utils.data import DataLoader, Dataset
 
-from iotdb.mlnode.algorithm.metric import all_metrics
+from iotdb.mlnode.algorithm.metric import MAE, MSE, all_metrics
 from iotdb.mlnode.client import client_manager
 from iotdb.mlnode.log import logger
 from iotdb.mlnode.storage import model_storage
-from iotdb.mlnode.constant import ModelState
+from iotdb.thrift.common.ttypes import TrainingState
 
 
 def _parse_trial_config(**kwargs):
@@ -188,8 +188,8 @@ class ForecastingTrainingTrial(BasicTrial):
 
             val_loss.append(loss.item())
             for name in self.metric_names:
-                value = eval(name)(outputs.detach().cpu().numpy(),
-                                   batch_y.detach().cpu().numpy())
+                metric = eval(name)()
+                value = metric(outputs.detach().cpu().numpy(), batch_y.detach().cpu().numpy())
                 metrics_dict[name].append(value)
 
         for name, value_list in metrics_dict.items():
@@ -207,25 +207,32 @@ class ForecastingTrainingTrial(BasicTrial):
         return val_loss, metrics_dict
 
     def start(self) -> float:
-        self.confignode_client.update_model_state(self.model_id, self.trial_id, ModelState.RUNNING)
-        best_loss = np.inf
-        best_metrics_dict = None
-        for epoch in range(self.epochs):
-            self._train(epoch)
-            val_loss, metrics_dict = self._validate(epoch)
-            if val_loss < best_loss:
-                best_loss = val_loss
-                best_metrics_dict = metrics_dict
-                model_storage.save_model(self.model,
-                                         self.model_configs,
-                                         model_id=self.model_id,
-                                         trial_id=self.trial_id)
-
-        logger.info(f'Trial: ({self.model_id}_{self.trial_id}) - Finished with best model saved successfully')
-
-        self.confignode_client.update_model_state(self.model_id, self.trial_id, ModelState.RUNNING)
-        model_info = {}
-        model_info.update(best_metrics_dict)
-        model_info.update(self.trial_configs)
-        self.confignode_client.update_model_info(self.model_id, self.trial_id, model_info)
-        return best_loss
+        try:
+            self.confignode_client.update_model_state(self.model_id, TrainingState.RUNNING)
+            best_loss = np.inf
+            best_metrics_dict = None
+            model_path = None
+            for epoch in range(self.epochs):
+                self._train(epoch)
+                val_loss, metrics_dict = self._validate(epoch)
+                if val_loss < best_loss:
+                    best_loss = val_loss
+                    best_metrics_dict = metrics_dict
+                    model_path = model_storage.save_model(self.model,
+                                                          self.model_configs,
+                                                          model_id=self.model_id,
+                                                          trial_id=self.trial_id)
+
+            logger.info(f'Trial: ({self.model_id}_{self.trial_id}) - Finished with best model saved successfully')
+
+            model_info = {}
+            model_info.update(best_metrics_dict)
+            model_info.update(self.trial_configs)
+            model_info['model_path'] = model_path
+            self.confignode_client.update_model_info(self.model_id, self.trial_id, model_info)
+            self.confignode_client.update_model_state(self.model_id, TrainingState.FINISHED, self.trial_id)
+            return best_loss
+        except Exception as e:
+            logger.warn(e)
+            self.confignode_client.update_model_state(self.model_id, TrainingState.FAILED)
+            raise e
diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py
index a2c05ea5c3..ae0727cc5a 100644
--- a/mlnode/iotdb/mlnode/service.py
+++ b/mlnode/iotdb/mlnode/service.py
@@ -19,10 +19,10 @@ import threading
 import time
 
 from thrift.protocol import TCompactProtocol
-from thrift.server import TServer
+from thrift.server import TProcessPoolServer
 from thrift.transport import TSocket, TTransport
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.handler import MLNodeRPCServiceHandler
 from iotdb.mlnode.log import logger
 from iotdb.thrift.mlnode import IMLNodeRPCService
@@ -32,11 +32,13 @@ class RPCService(threading.Thread):
     def __init__(self):
         super().__init__()
         processor = IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler())
-        transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(), port=config.get_mn_rpc_port())
+        transport = TSocket.TServerSocket(host=descriptor.get_config().get_mn_rpc_address(),
+                                          port=descriptor.get_config().get_mn_rpc_port())
         transport_factory = TTransport.TFramedTransportFactory()
         protocol_factory = TCompactProtocol.TCompactProtocolFactory()
 
-        self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory)
+        self.__pool_server = TProcessPoolServer.TProcessPoolServer(processor, transport, transport_factory,
+                                                                   protocol_factory)
 
     def run(self) -> None:
         logger.info("The RPC service thread begin to run...")
@@ -45,6 +47,7 @@ class RPCService(threading.Thread):
 
 class MLNode(object):
     def __init__(self):
+        descriptor.load_config_from_file()
         self.__rpc_service = RPCService()
 
     def start(self) -> None:
diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/storage.py
index ee745689b1..78a0be43bf 100644
--- a/mlnode/iotdb/mlnode/storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@ -24,35 +24,36 @@ import torch
 import torch.nn as nn
 from pylru import lrucache
 
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
 from iotdb.mlnode.exception import ModelNotExistError
 
 
 class ModelStorage(object):
     def __init__(self):
-        self.__model_dir = os.path.join(os.getcwd(), config.get_mn_model_storage_dir())
+        self.__model_dir = os.path.join('.', descriptor.get_config().get_mn_model_storage_dir())
         if not os.path.exists(self.__model_dir):
             os.mkdir(self.__model_dir)
 
-        self.__model_cache = lrucache(config.get_mn_model_storage_cache_size())
+        self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
 
     def save_model(self,
                    model: nn.Module,
                    model_config: dict,
                    model_id: str,
-                   trial_id: str) -> None:
+                   trial_id: str) -> str:
         """
         Note: model config for time series should contain 'input_len' and 'input_vars'
         """
         model_dir_path = os.path.join(self.__model_dir, f'{model_id}')
         if not os.path.exists(model_dir_path):
-            os.mkdir(model_dir_path)
+            os.makedirs(model_dir_path)
         model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt')
 
         sample_input = [torch.randn(1, model_config['input_len'], model_config['input_vars'])]
         torch.jit.save(torch.jit.trace(model, sample_input),
                        model_file_path,
                        _extra_files={'model_config': json.dumps(model_config)})
+        return os.path.abspath(model_file_path)
 
     def load_model(self, model_id: str, trial_id: str) -> (torch.jit.ScriptModule, dict):
         """
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index e451d2b25a..5d3a2d670e 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -52,6 +52,6 @@ def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
 
 
 def verify_success(status: TSStatus, err_msg: str) -> None:
-    if status.code != TSStatusCode.SUCCESS_STATUS:
+    if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code():
         logger.warn(err_msg + ", error status is ", status)
         raise RuntimeError(str(status.code) + ": " + status.message)
diff --git a/mlnode/pyproject.toml b/mlnode/pyproject.toml
index 3944e2910d..56290f8d4e 100644
--- a/mlnode/pyproject.toml
+++ b/mlnode/pyproject.toml
@@ -49,6 +49,7 @@ packages = [
 python = "^3.7"
 thrift = "^0.13.0"
 dynaconf = "^3.1.11"
+pylru = "^1.2.1"
 
 [tool.poetry.scripts]
 mlnode = "iotdb.mlnode.script:main"
\ No newline at end of file
diff --git a/mlnode/requirements.txt b/mlnode/requirements.txt
index edd85701ab..c49c8a0189 100644
--- a/mlnode/requirements.txt
+++ b/mlnode/requirements.txt
@@ -20,7 +20,7 @@ pandas>=1.3.5
 numpy>=1.21.4
 apache-iotdb
 poetry
-torch
+torch~=2.0.0
 pylru
 
 thrift~=0.13.0