You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/20 12:38:48 UTC

[iotdb] branch lmh/forecastTest updated: modify handler

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/forecastTest
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/lmh/forecastTest by this push:
     new 1a1369d441e modify handler
1a1369d441e is described below

commit 1a1369d441eabbb8b8d7510f129d02d1e0d7f8e3
Author: liuminghui233 <54...@qq.com>
AuthorDate: Sat May 20 20:35:21 2023 +0800

    modify handler
---
 mlnode/iotdb/mlnode/handler.py | 34 ++++++++--------------------------
 mlnode/iotdb/mlnode/parser.py  | 11 +++++------
 mlnode/iotdb/mlnode/serde.py   |  2 +-
 3 files changed, 14 insertions(+), 33 deletions(-)

diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 07c4224cf03..a349a99d613 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -18,6 +18,7 @@
 
 from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.data_access.factory import create_forecast_dataset
+from iotdb.mlnode.log import logger
 from iotdb.mlnode.parser import parse_training_request, parse_forecast_request
 from iotdb.mlnode.process.manager import TaskManager
 from iotdb.mlnode.storage import model_storage
@@ -39,6 +40,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
             model_storage.delete_model(req.modelId)
             return get_status(TSStatusCode.SUCCESS_STATUS)
         except Exception as e:
+            logger.warn(e)
             return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
 
     def createTrainingTask(self, req: TCreateTrainingTaskReq):
@@ -57,6 +59,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
 
             return get_status(TSStatusCode.SUCCESS_STATUS)
         except Exception as e:
+            logger.warn(e)
             return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
         finally:
             # submit task stage & check resource and decide pending/start
@@ -65,7 +68,6 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
     def forecast(self, req: TForecastReq):
         model_path, data, pred_length = parse_forecast_request(req)
         model, model_configs = model_storage.load_model(model_path)
-        task = None
         task_configs = {'pred_len': pred_length}
         try:
             task = self.__task_manager.create_forecast_task(
@@ -74,30 +76,10 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
                 data,
                 model_path
             )
-        except Exception as e:
-            print(e)
-            return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
-        finally:
             # submit task stage & check resource and decide pending/start
-            forecast_result = self.__task_manager.submit_forecast_task(task)
-            binary_result = convert_to_binary(forecast_result)
-            binary_result = binary_result[0]
-            resp = TForecastResp(get_status(TSStatusCode.SUCCESS_STATUS), binary_result)
+            forecast_result = convert_to_binary(self.__task_manager.submit_forecast_task(task))
+            resp = TForecastResp(get_status(TSStatusCode.SUCCESS_STATUS), forecast_result)
             return resp
-
-
-# if __name__ == '__main__':
-#     handler = MLNodeRPCServiceHandler()
-#     import pickle
-#     f = open('D:\\undergraduate\\DL\\iotdb\\mlnode\\iotdb\\mlnode\\test_tsdataset.pkl', 'rb')
-#     ts_dataset = pickle.load(f)
-#     req = TForecastReq(
-#         'D:\\undergraduate\\DL\\iotdb\\mlnode\\iotdb\\mlnode\\models\\Model_1\\tid_0.pt',
-#         ts_dataset,
-#         ['root.eg.etth1.s0'],
-#         ['FLOAT'],
-#         {'root.eg.etth1.s0': 0},
-#         192,
-#         'Model_2'
-#     )
-#     handler.forecast(req)
+        except Exception as e:
+            logger.warn(e)
+            return get_status(TSStatusCode.MLNODE_INTERNAL_ERROR, str(e))
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
index bd70426c1bb..71ffbe7c372 100644
--- a/mlnode/iotdb/mlnode/parser.py
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -217,13 +217,12 @@ def parse_training_request(req: TCreateTrainingTaskReq) -> Tuple[Dict, Dict, Dic
 
 def parse_forecast_request(req: TForecastReq):
     model_path = req.modelPath
-    column_name_list = req.columnNameList
-    column_type_list = req.columnTypeList
-    column_name_index_map = req.columnNameIndexMap
-    ts_dataset = req.tsDataset
-    pred_len = req.predLength
+    column_name_list = req.inputColumnNameList
+    column_type_list = req.inputTypeList
+    ts_dataset = req.inputData
+    pred_len = req.predictLength
 
-    data = convert_to_df(column_name_list, column_type_list, column_name_index_map, ts_dataset)
+    data = convert_to_df(column_name_list, column_type_list, None, ts_dataset)
     time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
     full_data = (data, time_stamp)
     return model_path, full_data, pred_len
diff --git a/mlnode/iotdb/mlnode/serde.py b/mlnode/iotdb/mlnode/serde.py
index 6f53ba93e47..4b5e90c2554 100644
--- a/mlnode/iotdb/mlnode/serde.py
+++ b/mlnode/iotdb/mlnode/serde.py
@@ -92,7 +92,7 @@ def convert_to_binary(data_frame: pd.DataFrame):
                 value = value.byteswap()
             binary += value.tobytes()
 
-    return [binary]
+    return binary
 
 
 # convert tsBlock in binary to dataFrame