You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/04/03 13:21:26 UTC
[iotdb] 02/02: make mlnode available
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch mlnode/test
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 833d0619ed19711132042578ae0ea6123bec77c4
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon Apr 3 16:11:31 2023 +0800
make mlnode available
---
mlnode/iotdb/mlnode/algorithm/enums.py | 3 ++
mlnode/iotdb/mlnode/algorithm/factory.py | 1 +
.../mlnode/algorithm/models/forecast/dlinear.py | 3 +-
mlnode/iotdb/mlnode/client.py | 32 ++++++------
mlnode/iotdb/mlnode/config.py | 10 ++--
mlnode/iotdb/mlnode/constant.py | 6 ---
mlnode/iotdb/mlnode/data_access/enums.py | 3 ++
mlnode/iotdb/mlnode/data_access/offline/source.py | 4 +-
mlnode/iotdb/mlnode/handler.py | 32 +++++-------
mlnode/iotdb/mlnode/parser.py | 7 ++-
mlnode/iotdb/mlnode/process/manager.py | 36 +++++++------
mlnode/iotdb/mlnode/process/task.py | 4 +-
mlnode/iotdb/mlnode/process/task_factory.py | 2 +-
mlnode/iotdb/mlnode/process/trial.py | 59 ++++++++++++----------
mlnode/iotdb/mlnode/service.py | 11 ++--
mlnode/iotdb/mlnode/storage.py | 11 ++--
mlnode/iotdb/mlnode/util.py | 2 +-
mlnode/pyproject.toml | 1 +
mlnode/requirements.txt | 2 +-
19 files changed, 122 insertions(+), 107 deletions(-)
diff --git a/mlnode/iotdb/mlnode/algorithm/enums.py b/mlnode/iotdb/mlnode/algorithm/enums.py
index 4b05aa4bf8..2def3751cd 100644
--- a/mlnode/iotdb/mlnode/algorithm/enums.py
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@ -27,3 +27,6 @@ class ForecastTaskType(Enum):
def __eq__(self, other: str) -> bool:
return self.value == other
+
+ def __hash__(self) -> int:
+ return hash(self.value)
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py b/mlnode/iotdb/mlnode/algorithm/factory.py
index 92cb01a883..26eab10860 100644
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -19,6 +19,7 @@ import torch.nn as nn
from iotdb.mlnode.algorithm.enums import ForecastTaskType
from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
+from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear
from iotdb.mlnode.exception import BadConfigValueError
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
index fa9ee04e56..966ea20347 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -21,6 +21,7 @@ import math
import torch
import torch.nn as nn
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
from iotdb.mlnode.exception import BadConfigValueError
@@ -65,7 +66,7 @@ class DLinear(nn.Module):
pred_len=96,
input_vars=1,
output_vars=1,
- forecast_type='m', # TODO, support others
+ forecast_task_type=ForecastTaskType.ENDOGENOUS, # TODO, support others
):
super(DLinear, self).__init__()
self.input_len = input_len
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 6c1c549ea1..724d517316 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -21,7 +21,7 @@ from thrift.protocol import TBinaryProtocol, TCompactProtocol
from thrift.Thrift import TException
from thrift.transport import TSocket, TTransport
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.log import logger
from iotdb.mlnode.util import verify_success
@@ -29,7 +29,7 @@ from iotdb.thrift.common.ttypes import TEndPoint, TrainingState, TSStatus
from iotdb.thrift.confignode import IConfigNodeRPCService
from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq,
TUpdateModelStateReq)
-from iotdb.thrift.datanode import IDataNodeRPCService
+from iotdb.thrift.datanode import IMLNodeInternalRPCService
from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
TFetchTimeseriesResp,
TRecordModelMetricsReq)
@@ -39,8 +39,8 @@ from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
class ClientManager(object):
def __init__(self):
- self.__data_node_endpoint = config.get_mn_target_data_node()
- self.__config_node_endpoint = config.get_mn_target_config_node()
+ self.__data_node_endpoint = descriptor.get_config().get_mn_target_data_node()
+ self.__config_node_endpoint = descriptor.get_config().get_mn_target_config_node()
def borrow_data_node_client(self):
return DataNodeClient(host=self.__data_node_endpoint.ip,
@@ -120,18 +120,14 @@ class DataNodeClient(object):
raise e
protocol = TBinaryProtocol.TBinaryProtocol(transport)
- self.__client = IDataNodeRPCService.Client(protocol)
+ self.__client = IMLNodeInternalRPCService.Client(protocol)
def fetch_timeseries(self,
- session_id: int,
- statement_id: int,
query_expressions: list = [],
query_filter: str = None,
fetch_size: int = DEFAULT_FETCH_SIZE,
timeout: int = DEFAULT_TIMEOUT) -> TFetchTimeseriesResp:
req = TFetchTimeseriesReq(
- sessionId=session_id,
- statementId=statement_id,
queryExpressions=query_expressions,
queryFilter=query_filter,
fetchSize=fetch_size,
@@ -147,8 +143,8 @@ class DataNodeClient(object):
def record_model_metrics(self,
model_id: str,
trial_id: str,
- metrics: list = [],
- values: list = []) -> None:
+ metrics: list,
+ values: list) -> None:
req = TRecordModelMetricsReq(
modelId=model_id,
trialId=trial_id,
@@ -186,6 +182,7 @@ class ConfigNodeClient(object):
if self.__config_leader is not None:
try:
self.__connect(self.__config_leader)
+ return
except TException:
logger.warn("The current node {} may have been down, try next node", self.__config_leader)
self.__config_leader = None
@@ -200,6 +197,7 @@ class ConfigNodeClient(object):
try_endpoint = self.__config_nodes[self.__cursor]
try:
self.__connect(try_endpoint)
+ return
except TException:
logger.warn("The current node {} may have been down, try next node", try_endpoint)
@@ -217,7 +215,7 @@ class ConfigNodeClient(object):
except TTransport.TTransportException as e:
logger.exception("TTransportException!", exc_info=e)
- protocol = TCompactProtocol.TBinaryProtocol(transport)
+ protocol = TBinaryProtocol.TBinaryProtocol(transport)
self.__client = IConfigNodeRPCService.Client(protocol)
def __wait_and_reconnect(self) -> None:
@@ -246,12 +244,12 @@ class ConfigNodeClient(object):
def update_model_state(self,
model_id: str,
- trial_id: str,
- training_state: TrainingState) -> None:
+ training_state: TrainingState,
+ best_trail_id: str = None) -> None:
req = TUpdateModelStateReq(
modelId=model_id,
- trialId=trial_id,
- trainingState=training_state
+ state=training_state,
+ bestTrailId=best_trail_id
)
for i in range(0, self.__RETRY_NUM):
try:
@@ -275,7 +273,7 @@ class ConfigNodeClient(object):
model_info = {}
req = TUpdateModelInfoReq(
modelId=model_id,
- trialId=trial_id,
+ trailId=trial_id,
modelInfo={k: str(v) for k, v in model_info.items()},
)
diff --git a/mlnode/iotdb/mlnode/config.py b/mlnode/iotdb/mlnode/config.py
index e59338209a..109452eab5 100644
--- a/mlnode/iotdb/mlnode/config.py
+++ b/mlnode/iotdb/mlnode/config.py
@@ -44,7 +44,7 @@ class MLNodeConfig(object):
self.__mn_target_config_node: TEndPoint = TEndPoint("127.0.0.1", 10710)
# Target DataNode to be connected by MLNode
- self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10730)
+ self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10780)
def get_mn_rpc_address(self) -> str:
return self.__mn_rpc_address
@@ -86,9 +86,8 @@ class MLNodeConfig(object):
class MLNodeDescriptor(object):
def __init__(self):
self.__config = MLNodeConfig()
- self.__load_config_from_file()
- def __load_config_from_file(self) -> None:
+ def load_config_from_file(self) -> None:
conf_file = os.path.join(os.getcwd(), MLNODE_CONF_DIRECTORY_NAME, MLNODE_CONF_FILE_NAME)
if not os.path.exists(conf_file):
logger.info("Cannot find MLNode config file '{}', use default configuration.".format(conf_file))
@@ -113,7 +112,7 @@ class MLNodeDescriptor(object):
self.__config.set_mn_model_storage_dir(file_configs.mn_model_storage_dir)
if file_configs.mn_model_storage_cache_size is not None:
- self.__config.set_mn_model_storage_cachesize(file_configs.mn_model_storage_cache_size)
+ self.__config.set_mn_model_storage_cache_size(file_configs.mn_model_storage_cache_size)
if file_configs.mn_target_config_node is not None:
self.__config.set_mn_target_config_node(file_configs.mn_target_config_node)
@@ -129,4 +128,5 @@ class MLNodeDescriptor(object):
return self.__config
-config = MLNodeDescriptor().get_config()
+# initialize a singleton
+descriptor = MLNodeDescriptor()
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index e0be2a7b63..3bffa06526 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -31,9 +31,3 @@ class TSStatusCode(Enum):
def get_status_code(self) -> int:
return self.value
-
-
-class ModelState(Enum):
- RUNNING = 'running'
- FINISHED = 'finished'
- FAILED = 'failed'
diff --git a/mlnode/iotdb/mlnode/data_access/enums.py b/mlnode/iotdb/mlnode/data_access/enums.py
index d21a9f69c4..e7f5417b3d 100644
--- a/mlnode/iotdb/mlnode/data_access/enums.py
+++ b/mlnode/iotdb/mlnode/data_access/enums.py
@@ -27,3 +27,6 @@ class DatasetType(Enum):
def __eq__(self, other: str) -> bool:
return self.value == other
+
+ def __hash__(self) -> int:
+ return hash(self.value)
diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py b/mlnode/iotdb/mlnode/data_access/offline/source.py
index a63371ec7a..0422bb373d 100644
--- a/mlnode/iotdb/mlnode/data_access/offline/source.py
+++ b/mlnode/iotdb/mlnode/data_access/offline/source.py
@@ -74,8 +74,8 @@ class ThriftDataSource(DataSource):
try:
res = data_client.fetch_timeseries(
- queryExpressions=self.query_expressions,
- queryFilter=self.query_filter,
+ query_expressions=self.query_expressions,
+ query_filter=self.query_filter,
)
except Exception:
raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}'
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index e43f26c226..1a6e3eb90a 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -19,7 +19,6 @@
from iotdb.mlnode.algorithm.factory import create_forecast_model
from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.data_access.factory import create_forecast_dataset
-from iotdb.mlnode.log import logger
from iotdb.mlnode.parser import parse_training_request
from iotdb.mlnode.process.manager import TaskManager
from iotdb.mlnode.util import get_status
@@ -37,29 +36,26 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
return get_status(TSStatusCode.SUCCESS_STATUS, "")
def createTrainingTask(self, req: TCreateTrainingTaskReq):
- # parse request stage (check required config and config type)
- data_config, model_config, task_config = parse_training_request(req)
-
- # create model stage (check model config legitimacy)
+ task = None
try:
+ # parse request, check required config and config type
+ data_config, model_config, task_config = parse_training_request(req)
+
+ # create model & check model config legitimacy
model, model_config = create_forecast_model(**model_config)
- except Exception as e: # Create model failed
- return get_status(TSStatusCode.FAIL_STATUS, str(e))
- logger.info('model config: ' + str(model_config))
- # create data stage (check data config legitimacy)
- try:
+ # create dataset & check data config legitimacy
dataset, data_config = create_forecast_dataset(**data_config)
- except Exception as e: # Create data failed
- return get_status(TSStatusCode.FAIL_STATUS, str(e))
- logger.info('data config: ' + str(data_config))
-
- # create task stage (check task config legitimacy)
- # submit task stage (check resource and decide pending/start)
- self.__task_manager.submit_training_task(task_config, model_config, model, dataset)
+ # create task & check task config legitimacy
+ task = self.__task_manager.create_training_task(dataset, model, model_config, task_config)
- return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task')
+ return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task')
+ except Exception as e:
+ return get_status(TSStatusCode.FAIL_STATUS, str(e))
+ finally:
+ # submit task stage & check resource and decide pending/start
+ self.__task_manager.submit_training_task(task)
def forecast(self, req: TForecastReq):
status = get_status(TSStatusCode.SUCCESS_STATUS, "")
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
index 236032b9a0..c052cd5050 100644
--- a/mlnode/iotdb/mlnode/parser.py
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -91,8 +91,9 @@ class _ConfigParser(argparse.ArgumentParser):
- output_vars: number of output variables
"""
_data_config_parser = _ConfigParser()
-_data_config_parser.add_argument('--source_type', type=str, required=True)
-_data_config_parser.add_argument('--dataset_type', type=DatasetType, required=True)
+_data_config_parser.add_argument('--source_type', type=str, default="thrift")
+_data_config_parser.add_argument('--dataset_type', type=DatasetType, default=DatasetType.WINDOW,
+ choices=list(DatasetType))
_data_config_parser.add_argument('--filename', type=str, default='')
_data_config_parser.add_argument('--query_expressions', type=str, nargs='*', default=[])
_data_config_parser.add_argument('--query_filter', type=str, default='')
@@ -183,6 +184,8 @@ def parse_training_request(req: TCreateTrainingTaskReq):
task_config: configurations related to task
"""
config = req.modelConfigs
+ config.update(model_name=config['model_type'])
+ config.update(task_class=config['model_task'])
config.update(model_id=req.modelId)
config.update(tuning=req.isAuto)
config.update(query_expressions=req.queryExpressions)
diff --git a/mlnode/iotdb/mlnode/process/manager.py b/mlnode/iotdb/mlnode/process/manager.py
index bfb035f27b..0af0353973 100644
--- a/mlnode/iotdb/mlnode/process/manager.py
+++ b/mlnode/iotdb/mlnode/process/manager.py
@@ -18,7 +18,11 @@
import multiprocessing as mp
+from torch import nn
+from torch.utils.data import Dataset
+
from iotdb.mlnode.log import logger
+from iotdb.mlnode.process.task import ForecastingTrainingTask
from iotdb.mlnode.process.task_factory import create_task
@@ -33,22 +37,22 @@ class TaskManager(object):
self.__pid_info = self.__shared_resource_manager.dict()
self.__training_process_pool = mp.Pool(pool_num)
- def submit_training_task(self, task_configs, model_configs, model, dataset):
- assert 'model_id' in task_configs.keys(), 'Task config should contain model_id'
+ def create_training_task(self,
+ dataset: Dataset,
+ model: nn.Module,
+ model_configs: dict,
+ task_configs: dict) -> ForecastingTrainingTask:
model_id = task_configs['model_id']
self.__pid_info[model_id] = self.__shared_resource_manager.dict()
- try:
- task = create_task(
- task_configs,
- model_configs,
- model,
- dataset,
- self.__pid_info
- )
- except Exception as e:
- logger.exception(e)
- return e, False
+ return create_task(
+ task_configs,
+ model_configs,
+ model,
+ dataset,
+ self.__pid_info
+ )
- logger.info(f'Task: ({model_id}) - Training process submitted successfully')
- self.__training_process_pool.apply_async(task, args=())
- return model_id, True
+ def submit_training_task(self, task: ForecastingTrainingTask) -> None:
+ if task is not None:
+ self.__training_process_pool.apply_async(task, args=())
+ logger.info(f'Task: ({task.model_id}) - Training process submitted successfully')
diff --git a/mlnode/iotdb/mlnode/process/task.py b/mlnode/iotdb/mlnode/process/task.py
index 7fac9cb1c5..85d5b5d2cf 100644
--- a/mlnode/iotdb/mlnode/process/task.py
+++ b/mlnode/iotdb/mlnode/process/task.py
@@ -75,7 +75,7 @@ class _BasicTask(object):
class ForecastingTrainingTask(_BasicTask):
def __init__(self, task_configs, model_configs, model, dataset, task_trial_map):
super(ForecastingTrainingTask, self).__init__(task_configs, model_configs, model, dataset, task_trial_map)
- model_id = self.task_configs['model_id']
+ self.model_id = self.task_configs['model_id']
self.tuning = self.task_configs["tuning"]
if self.tuning: # TODO implement tuning task
@@ -83,7 +83,7 @@ class ForecastingTrainingTask(_BasicTask):
else:
self.task_configs['trial_id'] = 'tid_0' # TODO: set a default trial id
self.trial = ForecastingTrainingTrial(self.task_configs, self.model, self.model_configs, self.dataset)
- self.task_trial_map[model_id]['tid_0'] = os.getpid()
+ self.task_trial_map[self.model_id]['tid_0'] = os.getpid()
def __call__(self):
try:
diff --git a/mlnode/iotdb/mlnode/process/task_factory.py b/mlnode/iotdb/mlnode/process/task_factory.py
index 7b9966a8f3..083b84eba2 100644
--- a/mlnode/iotdb/mlnode/process/task_factory.py
+++ b/mlnode/iotdb/mlnode/process/task_factory.py
@@ -20,7 +20,7 @@
from iotdb.mlnode.process.task import ForecastingTrainingTask
support_task_types = {
- 'forecast_training_task': ForecastingTrainingTask
+ 'forecast': ForecastingTrainingTask
}
diff --git a/mlnode/iotdb/mlnode/process/trial.py b/mlnode/iotdb/mlnode/process/trial.py
index f8671b4657..9852e3ffb4 100644
--- a/mlnode/iotdb/mlnode/process/trial.py
+++ b/mlnode/iotdb/mlnode/process/trial.py
@@ -23,11 +23,11 @@ import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
-from iotdb.mlnode.algorithm.metric import all_metrics
+from iotdb.mlnode.algorithm.metric import MAE, MSE, all_metrics
from iotdb.mlnode.client import client_manager
from iotdb.mlnode.log import logger
from iotdb.mlnode.storage import model_storage
-from iotdb.mlnode.constant import ModelState
+from iotdb.thrift.common.ttypes import TrainingState
def _parse_trial_config(**kwargs):
@@ -188,8 +188,8 @@ class ForecastingTrainingTrial(BasicTrial):
val_loss.append(loss.item())
for name in self.metric_names:
- value = eval(name)(outputs.detach().cpu().numpy(),
- batch_y.detach().cpu().numpy())
+ metric = eval(name)()
+ value = metric(outputs.detach().cpu().numpy(), batch_y.detach().cpu().numpy())
metrics_dict[name].append(value)
for name, value_list in metrics_dict.items():
@@ -207,25 +207,32 @@ class ForecastingTrainingTrial(BasicTrial):
return val_loss, metrics_dict
def start(self) -> float:
- self.confignode_client.update_model_state(self.model_id, self.trial_id, ModelState.RUNNING)
- best_loss = np.inf
- best_metrics_dict = None
- for epoch in range(self.epochs):
- self._train(epoch)
- val_loss, metrics_dict = self._validate(epoch)
- if val_loss < best_loss:
- best_loss = val_loss
- best_metrics_dict = metrics_dict
- model_storage.save_model(self.model,
- self.model_configs,
- model_id=self.model_id,
- trial_id=self.trial_id)
-
- logger.info(f'Trial: ({self.model_id}_{self.trial_id}) - Finished with best model saved successfully')
-
- self.confignode_client.update_model_state(self.model_id, self.trial_id, ModelState.RUNNING)
- model_info = {}
- model_info.update(best_metrics_dict)
- model_info.update(self.trial_configs)
- self.confignode_client.update_model_info(self.model_id, self.trial_id, model_info)
- return best_loss
+ try:
+ self.confignode_client.update_model_state(self.model_id, TrainingState.RUNNING)
+ best_loss = np.inf
+ best_metrics_dict = None
+ model_path = None
+ for epoch in range(self.epochs):
+ self._train(epoch)
+ val_loss, metrics_dict = self._validate(epoch)
+ if val_loss < best_loss:
+ best_loss = val_loss
+ best_metrics_dict = metrics_dict
+ model_path = model_storage.save_model(self.model,
+ self.model_configs,
+ model_id=self.model_id,
+ trial_id=self.trial_id)
+
+ logger.info(f'Trial: ({self.model_id}_{self.trial_id}) - Finished with best model saved successfully')
+
+ model_info = {}
+ model_info.update(best_metrics_dict)
+ model_info.update(self.trial_configs)
+ model_info['model_path'] = model_path
+ self.confignode_client.update_model_info(self.model_id, self.trial_id, model_info)
+ self.confignode_client.update_model_state(self.model_id, TrainingState.FINISHED, self.trial_id)
+ return best_loss
+ except Exception as e:
+ logger.warn(e)
+ self.confignode_client.update_model_state(self.model_id, TrainingState.FAILED)
+ raise e
diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py
index a2c05ea5c3..ae0727cc5a 100644
--- a/mlnode/iotdb/mlnode/service.py
+++ b/mlnode/iotdb/mlnode/service.py
@@ -19,10 +19,10 @@ import threading
import time
from thrift.protocol import TCompactProtocol
-from thrift.server import TServer
+from thrift.server import TProcessPoolServer
from thrift.transport import TSocket, TTransport
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
from iotdb.mlnode.handler import MLNodeRPCServiceHandler
from iotdb.mlnode.log import logger
from iotdb.thrift.mlnode import IMLNodeRPCService
@@ -32,11 +32,13 @@ class RPCService(threading.Thread):
def __init__(self):
super().__init__()
processor = IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler())
- transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(), port=config.get_mn_rpc_port())
+ transport = TSocket.TServerSocket(host=descriptor.get_config().get_mn_rpc_address(),
+ port=descriptor.get_config().get_mn_rpc_port())
transport_factory = TTransport.TFramedTransportFactory()
protocol_factory = TCompactProtocol.TCompactProtocolFactory()
- self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory)
+ self.__pool_server = TProcessPoolServer.TProcessPoolServer(processor, transport, transport_factory,
+ protocol_factory)
def run(self) -> None:
logger.info("The RPC service thread begin to run...")
@@ -45,6 +47,7 @@ class RPCService(threading.Thread):
class MLNode(object):
def __init__(self):
+ descriptor.load_config_from_file()
self.__rpc_service = RPCService()
def start(self) -> None:
diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/storage.py
index ee745689b1..78a0be43bf 100644
--- a/mlnode/iotdb/mlnode/storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@ -24,35 +24,36 @@ import torch
import torch.nn as nn
from pylru import lrucache
-from iotdb.mlnode.config import config
+from iotdb.mlnode.config import descriptor
from iotdb.mlnode.exception import ModelNotExistError
class ModelStorage(object):
def __init__(self):
- self.__model_dir = os.path.join(os.getcwd(), config.get_mn_model_storage_dir())
+ self.__model_dir = os.path.join('.', descriptor.get_config().get_mn_model_storage_dir())
if not os.path.exists(self.__model_dir):
os.mkdir(self.__model_dir)
- self.__model_cache = lrucache(config.get_mn_model_storage_cache_size())
+ self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
def save_model(self,
model: nn.Module,
model_config: dict,
model_id: str,
- trial_id: str) -> None:
+ trial_id: str) -> str:
"""
Note: model config for time series should contain 'input_len' and 'input_vars'
"""
model_dir_path = os.path.join(self.__model_dir, f'{model_id}')
if not os.path.exists(model_dir_path):
- os.mkdir(model_dir_path)
+ os.makedirs(model_dir_path)
model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt')
sample_input = [torch.randn(1, model_config['input_len'], model_config['input_vars'])]
torch.jit.save(torch.jit.trace(model, sample_input),
model_file_path,
_extra_files={'model_config': json.dumps(model_config)})
+ return os.path.abspath(model_file_path)
def load_model(self, model_id: str, trial_id: str) -> (torch.jit.ScriptModule, dict):
"""
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index e451d2b25a..5d3a2d670e 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -52,6 +52,6 @@ def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
def verify_success(status: TSStatus, err_msg: str) -> None:
- if status.code != TSStatusCode.SUCCESS_STATUS:
+ if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code():
logger.warn(err_msg + ", error status is ", status)
raise RuntimeError(str(status.code) + ": " + status.message)
diff --git a/mlnode/pyproject.toml b/mlnode/pyproject.toml
index 3944e2910d..56290f8d4e 100644
--- a/mlnode/pyproject.toml
+++ b/mlnode/pyproject.toml
@@ -49,6 +49,7 @@ packages = [
python = "^3.7"
thrift = "^0.13.0"
dynaconf = "^3.1.11"
+pylru = "^1.2.1"
[tool.poetry.scripts]
mlnode = "iotdb.mlnode.script:main"
\ No newline at end of file
diff --git a/mlnode/requirements.txt b/mlnode/requirements.txt
index edd85701ab..c49c8a0189 100644
--- a/mlnode/requirements.txt
+++ b/mlnode/requirements.txt
@@ -20,7 +20,7 @@ pandas>=1.3.5
numpy>=1.21.4
apache-iotdb
poetry
-torch
+torch~=2.0.0
pylru
thrift~=0.13.0