You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/20 12:38:48 UTC
[iotdb] branch lmh/forecastTest updated: modify handler
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/forecastTest
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/lmh/forecastTest by this push:
new 1a1369d441e modify handler
1a1369d441e is described below
commit 1a1369d441eabbb8b8d7510f129d02d1e0d7f8e3
Author: liuminghui233 <54...@qq.com>
AuthorDate: Sat May 20 20:35:21 2023 +0800
modify handler
---
mlnode/iotdb/mlnode/handler.py | 34 ++++++++--------------------------
mlnode/iotdb/mlnode/parser.py | 11 +++++------
mlnode/iotdb/mlnode/serde.py | 2 +-
3 files changed, 14 insertions(+), 33 deletions(-)
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 07c4224cf03..a349a99d613 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -18,6 +18,7 @@
from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.data_access.factory import create_forecast_dataset
+from iotdb.mlnode.log import logger
from iotdb.mlnode.parser import parse_training_request, parse_forecast_request
from iotdb.mlnode.process.manager import TaskManager
from iotdb.mlnode.storage import model_storage
@@ -39,6 +40,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
model_storage.delete_model(req.modelId)
return get_status(TSStatusCode.SUCCESS_STATUS)
except Exception as e:
+ logger.warn(e)
return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
def createTrainingTask(self, req: TCreateTrainingTaskReq):
@@ -57,6 +59,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
return get_status(TSStatusCode.SUCCESS_STATUS)
except Exception as e:
+ logger.warn(e)
return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
finally:
# submit task stage & check resource and decide pending/start
@@ -65,7 +68,6 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
def forecast(self, req: TForecastReq):
model_path, data, pred_length = parse_forecast_request(req)
model, model_configs = model_storage.load_model(model_path)
- task = None
task_configs = {'pred_len': pred_length}
try:
task = self.__task_manager.create_forecast_task(
@@ -74,30 +76,10 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
data,
model_path
)
- except Exception as e:
- print(e)
- return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
- finally:
# submit task stage & check resource and decide pending/start
- forecast_result = self.__task_manager.submit_forecast_task(task)
- binary_result = convert_to_binary(forecast_result)
- binary_result = binary_result[0]
- resp = TForecastResp(get_status(TSStatusCode.SUCCESS_STATUS), binary_result)
+ forecast_result = convert_to_binary(self.__task_manager.submit_forecast_task(task))
+ resp = TForecastResp(get_status(TSStatusCode.SUCCESS_STATUS), forecast_result)
return resp
-
-
-# if __name__ == '__main__':
-# handler = MLNodeRPCServiceHandler()
-# import pickle
-# f = open('D:\\undergraduate\\DL\\iotdb\\mlnode\\iotdb\\mlnode\\test_tsdataset.pkl', 'rb')
-# ts_dataset = pickle.load(f)
-# req = TForecastReq(
-# 'D:\\undergraduate\\DL\\iotdb\\mlnode\\iotdb\\mlnode\\models\\Model_1\\tid_0.pt',
-# ts_dataset,
-# ['root.eg.etth1.s0'],
-# ['FLOAT'],
-# {'root.eg.etth1.s0': 0},
-# 192,
-# 'Model_2'
-# )
-# handler.forecast(req)
+ except Exception as e:
+ logger.warn(e)
+ return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
index bd70426c1bb..71ffbe7c372 100644
--- a/mlnode/iotdb/mlnode/parser.py
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -217,13 +217,12 @@ def parse_training_request(req: TCreateTrainingTaskReq) -> Tuple[Dict, Dict, Dic
def parse_forecast_request(req: TForecastReq):
model_path = req.modelPath
- column_name_list = req.columnNameList
- column_type_list = req.columnTypeList
- column_name_index_map = req.columnNameIndexMap
- ts_dataset = req.tsDataset
- pred_len = req.predLength
+ column_name_list = req.inputColumnNameList
+ column_type_list = req.inputTypeList
+ ts_dataset = req.inputData
+ pred_len = req.predictLength
- data = convert_to_df(column_name_list, column_type_list, column_name_index_map, ts_dataset)
+ data = convert_to_df(column_name_list, column_type_list, None, ts_dataset)
time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
full_data = (data, time_stamp)
return model_path, full_data, pred_len
diff --git a/mlnode/iotdb/mlnode/serde.py b/mlnode/iotdb/mlnode/serde.py
index 6f53ba93e47..4b5e90c2554 100644
--- a/mlnode/iotdb/mlnode/serde.py
+++ b/mlnode/iotdb/mlnode/serde.py
@@ -92,7 +92,7 @@ def convert_to_binary(data_frame: pd.DataFrame):
value = value.byteswap()
binary += value.tobytes()
- return [binary]
+ return binary
# convert tsBlock in binary to dataFrame