This is an automated email from the ASF dual-hosted git repository. hui pushed a commit to branch mlnode/test in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit f36a02b0debe1962863491a00cf16e306217ffab Merge: 833d0619ed b08b0c40c5 Author: Minghui Liu <[email protected]> AuthorDate: Mon Apr 3 21:35:43 2023 +0800 Merge remote-tracking branch 'liuyong/mlnode/test' into mlnode/test # Conflicts: # mlnode/iotdb/mlnode/data_access/enums.py # mlnode/iotdb/mlnode/handler.py # mlnode/iotdb/mlnode/parser.py mlnode/iotdb/mlnode/algorithm/enums.py | 11 +++ mlnode/iotdb/mlnode/algorithm/factory.py | 23 +++--- .../mlnode/algorithm/models/forecast/__init__.py | 2 + .../mlnode/algorithm/models/forecast/dlinear.py | 2 +- .../mlnode/algorithm/models/forecast/nbeats.py | 9 ++- mlnode/iotdb/mlnode/client.py | 2 +- mlnode/iotdb/mlnode/data_access/enums.py | 16 +++- mlnode/iotdb/mlnode/data_access/factory.py | 31 +++++--- mlnode/iotdb/mlnode/exception.py | 8 +- mlnode/iotdb/mlnode/parser.py | 19 +++-- mlnode/iotdb/mlnode/storage.py | 5 +- mlnode/test/test_create_forecast_dataset.py | 89 ++++++++++++++++++++++ mlnode/test/test_create_forecast_model.py | 77 +++++++++++++++++++ mlnode/test/test_model_storage.py | 28 +++++-- mlnode/test/test_parse_training_request.py | 16 ++-- 15 files changed, 281 insertions(+), 57 deletions(-) diff --cc mlnode/iotdb/mlnode/algorithm/enums.py index 2def3751cd,cf57a20083..0f93cf056b --- a/mlnode/iotdb/mlnode/algorithm/enums.py +++ b/mlnode/iotdb/mlnode/algorithm/enums.py @@@ -25,8 -33,8 +33,11 @@@ class ForecastTaskType(Enum) def __str__(self): return self.value + def __hash__(self): + return hash(self.value) + def __eq__(self, other: str) -> bool: return self.value == other + + def __hash__(self) -> int: + return hash(self.value) diff --cc mlnode/iotdb/mlnode/algorithm/factory.py index 26eab10860,e4c5deefe9..37f81c1e68 --- a/mlnode/iotdb/mlnode/algorithm/factory.py +++ b/mlnode/iotdb/mlnode/algorithm/factory.py @@@ -16,10 -16,9 +16,10 @@@ # under the License. # import torch.nn as nn - + from iotdb.mlnode.algorithm.models.forecast import * from iotdb.mlnode.algorithm.enums import ForecastTaskType from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models +from iotdb.mlnode.algorithm.models.forecast.dlinear import dlinear from iotdb.mlnode.exception import BadConfigValueError diff --cc mlnode/iotdb/mlnode/storage.py index 78a0be43bf,a04a30441d..68392be53b --- a/mlnode/iotdb/mlnode/storage.py +++ b/mlnode/iotdb/mlnode/storage.py @@@ -30,11 -30,14 +30,14 @@@ from iotdb.mlnode.exception import Mode class ModelStorage(object): def __init__(self): - self.__model_dir = os.path.join(os.getcwd(), config.get_mn_model_storage_dir()) + self.__model_dir = os.path.join('.', descriptor.get_config().get_mn_model_storage_dir()) if not os.path.exists(self.__model_dir): - os.mkdir(self.__model_dir) + try: + os.mkdir(self.__model_dir) + except PermissionError as e: # TODO: handle storage permission + raise e - self.__model_cache = lrucache(config.get_mn_model_storage_cache_size()) + self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size()) def save_model(self, model: nn.Module,
