You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/04/03 13:37:03 UTC
[iotdb] 01/01: Merge remote-tracking branch 'liuyong/mlnode/test' into mlnode/test
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch mlnode/test
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit f36a02b0debe1962863491a00cf16e306217ffab
Merge: 833d0619ed b08b0c40c5
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon Apr 3 21:35:43 2023 +0800
Merge remote-tracking branch 'liuyong/mlnode/test' into mlnode/test
# Conflicts:
# mlnode/iotdb/mlnode/data_access/enums.py
# mlnode/iotdb/mlnode/handler.py
# mlnode/iotdb/mlnode/parser.py
mlnode/iotdb/mlnode/algorithm/enums.py | 11 +++
mlnode/iotdb/mlnode/algorithm/factory.py | 23 +++---
.../mlnode/algorithm/models/forecast/__init__.py | 2 +
.../mlnode/algorithm/models/forecast/dlinear.py | 2 +-
.../mlnode/algorithm/models/forecast/nbeats.py | 9 ++-
mlnode/iotdb/mlnode/client.py | 2 +-
mlnode/iotdb/mlnode/data_access/enums.py | 16 +++-
mlnode/iotdb/mlnode/data_access/factory.py | 31 +++++---
mlnode/iotdb/mlnode/exception.py | 8 +-
mlnode/iotdb/mlnode/parser.py | 19 +++--
mlnode/iotdb/mlnode/storage.py | 5 +-
mlnode/test/test_create_forecast_dataset.py | 89 ++++++++++++++++++++++
mlnode/test/test_create_forecast_model.py | 77 +++++++++++++++++++
mlnode/test/test_model_storage.py | 28 +++++--
mlnode/test/test_parse_training_request.py | 16 ++--
15 files changed, 281 insertions(+), 57 deletions(-)
diff --cc mlnode/iotdb/mlnode/algorithm/enums.py
index 2def3751cd,cf57a20083..0f93cf056b
--- a/mlnode/iotdb/mlnode/algorithm/enums.py
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@@ -25,8 -33,8 +33,11 @@@ class ForecastTaskType(Enum)
def __str__(self):
return self.value
+ def __hash__(self):
+ return hash(self.value)
+
def __eq__(self, other: str) -> bool:
return self.value == other
+
+ def __hash__(self) -> int:
+ return hash(self.value)
diff --cc mlnode/iotdb/mlnode/algorithm/factory.py
index 26eab10860,e4c5deefe9..37f81c1e68
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@@ -16,10 -16,9 +16,10 @@@
# under the License.
#
import torch.nn as nn
-
+ from iotdb.mlnode.algorithm.models.forecast import *
from iotdb.mlnode.algorithm.enums import ForecastTaskType
from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
+from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear
from iotdb.mlnode.exception import BadConfigValueError
diff --cc mlnode/iotdb/mlnode/storage.py
index 78a0be43bf,a04a30441d..68392be53b
--- a/mlnode/iotdb/mlnode/storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@@ -30,11 -30,14 +30,14 @@@ from iotdb.mlnode.exception import Mode
class ModelStorage(object):
def __init__(self):
- self.__model_dir = os.path.join(os.getcwd(), config.get_mn_model_storage_dir())
+ self.__model_dir = os.path.join('.', descriptor.get_config().get_mn_model_storage_dir())
if not os.path.exists(self.__model_dir):
- os.mkdir(self.__model_dir)
+ try:
+ os.mkdir(self.__model_dir)
+ except PermissionError as e: # TODO: handle storage permission
+ raise e
- self.__model_cache = lrucache(config.get_mn_model_storage_cache_size())
+ self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
def save_model(self,
model: nn.Module,