You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/05/22 02:55:33 UTC
[iotdb] 01/02: fix mlnode
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/forecastTest
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 4c8e850b422b9b826fbb18ee42af6aaa80f25184
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon May 22 10:40:42 2023 +0800
fix mlnode
---
mlnode/iotdb/mlnode/algorithm/metric.py | 2 +-
.../mlnode/algorithm/models/forecast/dlinear.py | 2 +-
mlnode/iotdb/mlnode/client.py | 9 +++----
mlnode/iotdb/mlnode/data_access/offline/source.py | 4 +--
mlnode/iotdb/mlnode/handler.py | 6 ++---
mlnode/iotdb/mlnode/parser.py | 4 +--
mlnode/iotdb/mlnode/process/manager.py | 13 +++++-----
mlnode/iotdb/mlnode/process/task.py | 30 ++++++++--------------
mlnode/iotdb/mlnode/process/trial.py | 4 +--
mlnode/iotdb/mlnode/storage.py | 3 +--
mlnode/test/test_create_forecast_dataset.py | 1 +
mlnode/test/test_create_forecast_model.py | 1 +
mlnode/test/test_model_storage.py | 1 +
mlnode/test/test_serde.py | 3 ++-
14 files changed, 38 insertions(+), 45 deletions(-)
diff --git a/mlnode/iotdb/mlnode/algorithm/metric.py b/mlnode/iotdb/mlnode/algorithm/metric.py
index c32580743ba..5ffc5f24a5c 100644
--- a/mlnode/iotdb/mlnode/algorithm/metric.py
+++ b/mlnode/iotdb/mlnode/algorithm/metric.py
@@ -16,7 +16,7 @@
# under the License.
#
from abc import abstractmethod
-from typing import List, Dict
+from typing import Dict, List
import numpy as np
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
index a2ae149134f..35ba728fa78 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -17,10 +17,10 @@
#
import math
+from typing import Dict, Tuple
import torch
import torch.nn as nn
-from typing import Dict, Tuple
from iotdb.mlnode.algorithm.enums import ForecastTaskType
from iotdb.mlnode.exception import BadConfigValueError
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index bb442cd7b8d..a5bdf109e8f 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -33,13 +33,12 @@ from iotdb.thrift.confignode import IConfigNodeRPCService
from iotdb.thrift.confignode.ttypes import (TUpdateModelInfoReq,
TUpdateModelStateReq)
from iotdb.thrift.datanode import IMLNodeInternalRPCService
-from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
- TRecordModelMetricsReq,
- TFetchMoreDataReq)
+from iotdb.thrift.datanode.ttypes import (TFetchMoreDataReq,
+ TFetchTimeseriesReq,
+ TRecordModelMetricsReq)
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
- TDeleteModelReq,
- TForecastReq)
+ TDeleteModelReq, TForecastReq)
class ClientManager(object):
diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py b/mlnode/iotdb/mlnode/data_access/offline/source.py
index 418c9c7afc0..4f29241913f 100644
--- a/mlnode/iotdb/mlnode/data_access/offline/source.py
+++ b/mlnode/iotdb/mlnode/data_access/offline/source.py
@@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
#
+from typing import List
+
import numpy as np
import pandas as pd
-from typing import List
-
from iotdb.mlnode.client import client_manager
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index a349a99d613..d4b4ca3b034 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -16,19 +16,19 @@
# under the License.
#
+from iotdb.mlnode.config import descriptor
from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.data_access.factory import create_forecast_dataset
from iotdb.mlnode.log import logger
-from iotdb.mlnode.parser import parse_training_request, parse_forecast_request
+from iotdb.mlnode.parser import parse_forecast_request, parse_training_request
from iotdb.mlnode.process.manager import TaskManager
+from iotdb.mlnode.serde import convert_to_binary
from iotdb.mlnode.storage import model_storage
from iotdb.mlnode.util import get_status
-from iotdb.mlnode.config import descriptor
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
TDeleteModelReq, TForecastReq,
TForecastResp)
-from iotdb.mlnode.serde import convert_to_binary
class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
index 71ffbe7c372..34bb7eb65a9 100644
--- a/mlnode/iotdb/mlnode/parser.py
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -24,8 +24,8 @@ from typing import Dict, List, Tuple
from iotdb.mlnode.algorithm.enums import ForecastTaskType
from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
-from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TForecastReq
from iotdb.mlnode.serde import convert_to_df
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TForecastReq
class _ConfigParser(argparse.ArgumentParser):
@@ -222,7 +222,7 @@ def parse_forecast_request(req: TForecastReq):
ts_dataset = req.inputData
pred_len = req.predictLength
- data = convert_to_df(column_name_list, column_type_list, None, ts_dataset)
+ data = convert_to_df(column_name_list, column_type_list, None, [ts_dataset])
time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
full_data = (data, time_stamp)
return model_path, full_data, pred_len
diff --git a/mlnode/iotdb/mlnode/process/manager.py b/mlnode/iotdb/mlnode/process/manager.py
index ee7df83acae..5957dadbed2 100644
--- a/mlnode/iotdb/mlnode/process/manager.py
+++ b/mlnode/iotdb/mlnode/process/manager.py
@@ -17,13 +17,15 @@
#
import multiprocessing as mp
-import pandas as pd
-
from typing import Dict, Union
+
+import pandas as pd
from torch.utils.data import Dataset
+
from iotdb.mlnode.log import logger
-from iotdb.mlnode.process.task import ForecastingSingleTrainingTask, ForecastingTuningTrainingTask, \
- ForecastingInferenceTask
+from iotdb.mlnode.process.task import (ForecastingInferenceTask,
+ ForecastingSingleTrainingTask,
+ ForecastingTuningTrainingTask)
class TaskManager(object):
@@ -102,6 +104,5 @@ class TaskManager(object):
read_pipe, send_pipe = mp.Pipe()
if task is not None:
self.__training_process_pool.apply_async(task, args=(send_pipe,))
- logger.info(f'Forecasting process submitted successfully')
- # task(send_pipe)
+ logger.info('Forecasting process submitted successfully')
return read_pipe.recv()
diff --git a/mlnode/iotdb/mlnode/process/task.py b/mlnode/iotdb/mlnode/process/task.py
index ff0ee4c5a18..85611a948fb 100644
--- a/mlnode/iotdb/mlnode/process/task.py
+++ b/mlnode/iotdb/mlnode/process/task.py
@@ -17,25 +17,23 @@
#
import os
-import pandas as pd
-import numpy as np
-
from abc import abstractmethod
+from multiprocessing.connection import Connection
from typing import Dict, Tuple
+import numpy as np
import optuna
+import pandas as pd
import torch
-from torch import nn
from torch.utils.data import Dataset
-from multiprocessing.connection import Connection
-from iotdb.mlnode.log import logger
-from iotdb.mlnode.process.trial import ForecastingTrainingTrial
from iotdb.mlnode.algorithm.factory import create_forecast_model
from iotdb.mlnode.client import client_manager
from iotdb.mlnode.config import descriptor
-from iotdb.thrift.common.ttypes import TrainingState
+from iotdb.mlnode.log import logger
+from iotdb.mlnode.process.trial import ForecastingTrainingTrial
from iotdb.mlnode.storage import model_storage
+from iotdb.thrift.common.ttypes import TrainingState
class ForestingTrainingObjective:
@@ -229,7 +227,7 @@ class ForecastingInferenceTask(_BasicInferenceTask):
task_configs: Dict,
model_configs: Dict,
pid_info: Dict,
- data:Tuple,
+ data: Tuple,
model_path: str
):
super().__init__(task_configs, model_configs, pid_info, data, model_path)
@@ -247,17 +245,14 @@ class ForecastingInferenceTask(_BasicInferenceTask):
current_pred_len = 0
while current_pred_len < self.pred_len:
current_data = full_data[:, -self.input_len:, :]
- # current_data_stamp = timefeatures.time_features(full_data_stamp.iloc[-self.input_len:, :])[None, :] # batch
current_data = torch.Tensor(current_data)
output_data = self.model(current_data).detach().numpy()
full_data = np.concatenate([full_data, output_data], axis=1)
- # full_data_stamp = pd.concat([full_data_stamp, self.generate_future_mark(full_data_stamp, self.pred_len)])
current_pred_len += self.model_pred_len
full_data_stamp = self.generate_future_mark(full_data_stamp, self.pred_len)
- # ret_data = np.concatenate([full_data_stamp, full_data[0, -self.pred_len:, :]], axis=1)
- # ret_data = ret_data[-
- ret_data = pd.concat([pd.DataFrame(full_data_stamp.astype(np.int64)), pd.DataFrame(full_data[0, -self.pred_len:, :])], axis=1)
- # ret_data = pd.DataFrame(ret_data)
+ ret_data = pd.concat(
+ [pd.DataFrame(full_data_stamp.astype(np.int64)),
+ pd.DataFrame(full_data[0, -self.pred_len:, :]).astype(np.double)], axis=1)
ret_data.columns = list(np.arange(0, C + 1))
pipe.send(ret_data)
@@ -284,12 +279,9 @@ class ForecastingInferenceTask(_BasicInferenceTask):
data = data[None, :] # add batch dim
return data, data_stamp
- def generate_future_mark(self, data_stamp:pd.DataFrame, future_len: int) -> pd.DatetimeIndex:
+ def generate_future_mark(self, data_stamp: pd.DataFrame, future_len: int) -> pd.DatetimeIndex:
time_deltas = data_stamp.diff().dropna()
mean_timedelta = time_deltas.mean()[0]
extrapolated_timestamp = pd.date_range(data_stamp.values[0][0], periods=future_len,
freq=mean_timedelta)
return extrapolated_timestamp[:, None]
-
-if __name__ == '__main__':
- pass
diff --git a/mlnode/iotdb/mlnode/process/trial.py b/mlnode/iotdb/mlnode/process/trial.py
index 973fb31f6f0..7fc34477686 100644
--- a/mlnode/iotdb/mlnode/process/trial.py
+++ b/mlnode/iotdb/mlnode/process/trial.py
@@ -22,12 +22,10 @@ from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
-from torch.nn.modules import loss
-from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from iotdb.mlnode.algorithm.metric import all_metrics, build_metrics
-from iotdb.mlnode.client import client_manager, DataNodeClient, ConfigNodeClient
+from iotdb.mlnode.client import client_manager
from iotdb.mlnode.log import logger
from iotdb.mlnode.storage import model_storage
from iotdb.thrift.common.ttypes import TrainingState
diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/storage.py
index db3b081a2f1..4c038786174 100644
--- a/mlnode/iotdb/mlnode/storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@ -19,12 +19,11 @@
import json
import os
import shutil
+import threading
from typing import Dict, Tuple
import torch
import torch.nn as nn
-import threading
-
from pylru import lrucache
from iotdb.mlnode.config import descriptor
diff --git a/mlnode/test/test_create_forecast_dataset.py b/mlnode/test/test_create_forecast_dataset.py
index c9e506dfefb..49e2d177e8f 100644
--- a/mlnode/test/test_create_forecast_dataset.py
+++ b/mlnode/test/test_create_forecast_dataset.py
@@ -18,6 +18,7 @@
import os
import requests
+
from iotdb.mlnode.data_access.enums import DatasetType, DataSourceType
from iotdb.mlnode.data_access.factory import create_forecast_dataset
diff --git a/mlnode/test/test_create_forecast_model.py b/mlnode/test/test_create_forecast_model.py
index a100d01c1ca..ed439fd9c05 100644
--- a/mlnode/test/test_create_forecast_model.py
+++ b/mlnode/test/test_create_forecast_model.py
@@ -16,6 +16,7 @@
# under the License.
#
import torch
+
from iotdb.mlnode.algorithm.enums import ForecastTaskType
from iotdb.mlnode.algorithm.factory import create_forecast_model
from iotdb.mlnode.exception import BadConfigValueError
diff --git a/mlnode/test/test_model_storage.py b/mlnode/test/test_model_storage.py
index 90d6bebdc1d..9a6399b2ea2 100644
--- a/mlnode/test/test_model_storage.py
+++ b/mlnode/test/test_model_storage.py
@@ -20,6 +20,7 @@ import os
import time
import torch.nn as nn
+
from iotdb.mlnode.config import descriptor
from iotdb.mlnode.exception import ModelNotExistError
from iotdb.mlnode.storage import model_storage
diff --git a/mlnode/test/test_serde.py b/mlnode/test/test_serde.py
index 3454f41ddec..c05083be417 100644
--- a/mlnode/test/test_serde.py
+++ b/mlnode/test/test_serde.py
@@ -18,9 +18,10 @@
import numpy as np
import pandas as pd
-from iotdb.mlnode.serde import convert_to_df
from pandas.testing import assert_frame_equal
+from iotdb.mlnode.serde import convert_to_df
+
device_id = "root.wt1"
ts_path_lst = [