Copilot commented on code in PR #16794: URL: https://github.com/apache/iotdb/pull/16794#discussion_r2558408785
########## iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py: ########## @@ -0,0 +1,379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Sktime model configuration module - simplified version +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +@dataclass +class AttributeConfig: + """Base class for attribute configuration""" + + name: str + default: Any + type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple' + low: Union[int, float, None] = None + high: Union[int, float, None] = None + choices: List[str] = field(default_factory=list) + value_type: type = None # Element type for list and tuple + + def validate_value(self, value): + """Validate if the value meets the requirements""" + if self.type == "int": + if not isinstance(value, int): + raise WrongAttributeTypeError(self.name, "int") + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "float": + if not isinstance(value, (int, float)): + raise WrongAttributeTypeError(self.name, "float") + value = float(value) + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "str": + if not isinstance(value, str): + raise WrongAttributeTypeError(self.name, "str") + if self.choices and value not in self.choices: + raise StringRangeException(self.name, value, self.choices) + elif self.type == "bool": + if not isinstance(value, bool): + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + if not isinstance(value, list): + raise WrongAttributeTypeError(self.name, "list") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + elif self.type == "tuple": + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + return True + + def parse(self, string_value: str): + """Parse string value to corresponding type""" + if self.type == "int": + try: + return int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + elif self.type == "float": + try: + return float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + elif self.type == "str": + return string_value + elif self.type == "bool": + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + else: + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return list_value + elif self.type == "tuple": + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return tuple(list_value) + + +# Model configuration definitions - using concise dictionary format +MODEL_CONFIGS = { + "NAIVE_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "pipeline": AttributeConfig( + "pipeline", "last", "str", choices=["last", "mean"] + ), + "sp": AttributeConfig("sp", 1, "int", 1, 5000), + }, + "EXPONENTIAL_SMOOTHING": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "damped_trend": AttributeConfig("damped_trend", False, "bool"), + "initialization_method": AttributeConfig( + "initialization_method", + "estimated", + "str", + choices=["estimated", "heuristic", "legacy-heuristic", "known"], + ), + "optimized": AttributeConfig("optimized", True, "bool"), + "remove_bias": AttributeConfig("remove_bias", False, "bool"), + "use_brute": AttributeConfig("use_brute", False, "bool"), + }, + "ARIMA": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), + "seasonal_order": AttributeConfig( + "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int + ), + "method": AttributeConfig( + "method", + "lbfgs", + "str", + choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], + ), + "maxiter": AttributeConfig("maxiter", 1, "int", 1, 5000), + "suppress_warnings": AttributeConfig("suppress_warnings", True, "bool"), + "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), + "scoring": AttributeConfig( + "scoring", + "mse", + "str", + choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + "with_intercept": AttributeConfig("with_intercept", True, "bool"), + "time_varying_regression": AttributeConfig( + "time_varying_regression", False, "bool" + ), + "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), + "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), + "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), + "measurement_error": AttributeConfig("measurement_error", False, "bool"), + "mle_regression": AttributeConfig("mle_regression", True, "bool"), + "hamilton_representation": AttributeConfig( + "hamilton_representation", False, "bool" + ), + "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), + }, + "STL_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "sp": AttributeConfig("sp", 2, "int", 1, 5000), + "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000), + "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000), + "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000), + "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000), + "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000), + "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000), + "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000), + }, + "GAUSSIAN_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", 1e-2, "float", -1e10, 1e10), + "covars_weight": AttributeConfig("covars_weight", 1.0, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]), + "init_params": AttributeConfig( + "init_params", "stmc", "str", choices=["stmc", "stm"] + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "GMM_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "n_mix": AttributeConfig("n_mix", 1, "int", 1, 5000), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["sperical", "diag", "full", "tied"], Review Comment: Typo in choice value "sperical" should be "spherical". This will cause validation failures when users try to use the correct spelling. ```suggestion choices=["spherical", "diag", "full", "tied"], ``` ########## iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py: ########## @@ -17,440 +17,560 @@ # import concurrent.futures -import json import os import shutil -from collections.abc import Callable -from typing import Dict +from typing import List, Optional -import torch -from torch import nn +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri +from iotdb.ainode.core.model.model_enums import REPO_ID_MAP, ModelCategory, ModelStates from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, ModelInfo, - get_built_in_model_type, - get_model_file_type, ) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy +from iotdb.ainode.core.model.utils import * from iotdb.ainode.core.util.lock import ModelLockPool from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from iotdb.thrift.common.ttypes import TSStatus logger = Logger() -class ModelStorage(object): - def __init__(self): - self._model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() - ) - if not os.path.exists(self._model_dir): - try: - os.makedirs(self._model_dir) - except PermissionError as e: - logger.error(e) - raise e - self._builtin_model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_builtin_models_dir() - ) - if not os.path.exists(self._builtin_model_dir): - try: - os.makedirs(self._builtin_model_dir) - except PermissionError as e: - logger.error(e) - raise e +class ModelStorage: + """Model storage class - unified management of model discovery and registration""" + + def __init__(self, models_dir: str): + self.models_dir = Path(models_dir) + # Unified storage: category -> {model_id -> ModelInfo} + self._models: Dict[str, Dict[str, ModelInfo]] = { + ModelCategory.BUILTIN.value: {}, + ModelCategory.USER_DEFINED.value: {}, + ModelCategory.FINETUNE.value: {}, + } + # Async download executor (using single-threaded executor because hf download interface is unstable with concurrent downloads) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Thread lock pool for protecting concurrent access to model information self._lock_pool = ModelLockPool() - self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1 - ) # TODO: Here we set the work_num=1 cause we found that the hf download interface is not stable for concurrent downloading. - self._model_info_map: Dict[str, ModelInfo] = {} - self._init_model_info_map() + self._initialize_directories() + + def _initialize_directories(self): + """Initialize directory structure and ensure __init__.py files exist""" + self.models_dir.mkdir(parents=True, exist_ok=True) + ensure_init_file(self.models_dir) + + for category in ModelCategory: + category_path = self.models_dir / category.value + category_path.mkdir(parents=True, exist_ok=True) + ensure_init_file(category_path) + + # ==================== Discovery Methods ==================== + + def discover_all(self) -> Dict[str, Dict[str, ModelInfo]]: + """Scan file system to discover all models""" + self._discover_category(ModelCategory.BUILTIN) + self._discover_category(ModelCategory.USER_DEFINED) + self._discover_category(ModelCategory.FINETUNE) + return self._models + + def _discover_category(self, category: ModelCategory): + """Discover all models in a category directory""" + category_path = self.models_dir / category.value + if not category_path.exists(): + return + + if category == ModelCategory.BUILTIN: + self._discover_builtin_models(category_path) + else: + # For finetune and user_defined, scan directories + for item in category_path.iterdir(): + if item.is_dir() and not item.name.startswith("__"): + relative_path = item.relative_to(category_path) + model_id = str(relative_path).replace("/", "_").replace("\\", "_") + self._process_model_directory(item, model_id, category) + + def _discover_builtin_models(self, category_path: Path): + # Register SKTIME models directly from map + for model_id in BUILTIN_SKTIME_MODEL_MAP.keys(): + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_SKTIME_MODEL_MAP[model_id] + ) + + # Process HuggingFace Transformers models + for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys(): + model_dir = category_path / model_id + model_dir.mkdir(parents=True, exist_ok=True) + self._process_model_directory(model_dir, model_id, ModelCategory.BUILTIN) + + def _process_model_directory( + self, model_dir: Path, model_id: str, category: ModelCategory + ): + """Handling the discovery logic for a single model directory.""" Review Comment: [nitpick] Missing period at end of comment. The comment should end with a period for consistency with documentation standards. ########## iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py: ########## @@ -17,440 +17,560 @@ # import concurrent.futures -import json import os import shutil -from collections.abc import Callable -from typing import Dict +from typing import List, Optional -import torch -from torch import nn +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri +from iotdb.ainode.core.model.model_enums import REPO_ID_MAP, ModelCategory, ModelStates from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, ModelInfo, - get_built_in_model_type, - get_model_file_type, ) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy +from iotdb.ainode.core.model.utils import * Review Comment: Missing import statement. The code uses `json` module on line 206 but doesn't import it at the top of the file. Add `import json` to the imports section. ########## iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py: ########## @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd Review Comment: Import of 'pd' is not used. ```suggestion ``` ########## iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py: ########## @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd Review Comment: Import of 'pd' is not used. ```suggestion ``` ########## iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py: ########## @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.util.serde import convert_to_binary Review Comment: Import of 'convert_to_binary' is not used. ```suggestion ``` ########## iotdb-core/ainode/iotdb/ainode/core/manager/utils.py: ########## @@ -80,8 +79,8 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: - model_info = BUILT_IN_LTSM_MAP.get(model_id, None) - if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP: + model_info = get_model_manager.get_model_info(model_id) Review Comment: Missing closing parenthesis. The function call `get_model_manager.get_model_info(model_id)` is missing parentheses after `get_model_manager`. It should be `get_model_manager().get_model_info(model_id)`. ```suggestion model_info = get_model_manager().get_model_info(model_id) ``` ########## iotdb-core/ainode/iotdb/ainode/core/model/model_info.py: ########## @@ -15,53 +15,21 @@ # specific language governing permissions and limitations # under the License. # -import glob -import os -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) +from typing import Dict, List, Optional, Tuple Review Comment: Import of 'List' is not used. Import of 'Tuple' is not used. ```suggestion from typing import Dict, Optional ``` ########## iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py: ########## @@ -17,440 +17,560 @@ # import concurrent.futures -import json import os import shutil -from collections.abc import Callable -from typing import Dict +from typing import List, Optional -import torch -from torch import nn +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri +from iotdb.ainode.core.model.model_enums import REPO_ID_MAP, ModelCategory, ModelStates from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, ModelInfo, - get_built_in_model_type, - get_model_file_type, ) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy +from iotdb.ainode.core.model.utils import * Review Comment: Import pollutes the enclosing namespace, as the imported module [iotdb.ainode.core.model.utils](1) does not define '__all__'. ```suggestion from iotdb.ainode.core.model.utils import ensure_init_file ``` ########## iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py: ########## @@ -0,0 +1,379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Sktime model configuration module - simplified version +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +@dataclass +class AttributeConfig: + """Base class for attribute configuration""" + + name: str + default: Any + type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple' + low: Union[int, float, None] = None + high: Union[int, float, None] = None + choices: List[str] = field(default_factory=list) + value_type: type = None # Element type for list and tuple + + def validate_value(self, value): + """Validate if the value meets the requirements""" + if self.type == "int": + if not isinstance(value, int): + raise WrongAttributeTypeError(self.name, "int") + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "float": + if not isinstance(value, (int, float)): + raise WrongAttributeTypeError(self.name, "float") + value = float(value) + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "str": + if not isinstance(value, str): + raise WrongAttributeTypeError(self.name, "str") + if self.choices and value not in self.choices: + raise StringRangeException(self.name, value, self.choices) + elif self.type == "bool": + if not isinstance(value, bool): + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + if not isinstance(value, list): + raise WrongAttributeTypeError(self.name, "list") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + elif self.type == "tuple": + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + return True + + def parse(self, string_value: str): + """Parse string value to corresponding type""" + if self.type == "int": + try: + return int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + elif self.type == "float": + try: + return float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + elif self.type == "str": + return string_value + elif self.type == "bool": + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + else: + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return list_value + elif self.type == "tuple": + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return tuple(list_value) + + +# Model configuration definitions - using concise dictionary format +MODEL_CONFIGS = { + "NAIVE_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "pipeline": AttributeConfig( + "pipeline", "last", "str", choices=["last", "mean"] Review Comment: The attribute name "pipeline" should likely be "strategy" to match the NaiveForecaster parameter. According to sktime documentation, the NaiveForecaster parameter for choosing between "last" and "mean" is called "strategy", not "pipeline". ```suggestion "strategy": AttributeConfig( "strategy", "last", "str", choices=["last", "mean"] ``` ########## iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py: ########## @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.util.serde import convert_to_binary Review Comment: Import of 'convert_to_binary' is not used. ```suggestion ``` ########## iotdb-core/ainode/iotdb/ainode/core/manager/utils.py: ########## @@ -47,7 +46,7 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = ModelManager().load_model(model_id, {}).to(device) + model = get_model_manager().load_model(model_id, {}).to(device) Review Comment: Call to [method ModelManager.load_model](1) with too many arguments; should be no more than 1. ```suggestion model = get_model_manager().load_model(model_id).to(device) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
