You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/04/03 13:37:03 UTC

[iotdb] 01/01: Merge remote-tracking branch 'liuyong/mlnode/test' into mlnode/test

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch mlnode/test
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit f36a02b0debe1962863491a00cf16e306217ffab
Merge: 833d0619ed b08b0c40c5
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon Apr 3 21:35:43 2023 +0800

    Merge remote-tracking branch 'liuyong/mlnode/test' into mlnode/test
    
    # Conflicts:
    #       mlnode/iotdb/mlnode/data_access/enums.py
    #       mlnode/iotdb/mlnode/handler.py
    #       mlnode/iotdb/mlnode/parser.py

 mlnode/iotdb/mlnode/algorithm/enums.py             | 11 +++
 mlnode/iotdb/mlnode/algorithm/factory.py           | 23 +++---
 .../mlnode/algorithm/models/forecast/__init__.py   |  2 +
 .../mlnode/algorithm/models/forecast/dlinear.py    |  2 +-
 .../mlnode/algorithm/models/forecast/nbeats.py     |  9 ++-
 mlnode/iotdb/mlnode/client.py                      |  2 +-
 mlnode/iotdb/mlnode/data_access/enums.py           | 16 +++-
 mlnode/iotdb/mlnode/data_access/factory.py         | 31 +++++---
 mlnode/iotdb/mlnode/exception.py                   |  8 +-
 mlnode/iotdb/mlnode/parser.py                      | 19 +++--
 mlnode/iotdb/mlnode/storage.py                     |  5 +-
 mlnode/test/test_create_forecast_dataset.py        | 89 ++++++++++++++++++++++
 mlnode/test/test_create_forecast_model.py          | 77 +++++++++++++++++++
 mlnode/test/test_model_storage.py                  | 28 +++++--
 mlnode/test/test_parse_training_request.py         | 16 ++--
 15 files changed, 281 insertions(+), 57 deletions(-)

diff --cc mlnode/iotdb/mlnode/algorithm/enums.py
index 2def3751cd,cf57a20083..0f93cf056b
--- a/mlnode/iotdb/mlnode/algorithm/enums.py
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@@ -25,8 -33,8 +33,11 @@@ class ForecastTaskType(Enum)
      def __str__(self):
          return self.value
  
+     def __hash__(self):
+         return hash(self.value)
+ 
      def __eq__(self, other: str) -> bool:
          return self.value == other
 +
 +    def __hash__(self) -> int:
 +        return hash(self.value)
diff --cc mlnode/iotdb/mlnode/algorithm/factory.py
index 26eab10860,e4c5deefe9..37f81c1e68
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@@ -16,10 -16,9 +16,10 @@@
  # under the License.
  #
  import torch.nn as nn
- 
+ from iotdb.mlnode.algorithm.models.forecast import *
  from iotdb.mlnode.algorithm.enums import ForecastTaskType
  from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
 +from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear
  from iotdb.mlnode.exception import BadConfigValueError
  
  
diff --cc mlnode/iotdb/mlnode/storage.py
index 78a0be43bf,a04a30441d..68392be53b
--- a/mlnode/iotdb/mlnode/storage.py
+++ b/mlnode/iotdb/mlnode/storage.py
@@@ -30,11 -30,14 +30,14 @@@ from iotdb.mlnode.exception import Mode
  
  class ModelStorage(object):
      def __init__(self):
 -        self.__model_dir = os.path.join(os.getcwd(), config.get_mn_model_storage_dir())
 +        self.__model_dir = os.path.join('.', descriptor.get_config().get_mn_model_storage_dir())
          if not os.path.exists(self.__model_dir):
-             os.mkdir(self.__model_dir)
+             try:
+                 os.mkdir(self.__model_dir)
+             except PermissionError as e: # TODO: handle storage permission
+                 raise e
  
 -        self.__model_cache = lrucache(config.get_mn_model_storage_cache_size())
 +        self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size())
  
      def save_model(self,
                     model: nn.Module,