This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch more-accurate-exceptions in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 7d3b8524a5ca63e40b9722dc17b7c30e1a254c27 Author: Yongzao <[email protected]> AuthorDate: Thu Dec 11 13:25:18 2025 +0800 ready 4 CI --- .../iotdb/ainode/it/AINodeClusterConfigIT.java | 3 - .../java/org/apache/iotdb/rpc/TSStatusCode.java | 21 +-- iotdb-core/ainode/iotdb/ainode/core/config.py | 8 +- iotdb-core/ainode/iotdb/ainode/core/constant.py | 33 ++--- iotdb-core/ainode/iotdb/ainode/core/exception.py | 65 +++++---- .../core/inference/dispatcher/basic_dispatcher.py | 6 +- .../iotdb/ainode/core/inference/pool_controller.py | 4 +- .../iotdb/ainode/core/inference/pool_group.py | 6 +- .../pool_scheduler/basic_pool_scheduler.py | 4 +- .../iotdb/ainode/core/manager/inference_manager.py | 4 +- .../iotdb/ainode/core/manager/model_manager.py | 43 +++--- .../ainode/iotdb/ainode/core/manager/utils.py | 4 +- .../ainode/iotdb/ainode/core/model/model_loader.py | 4 +- .../iotdb/ainode/core/model/model_storage.py | 158 +++++++-------------- .../core/model/sktime/configuration_sktime.py | 36 ++--- .../ainode/core/model/sktime/modeling_sktime.py | 12 +- .../ainode/core/model/sundial/pipeline_sundial.py | 4 +- .../ainode/core/model/timer_xl/pipeline_timer.py | 4 +- iotdb-core/ainode/iotdb/ainode/core/model/utils.py | 47 +++++- iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py | 48 ++++--- iotdb-core/ainode/iotdb/ainode/core/util/serde.py | 6 +- .../db/exception/ainode/GetModelInfoException.java | 2 +- .../exception/ainode/ModelNotFoundException.java | 28 ---- .../iotdb/db/protocol/client/an/AINodeClient.java | 6 +- .../execution/config/TableConfigTaskVisitor.java | 4 +- .../execution/config/TreeConfigTaskVisitor.java | 4 +- .../config/executor/ClusterConfigTaskExecutor.java | 10 +- .../config/executor/IConfigTaskExecutor.java | 2 +- ...eateTrainingTask.java => CreateTuningTask.java} | 8 +- .../thrift-ainode/src/main/thrift/ainode.thrift | 17 +-- 30 files changed, 278 insertions(+), 323 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java index 2e62f618091..e148be6b20a 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java @@ -107,7 +107,4 @@ public class AINodeClusterConfigIT { } Assert.fail("The target AINode is not removed successfully after all retries."); } - - // TODO: We might need to add remove unknown test in the future, but current infrastructure is too - // hard to implement it. } diff --git a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java index 7f94dd696ac..6bd0da063b0 100644 --- a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java +++ b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java @@ -244,16 +244,17 @@ public enum TSStatusCode { CQ_UPDATE_LAST_EXEC_TIME_ERROR(1403), // AI - CREATE_MODEL_ERROR(1500), - DROP_MODEL_ERROR(1501), - MODEL_EXIST_ERROR(1502), - GET_MODEL_INFO_ERROR(1503), - NO_REGISTERED_AI_NODE_ERROR(1504), - MODEL_NOT_FOUND_ERROR(1505), - REGISTER_AI_NODE_ERROR(1506), - UNAVAILABLE_AI_DEVICE_ERROR(1507), - AI_NODE_INTERNAL_ERROR(1510), - REMOVE_AI_NODE_ERROR(1511), + NO_REGISTERED_AI_NODE_ERROR(1500), + REGISTER_AI_NODE_ERROR(1501), + REMOVE_AI_NODE_ERROR(1502), + MODEL_EXISTED_ERROR(1503), + MODEL_NOT_EXIST_ERROR(1504), + CREATE_MODEL_ERROR(1505), + DROP_BUILTIN_MODEL_ERROR(1506), + DROP_MODEL_ERROR(1507), + UNAVAILABLE_AI_DEVICE_ERROR(1508), + + AINODE_INTERNAL_ERROR(1599), // In case somebody too lazy to add a new error code // Pipe Plugin CREATE_PIPE_PLUGIN_ERROR(1600), diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index e465df7e36d..b14efa3bedf 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -46,7 +46,7 @@ from iotdb.ainode.core.constant import ( AINODE_THRIFT_COMPRESSION_ENABLED, AINODE_VERSION_INFO, ) -from iotdb.ainode.core.exception import BadNodeUrlError +from iotdb.ainode.core.exception import BadNodeUrlException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.util.decorator import singleton from iotdb.thrift.common.ttypes import TEndPoint @@ -437,7 +437,7 @@ class AINodeDescriptor(object): file_configs["ain_cluster_ingress_time_zone"] ) - except BadNodeUrlError: + except BadNodeUrlException: logger.warning("Cannot load AINode conf file, use default configuration.") except Exception as e: @@ -489,7 +489,7 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint: """ split = endpoint_url.split(":") if len(split) != 2: - raise BadNodeUrlError(endpoint_url) + raise BadNodeUrlException(endpoint_url) ip = split[0] try: @@ -497,4 +497,4 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint: result = TEndPoint(ip, port) return result except ValueError: - raise BadNodeUrlError(endpoint_url) + raise BadNodeUrlException(endpoint_url) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index d8f730c829c..44e76840f73 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -81,33 +81,18 @@ DEFAULT_CHUNK_SIZE = 8192 class TSStatusCode(Enum): SUCCESS_STATUS = 200 REDIRECTION_RECOMMEND = 400 - MODEL_EXIST_ERROR = 1502 - MODEL_NOT_FOUND_ERROR = 1505 - UNAVAILABLE_AI_DEVICE_ERROR = 1507 - AINODE_INTERNAL_ERROR = 1510 + MODEL_EXISTED_ERROR = 1503 + MODEL_NOT_EXIST_ERROR = 1504 + CREATE_MODEL_ERROR = 1505 + DROP_BUILTIN_MODEL_ERROR = 1506 + DROP_MODEL_ERROR = 1507 + UNAVAILABLE_AI_DEVICE_ERROR = 1508 + INVALID_URI_ERROR = 1511 INVALID_INFERENCE_CONFIG = 1512 INFERENCE_INTERNAL_ERROR = 1520 - def get_status_code(self) -> int: - return self.value - + AINODE_INTERNAL_ERROR = 1599 # In case somebody too lazy to add a new error code -class HyperparameterName(Enum): - # Training hyperparameter - LEARNING_RATE = "learning_rate" - EPOCHS = "epochs" - BATCH_SIZE = "batch_size" - USE_GPU = "use_gpu" - NUM_WORKERS = "num_workers" - - # Structure hyperparameter - KERNEL_SIZE = "kernel_size" - INPUT_VARS = "input_vars" - BLOCK_TYPE = "block_type" - D_MODEL = "d_model" - INNER_LAYERS = "inner_layer" - OUTER_LAYERS = "outer_layer" - - def name(self): + def get_status_code(self) -> int: return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index 30b9d54dcc7..9a50b209495 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -23,7 +23,7 @@ from iotdb.ainode.core.model.model_constants import ( ) -class _BaseError(Exception): +class _BaseException(Exception): """Base class for exceptions in this module.""" def __init__(self): @@ -33,85 +33,102 @@ class _BaseError(Exception): return self.message -class BadNodeUrlError(_BaseError): +class BadNodeUrlException(_BaseException): def __init__(self, node_url: str): self.message = "Bad node url: {}".format(node_url) -class ModelNotExistError(_BaseError): +# ==================== Model Management ==================== + + +class ModelExistedException(_BaseException): + def __init__(self, model_id: str): + self.message = "Model {} already exists".format(model_id) + + +class ModelNotExistException(_BaseException): + def __init__(self, model_id: str): + self.message = "Model {} is not exists".format(model_id) + + +class InvalidModelUriException(_BaseException): def __init__(self, msg: str): - self.message = "Model is not exists: {} ".format(msg) + self.message = ( + "Model registration failed because the specified uri is invalid: {}".format( + msg + ) + ) + + +class BuiltInModelDeletionException(_BaseException): + def __init__(self, model_id: str): + self.message = "Cannot delete built-in model: {}".format(model_id) -class BadConfigValueError(_BaseError): +class BadConfigValueException(_BaseException): def __init__(self, config_name: str, config_value, hint: str = ""): self.message = "Bad value [{0}] for config {1}. {2}".format( config_value, config_name, hint ) -class MissingConfigError(_BaseError): +class MissingConfigException(_BaseException): def __init__(self, config_name: str): self.message = "Missing config: {}".format(config_name) -class MissingOptionError(_BaseError): +class MissingOptionException(_BaseException): def __init__(self, config_name: str): self.message = "Missing task option: {}".format(config_name) -class RedundantOptionError(_BaseError): +class RedundantOptionException(_BaseException): def __init__(self, option_name: str): self.message = "Redundant task option: {}".format(option_name) -class WrongTypeConfigError(_BaseError): +class WrongTypeConfigException(_BaseException): def __init__(self, config_name: str, expected_type: str): self.message = "Wrong type for config: {0}, expected: {1}".format( config_name, expected_type ) -class UnsupportedError(_BaseError): +class UnsupportedException(_BaseException): def __init__(self, msg: str): self.message = "{0} is not supported in current version".format(msg) -class InvalidUriError(_BaseError): +class InvalidUriException(_BaseException): def __init__(self, uri: str): self.message = "Invalid uri: {}, there are no {} or {} under this uri.".format( uri, MODEL_WEIGHTS_FILE_IN_PT, MODEL_CONFIG_FILE_IN_YAML ) -class InvalidWindowArgumentError(_BaseError): +class InvalidWindowArgumentException(_BaseException): def __init__(self, window_interval, window_step, dataset_length): self.message = f"Invalid inference input: window_interval {window_interval}, window_step {window_step}, dataset_length {dataset_length}" -class InferenceModelInternalError(_BaseError): +class InferenceModelInternalException(_BaseException): def __init__(self, msg: str): self.message = "Inference model internal error: {0}".format(msg) -class BuiltInModelNotSupportError(_BaseError): +class BuiltInModelNotSupportException(_BaseException): def __init__(self, msg: str): self.message = "Built-in model not support: {0}".format(msg) -class BuiltInModelDeletionError(_BaseError): - def __init__(self, model_id: str): - self.message = "Cannot delete built-in model: {0}".format(model_id) - - -class WrongAttributeTypeError(_BaseError): +class WrongAttributeTypeException(_BaseException): def __init__(self, attribute_name: str, expected_type: str): self.message = "Wrong type for attribute: {0}, expected: {1}".format( attribute_name, expected_type ) -class NumericalRangeException(_BaseError): +class NumericalRangeException(_BaseException): def __init__(self, attribute_name: str, value, min_value, max_value): self.message = ( "Attribute {0} expect value between {1} and {2}, got {3} instead.".format( @@ -120,14 +137,14 @@ class NumericalRangeException(_BaseError): ) -class StringRangeException(_BaseError): +class StringRangeException(_BaseException): def __init__(self, attribute_name: str, value: str, expect_value): self.message = "Attribute {0} expect value in {1}, got {2} instead.".format( attribute_name, expect_value, value ) -class ListRangeException(_BaseError): +class ListRangeException(_BaseException): def __init__(self, attribute_name: str, value: list, expected_type: str): self.message = ( "Attribute {0} expect value type list[{1}], got {2} instead.".format( @@ -136,7 +153,7 @@ class ListRangeException(_BaseError): ) -class AttributeNotSupportError(_BaseError): +class AttributeNotSupportException(_BaseException): def __init__(self, model_name: str, attribute_name: str): self.message = "Attribute {0} is not supported in model {1}".format( attribute_name, model_name diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py b/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py index d06ba55546f..d9a2d8663e5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py @@ -16,7 +16,7 @@ # under the License. # -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.dispatcher.abstract_dispatcher import ( AbstractDispatcher, ) @@ -41,7 +41,7 @@ class BasicDispatcher(AbstractDispatcher): """ model_id = req.model_id if not pool_ids: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"No available pools for model {model_id}" ) start_idx = hash(req.req_id) % len(pool_ids) @@ -51,7 +51,7 @@ class BasicDispatcher(AbstractDispatcher): state = self.pool_states[pool_id] if state == PoolState.RUNNING: return pool_id - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"No RUNNING pools available for model {model_id}" ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 8ffa89ffd67..c580a89916d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -24,7 +24,7 @@ from typing import Dict, Optional import torch.multiprocessing as mp -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, InferenceRequestProxy, @@ -374,7 +374,7 @@ class PoolController: if not self.has_request_pools(model_id): logger.error(f"[Inference] No pools found for model {model_id}.") infer_proxy.set_result(None) - raise InferenceModelInternalError( + raise InferenceModelInternalException( "Dispatch request failed, because no inference pools are init." ) # TODO: Implement adaptive scaling based on requests.(e.g. lazy initialization) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py index a700dcee473..b85f64d42cc 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py @@ -19,7 +19,7 @@ from typing import Dict, Tuple import torch.multiprocessing as mp -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.dispatcher.basic_dispatcher import BasicDispatcher from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, @@ -90,14 +90,14 @@ class PoolGroup: def get_request_pool(self, pool_id) -> InferenceRequestPool: if pool_id not in self.pool_group: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference][Pool-{pool_id}] Pool not found for model {self.model_id}" ) return self.pool_group[pool_id][0] def get_request_queue(self, pool_id) -> mp.Queue: if pool_id not in self.pool_group: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference][Pool-{pool_id}] Pool not found for model {self.model_id}" ) return self.pool_group[pool_id][1] diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index d2e7292ecd8..21140cafb1f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -20,7 +20,7 @@ from typing import Dict, List, Optional import torch -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.pool_group import PoolGroup from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import ( AbstractPoolScheduler, @@ -113,7 +113,7 @@ class BasicPoolScheduler(AbstractPoolScheduler): if model_id not in self._request_pool_map: pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id) if pool_num <= 0: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"Not enough memory to run model {model_id}." ) return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)] diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 1ce2e84e059..34c315274f5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -27,7 +27,7 @@ import torch.multiprocessing as mp from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import ( - InferenceModelInternalError, + InferenceModelInternalException, NumericalRangeException, ) from iotdb.ainode.core.inference.inference_request import ( @@ -161,7 +161,7 @@ class InferenceManager: return outputs except Exception as e: logger.error(e) - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) finally: with self._result_wrapper_lock: del self._result_wrapper_map[req_id] diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index 8ffb33d91e2..76fc46b411a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -16,12 +16,16 @@ # under the License. # -from typing import Any, List, Optional +from typing import List, Optional from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BuiltInModelDeletionError +from iotdb.ainode.core.exception import ( + BuiltInModelDeletionException, + InvalidModelUriException, + ModelExistedException, + ModelNotExistException, +) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_loader import load_model from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.decorator import singleton @@ -47,16 +51,15 @@ class ModelManager: req: TRegisterModelReq, ) -> TRegisterModelResp: try: - if self._model_storage.register_model(model_id=req.modelId, uri=req.uri): - return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) - except ValueError as e: + self._model_storage.register_model(model_id=req.modelId, uri=req.uri) + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) + except ModelExistedException as e: return TRegisterModelResp( - get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) + get_status(TSStatusCode.MODEL_EXISTED_ERROR, str(e)) ) - except Exception as e: + except InvalidModelUriException as e: return TRegisterModelResp( - get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e)) ) def show_models(self, req: TShowModelsReq) -> TShowModelsResp: @@ -67,12 +70,12 @@ class ModelManager: try: self._model_storage.delete_model(req.modelId) return get_status(TSStatusCode.SUCCESS_STATUS) - except BuiltInModelDeletionError as e: - logger.warning(e) - return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + except ModelNotExistException as e: + return get_status(TSStatusCode.MODEL_NOT_EXIST_ERROR, str(e)) + except BuiltInModelDeletionException as e: + return get_status(TSStatusCode.DROP_BUILTIN_MODEL_ERROR, str(e)) except Exception as e: - logger.warning(e) - return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + return get_status(TSStatusCode.DROP_MODEL_ERROR, str(e)) def get_model_info( self, @@ -81,19 +84,9 @@ class ModelManager: ) -> Optional[ModelInfo]: return self._model_storage.get_model_info(model_id, category) - def get_model_infos( - self, - category: Optional[ModelCategory] = None, - model_type: Optional[str] = None, - ) -> List[ModelInfo]: - return self._model_storage.get_model_infos(category, model_type) - def _refresh(self): """Refresh the model list (re-scan the file system)""" self._model_storage.discover_all_models() - def get_registered_models(self) -> List[str]: - return self._model_storage.get_registered_models() - def is_model_registered(self, model_id: str) -> bool: return self._model_storage.is_model_registered(model_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 23a98f26bbf..17516876201 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -22,7 +22,7 @@ import psutil import torch from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.exception import ModelNotExistException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager from iotdb.ainode.core.model.model_loader import load_model @@ -86,7 +86,7 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." ) - raise ModelNotExistError(model_id) + raise ModelNotExistException(model_id) system_res = evaluate_system_resources(device) free_mem = system_res["free_mem"] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index a6e3b1f7b5e..29a7c14c972 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -32,7 +32,7 @@ from transformers import ( ) from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.exception import ModelNotExistException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ModelCategory from iotdb.ainode.core.model.model_info import ModelInfo @@ -131,7 +131,7 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs): model_file = os.path.join(model_path, "model.pt") if not os.path.exists(model_file): logger.error(f"Model file not found at {model_file}.") - raise ModelNotExistError(model_file) + raise ModelNotExistException(model_file) model = torch.jit.load(model_file) if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: return model diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index 5194ed4df1b..18f5b7a90ba 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -23,12 +23,16 @@ import shutil from pathlib import Path from typing import Dict, List, Optional -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import hf_hub_download from transformers import AutoConfig, AutoModelForCausalLM from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BuiltInModelDeletionError +from iotdb.ainode.core.exception import ( + BuiltInModelDeletionException, + ModelExistedException, + ModelNotExistException, +) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ( MODEL_CONFIG_FILE_IN_JSON, @@ -43,6 +47,8 @@ from iotdb.ainode.core.model.model_info import ( ModelInfo, ) from iotdb.ainode.core.model.utils import ( + _fetch_model_from_hf_repo, + _fetch_model_from_local, ensure_init_file, get_parsed_uri, import_class_from_path, @@ -233,12 +239,23 @@ class ModelStorage: # ==================== Registration Methods ==================== - def register_model(self, model_id: str, uri: str) -> bool: + def register_model(self, model_id: str, uri: str): """ - Supported URI formats: - - repo://<huggingface_repo_id> (Maybe in the future) - - file://<local_path> + Register a user-defined model from a given URI. + Args: + model_id (str): Unique identifier for the model. + uri (str): URI to fetch the model from. + Supported URI formats: + - file://<local_path> + - repo://<huggingface_repo_id> (Maybe in the future) + Raises: + ModelExistedError: If the model_id already exists. + InvalidModelUriError: If the URI format is invalid. """ + + if self.is_model_registered(model_id): + raise ModelExistedException(model_id) + uri_type = parse_uri_type(uri) parsed_uri = get_parsed_uri(uri) @@ -249,9 +266,9 @@ class ModelStorage: ensure_init_file(model_dir) if uri_type == UriType.REPO: - self._fetch_model_from_hf_repo(parsed_uri, model_dir) + _fetch_model_from_hf_repo(parsed_uri, model_dir) else: - self._fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) + _fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) config_path, _ = validate_model_files(model_dir) config = load_model_config_in_json(config_path) @@ -272,7 +289,7 @@ class ModelStorage: self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info if auto_map: - # Transformers model: immediately register to Transformers auto-loading mechanism + # Transformers model: immediately register to Transformers autoloading mechanism success = self._register_transformers_model(model_info) if success: with self._lock_pool.get_lock(model_id).write_lock(): @@ -281,46 +298,15 @@ class ModelStorage: with self._lock_pool.get_lock(model_id).write_lock(): model_info.state = ModelStates.INACTIVE logger.error(f"Failed to register Transformers model {model_id}") - return False else: # Other type models: only log self._register_other_model(model_info) logger.info(f"Successfully registered model {model_id} from URI: {uri}") - return True - def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): - logger.info( - f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" - ) - # Use snapshot_download to download entire repository (including config.json and model.safetensors) - try: - snapshot_download( - repo_id=repo_id, - local_dir=storage_path, - local_dir_use_symlinks=False, - ) - except Exception as e: - logger.error(f"Failed to download model from HuggingFace: {e}") - raise - - def _fetch_model_from_local(self, source_path: str, storage_path: str): - logger.info(f"Copying model from local path: {source_path} -> {storage_path}") - source_dir = Path(source_path) - if not source_dir.is_dir(): - raise ValueError( - f"Source path does not exist or is not a directory: {source_path}" - ) - - storage_dir = Path(storage_path) - for file in source_dir.iterdir(): - if file.is_file(): - shutil.copy2(file, storage_dir / file.name) - return - - def _register_transformers_model(self, model_info: ModelInfo) -> bool: + def _register_transformers_model(self, model_info: ModelInfo): """ - Register Transformers model to auto-loading mechanism (internal method) + Register Transformers model to autoloading mechanism (internal method) """ auto_map = model_info.auto_map if not auto_map: @@ -350,7 +336,6 @@ class ModelStorage: logger.info( f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" ) - return True except Exception as e: logger.warning( @@ -471,8 +456,16 @@ class ModelStorage: stateMap=state_map, ) - def delete_model(self, model_id: str) -> None: - # Use write lock to protect entire deletion process + def delete_model(self, model_id: str): + """ + Delete a user-defined model by model_id. + Args: + model_id (str): Unique identifier for the model to be deleted. + Raises: + ModelNotExistException: If the model_id does not exist. + BuiltInModelDeletionException: If attempting to delete a built-in model. + Others: Any exceptions raised during file deletion. + """ with self._lock_pool.get_lock(model_id).write_lock(): model_info = None category_value = None @@ -481,30 +474,25 @@ class ModelStorage: model_info = category_dict[model_id] category_value = cat_value break - if not model_info: logger.warning(f"Model {model_id} does not exist, cannot delete") - return - + raise ModelNotExistException(model_id) if model_info.category == ModelCategory.BUILTIN: - raise BuiltInModelDeletionError(model_id) + logger.warning(f"Model {model_id} is builtin, cannot delete") + raise BuiltInModelDeletionException(model_id) model_info.state = ModelStates.DROPPING model_path = os.path.join( self._models_dir, model_info.category.value, model_id ) - if model_path.exists(): + if os.path.exists(model_path): try: shutil.rmtree(model_path) - logger.info(f"Deleted model directory: {model_path}") + logger.info(f"Model directory is deleted: {model_path}") except Exception as e: logger.error(f"Failed to delete model directory {model_path}: {e}") - raise - - if category_value and model_id in self._models[category_value]: - del self._models[category_value][model_id] - logger.info(f"Model {model_id} has been removed from storage") - - return + raise e + del self._models[category_value][model_id] + logger.info(f"Model {model_id} has been removed from model storage") # ==================== Query Methods ==================== @@ -512,10 +500,14 @@ class ModelStorage: self, model_id: str, category: Optional[ModelCategory] = None ) -> Optional[ModelInfo]: """ - Get single model information - - If category is specified, use model_id's lock - If category is not specified, need to traverse all dictionaries, use global lock + Get specified model information. + Args: + model_id (str): Unique identifier for the model. + category (Optional[ModelCategory]): Category of the model (if known). + Returns: + ModelInfo: Information of the specified model. + Raises: + ModelNotExistException: If the model_id does not exist. """ if category: # Category specified, only need to access specific dictionary, use model_id's lock @@ -527,39 +519,7 @@ class ModelStorage: for category_dict in self._models.values(): if model_id in category_dict: return category_dict[model_id] - return None - - def get_model_infos( - self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None - ) -> List[ModelInfo]: - """ - Get model information list - - Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure - For single model access, using model_id-based lock would be more efficient - """ - matching_models = [] - - # For traversal operations, we need to protect the entire dictionary structure - # Use a special lock (using empty string as key) to protect the entire dictionary - with self._lock_pool.get_lock("").read_lock(): - if category and model_type: - for model_info in self._models[category.value].values(): - if model_info.model_type == model_type: - matching_models.append(model_info) - return matching_models - elif category: - return list(self._models[category.value].values()) - elif model_type: - for category_dict in self._models.values(): - for model_info in category_dict.values(): - if model_info.model_type == model_type: - matching_models.append(model_info) - return matching_models - else: - for category_dict in self._models.values(): - matching_models.extend(category_dict.values()) - return matching_models + raise ModelNotExistException(model_id) def is_model_registered(self, model_id: str) -> bool: """Check if model is registered (search in _models)""" @@ -572,11 +532,3 @@ class ModelStorage: if model_id in category_dict: return True return False - - def get_registered_models(self) -> List[str]: - """Get list of all registered model IDs""" - with self._lock_pool.get_lock("").read_lock(): - model_ids = [] - for category_dict in self._models.values(): - model_ids.extend(category_dict.keys()) - return model_ids diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py index 261de3c9abe..d9d20545af6 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -20,11 +20,11 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Union from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, + BuiltInModelNotSupportException, ListRangeException, NumericalRangeException, StringRangeException, - WrongAttributeTypeError, + WrongAttributeTypeException, ) from iotdb.ainode.core.log import Logger @@ -49,7 +49,7 @@ class AttributeConfig: if value is None: return True # Allow None for optional int parameters if not isinstance(value, int): - raise WrongAttributeTypeError(self.name, "int") + raise WrongAttributeTypeException(self.name, "int") if self.low is not None and self.high is not None: if not (self.low <= value <= self.high): raise NumericalRangeException(self.name, value, self.low, self.high) @@ -57,7 +57,7 @@ class AttributeConfig: if value is None: return True # Allow None for optional float parameters if not isinstance(value, (int, float)): - raise WrongAttributeTypeError(self.name, "float") + raise WrongAttributeTypeException(self.name, "float") value = float(value) if self.low is not None and self.high is not None: if not (self.low <= value <= self.high): @@ -66,26 +66,26 @@ class AttributeConfig: if value is None: return True # Allow None for optional str parameters if not isinstance(value, str): - raise WrongAttributeTypeError(self.name, "str") + raise WrongAttributeTypeException(self.name, "str") if self.choices and value not in self.choices: raise StringRangeException(self.name, value, self.choices) elif self.type == "bool": if value is None: return True # Allow None for optional bool parameters if not isinstance(value, bool): - raise WrongAttributeTypeError(self.name, "bool") + raise WrongAttributeTypeException(self.name, "bool") elif self.type == "list": if not isinstance(value, list): - raise WrongAttributeTypeError(self.name, "list") + raise WrongAttributeTypeException(self.name, "list") for item in value: if not isinstance(item, self.value_type): - raise WrongAttributeTypeError(self.name, self.value_type) + raise WrongAttributeTypeException(self.name, self.value_type) elif self.type == "tuple": if not isinstance(value, tuple): - raise WrongAttributeTypeError(self.name, "tuple") + raise WrongAttributeTypeException(self.name, "tuple") for item in value: if not isinstance(item, self.value_type): - raise WrongAttributeTypeError(self.name, self.value_type) + raise WrongAttributeTypeException(self.name, self.value_type) return True def parse(self, string_value: str): @@ -96,14 +96,14 @@ class AttributeConfig: try: return int(string_value) except: - raise WrongAttributeTypeError(self.name, "int") + raise WrongAttributeTypeException(self.name, "int") elif self.type == "float": if string_value.lower() == "none" or string_value.strip() == "": return None try: return float(string_value) except: - raise WrongAttributeTypeError(self.name, "float") + raise WrongAttributeTypeException(self.name, "float") elif self.type == "str": if string_value.lower() == "none" or string_value.strip() == "": return None @@ -116,14 +116,14 @@ class AttributeConfig: elif string_value.lower() == "none" or string_value.strip() == "": return None else: - raise WrongAttributeTypeError(self.name, "bool") + raise WrongAttributeTypeException(self.name, "bool") elif self.type == "list": try: list_value = eval(string_value) except: - raise WrongAttributeTypeError(self.name, "list") + raise WrongAttributeTypeException(self.name, "list") if not isinstance(list_value, list): - raise WrongAttributeTypeError(self.name, "list") + raise WrongAttributeTypeException(self.name, "list") for i in range(len(list_value)): try: list_value[i] = self.value_type(list_value[i]) @@ -136,9 +136,9 @@ class AttributeConfig: try: tuple_value = eval(string_value) except: - raise WrongAttributeTypeError(self.name, "tuple") + raise WrongAttributeTypeException(self.name, "tuple") if not isinstance(tuple_value, tuple): - raise WrongAttributeTypeError(self.name, "tuple") + raise WrongAttributeTypeException(self.name, "tuple") list_value = list(tuple_value) for i in range(len(list_value)): try: @@ -390,7 +390,7 @@ def get_attributes(model_id: str) -> Dict[str, AttributeConfig]: """Get attribute configuration for Sktime model""" model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id if model_id not in MODEL_CONFIGS: - raise BuiltInModelNotSupportError(model_id) + raise BuiltInModelNotSupportException(model_id) return MODEL_CONFIGS[model_id] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py index eca812d35ec..9ddbcab286f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -30,8 +30,8 @@ from sktime.forecasting.naive import NaiveForecaster from sktime.forecasting.trend import STLForecaster from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, - InferenceModelInternalError, + BuiltInModelNotSupportException, + InferenceModelInternalException, ) from iotdb.ainode.core.log import Logger @@ -66,7 +66,7 @@ class ForecastingModel(SktimeModel): output = self._model.predict(fh=range(predict_length)) return np.array(output, dtype=np.float64) except Exception as e: - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) class DetectionModel(SktimeModel): @@ -82,7 +82,7 @@ class DetectionModel(SktimeModel): else: return np.array(output, dtype=np.int32) except Exception as e: - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) class ArimaModel(ForecastingModel): @@ -155,7 +155,7 @@ class STRAYModel(DetectionModel): scaled_data = pd.Series(scaled_data.flatten()) return super().generate(scaled_data, **kwargs) except Exception as e: - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) # Model factory mapping @@ -176,5 +176,5 @@ def create_sktime_model(model_id: str, **kwargs) -> SktimeModel: attributes = update_attribute({**kwargs}, get_attributes(model_id.upper())) model_class = _MODEL_FACTORY.get(model_id.upper()) if model_class is None: - raise BuiltInModelNotSupportError(model_id) + raise BuiltInModelNotSupportException(model_id) return model_class(attributes) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 85b6f7db2ff..ee128802d24 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -18,7 +18,7 @@ import torch -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline @@ -28,7 +28,7 @@ class SundialPipeline(ForecastPipeline): def _preprocess(self, inputs): if len(inputs.shape) != 2: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index c0f00b1f5ca..65c6cdd74cd 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -18,7 +18,7 @@ import torch -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline @@ -28,7 +28,7 @@ class TimerPipeline(ForecastPipeline): def _preprocess(self, inputs): if len(inputs.shape) != 2: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py index 1cd0ee44912..0a843b01d42 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -19,25 +19,29 @@ import importlib import json import os.path +import shutil import sys from contextlib import contextmanager +from pathlib import Path from typing import Dict, Tuple +from huggingface_hub import snapshot_download + +from iotdb.ainode.core.exception import InvalidModelUriException from iotdb.ainode.core.model.model_constants import ( MODEL_CONFIG_FILE_IN_JSON, MODEL_WEIGHTS_FILE_IN_SAFETENSORS, UriType, ) +from iotdb.ainode.core.model.model_storage import logger def parse_uri_type(uri: str) -> UriType: - if uri.startswith("repo://"): - return UriType.REPO - elif uri.startswith("file://"): + if uri.startswith("file://"): return UriType.FILE else: - raise ValueError( - f"Unsupported URI type: {uri}. Supported formats: repo:// or file://" + raise InvalidModelUriException( + f"Unknown uri type {uri}, currently supporting formats: file://" ) @@ -70,9 +74,13 @@ def validate_model_files(model_dir: str) -> Tuple[str, str]: weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) if not os.path.exists(config_path): - raise ValueError(f"Model config file does not exist: {config_path}") + raise InvalidModelUriException( + f"Model config file does not exist: {config_path}" + ) if not os.path.exists(weights_path): - raise ValueError(f"Model weights file does not exist: {weights_path}") + raise InvalidModelUriException( + f"Model weights file does not exist: {weights_path}" + ) # Create __init__.py file to ensure model directory can be imported as a module init_file = os.path.join(model_dir, "__init__.py") @@ -96,3 +104,28 @@ def ensure_init_file(dir_path: str): if not os.path.exists(init_file): with open(init_file, "w"): pass + + +def _fetch_model_from_local(source_path: str, storage_path: str): + logger.info(f"Copying model from local path: {source_path} -> {storage_path}") + source_dir = Path(source_path) + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + + +def _fetch_model_from_hf_repo(repo_id: str, storage_path: str): + logger.info( + f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" + ) + # Use snapshot_download to download entire repository (including config.json and model.safetensors) + try: + snapshot_download( + repo_id=repo_id, + local_dir=storage_path, + local_dir_use_symlinks=False, + ) + except Exception as e: + logger.error(f"Failed to download model from HuggingFace: {e}") + raise diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index 6c4eedeb99f..492802fc060 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -40,7 +40,7 @@ from iotdb.thrift.ainode.ttypes import ( TShowLoadedModelsResp, TShowModelsReq, TShowModelsResp, - TTrainingReq, + TTuningReq, TUnloadModelReq, ) from iotdb.thrift.common.ttypes import TSStatus @@ -56,8 +56,8 @@ def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus: for device_id in device_id_list: if device_id not in available_devices: return TSStatus( - code=TSStatusCode.INVALID_URI_ERROR.value, - message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value, + message=f"AIDevice ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) @@ -68,7 +68,9 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface): self._model_manager = ModelManager() self._inference_manager = InferenceManager() - def stop(self) -> None: + # ==================== Cluster Management ==================== + + def stop(self): logger.info("Stopping the RPC service handler of IoTDB-AINode...") self._inference_manager.stop() @@ -76,6 +78,17 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface): self._ainode.stop() return get_status(TSStatusCode.SUCCESS_STATUS, "AINode stopped successfully.") + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: + return ClusterManager.get_heart_beat(req) + + def showAIDevices(self) -> TShowAIDevicesResp: + return TShowAIDevicesResp( + status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value), + deviceIdList=get_available_devices(), + ) + + # ==================== Model Management ==================== + def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: return self._model_manager.register_model(req) @@ -109,11 +122,15 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface): return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) return self._inference_manager.show_loaded_models(req) - def showAIDevices(self) -> TShowAIDevicesResp: - return TShowAIDevicesResp( - status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value), - deviceIdList=get_available_devices(), - ) + def _ensure_model_is_registered(self, model_id: str) -> TSStatus: + if not self._model_manager.is_model_registered(model_id): + return TSStatus( + code=TSStatusCode.MODEL_NOT_EXIST_ERROR.value, + message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", + ) + return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) + + # ==================== Inference ==================== def inference(self, req: TInferenceReq) -> TInferenceResp: status = self._ensure_model_is_registered(req.modelId) @@ -127,16 +144,7 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface): return TForecastResp(status, []) return self._inference_manager.forecast(req) - def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: - return ClusterManager.get_heart_beat(req) + # ==================== Tuning ==================== - def createTrainingTask(self, req: TTrainingReq) -> TSStatus: + def createTuningTask(self, req: TTuningReq) -> TSStatus: pass - - def _ensure_model_is_registered(self, model_id: str) -> TSStatus: - if not self._model_manager.is_model_registered(model_id): - return TSStatus( - code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", - ) - return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py index f8188209a37..9c6020019fc 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py +++ b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py @@ -21,7 +21,7 @@ from enum import Enum import numpy as np import pandas as pd -from iotdb.ainode.core.exception import BadConfigValueError +from iotdb.ainode.core.exception import BadConfigValueException class TSDataType(Enum): @@ -122,7 +122,7 @@ def _get_type_in_byte(data_type: pd.Series): elif data_type == "text": return b"\x05" else: - raise BadConfigValueError( + raise BadConfigValueException( "data_type", data_type, "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']", @@ -138,7 +138,7 @@ def get_data_type_byte_from_str(value): byte: corresponding data type in [b'\x00', b'\x01', b'\x02', b'\x03', b'\x04', b'\x05'] """ if value not in ["bool", "int32", "int64", "float32", "float64", "text"]: - raise BadConfigValueError( + raise BadConfigValueException( "data_type", value, "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']", diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java index 03402d30c64..b8d98c9c1d7 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java @@ -23,6 +23,6 @@ import org.apache.iotdb.rpc.TSStatusCode; public class GetModelInfoException extends ModelException { public GetModelInfoException(String message) { - super(message, TSStatusCode.GET_MODEL_INFO_ERROR); + super(message, TSStatusCode.AINODE_INTERNAL_ERROR); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java deleted file mode 100644 index 38a5105cded..00000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.exception.ainode; - -import org.apache.iotdb.rpc.TSStatusCode; - -public class ModelNotFoundException extends ModelException { - public ModelNotFoundException(String message) { - super(message, TSStatusCode.MODEL_NOT_FOUND_ERROR); - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java index 5eaffc40af9..ffad889a223 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java @@ -35,7 +35,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TTuningReq; import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TEndPoint; @@ -129,8 +129,8 @@ public class AINodeClient implements IAINodeRPCService.Iface, AutoCloseable, Thr } @Override - public TSStatus createTrainingTask(TTrainingReq req) throws TException { - return executeRemoteCallWithRetry(() -> client.createTrainingTask(req)); + public TSStatus createTuningTask(TTuningReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.createTuningTask(req)); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java index 847b5d62880..bc52a3a7e8c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java @@ -62,7 +62,7 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowPipePl import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowRegionTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateModelTask; -import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTuningTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.DropModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.LoadModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowAIDevicesTask; @@ -1520,7 +1520,7 @@ public class TableConfigTaskVisitor extends AstVisitor<IConfigTask, MPPQueryCont protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext context) { context.setQueryType(QueryType.WRITE); accessControl.checkUserGlobalSysPrivilege(context); - return new CreateTrainingTask( + return new CreateTuningTask( node.getModelId(), node.getParameters(), node.getExistingModelId(), node.getTargetSql()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java index cb4d05f1b99..ed5e2e434f4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java @@ -68,7 +68,7 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTrigge import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.UnSetTTLTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateModelTask; -import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTuningTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.DropModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.LoadModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowAIDevicesTask; @@ -926,7 +926,7 @@ public class TreeConfigTaskVisitor extends StatementVisitor<IConfigTask, MPPQuer for (PartialPath partialPath : partialPathList) { targetPathPatterns.add(partialPath.getFullPath()); } - return new CreateTrainingTask( + return new CreateTuningTask( createTrainingStatement.getModelId(), createTrainingStatement.getParameters(), createTrainingStatement.getTargetTimeRanges(), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index 01f6757f02e..4d99a3ed892 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -28,7 +28,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TTuningReq; import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.FunctionType; import org.apache.iotdb.common.rpc.thrift.Model; @@ -3727,7 +3727,7 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { } @Override - public SettableFuture<ConfigTaskResult> createTraining( + public SettableFuture<ConfigTaskResult> createTuningTask( String modelId, boolean isTableModel, Map<String, String> parameters, @@ -3736,9 +3736,9 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { @Nullable String targetSql, @Nullable List<String> pathList) { final SettableFuture<ConfigTaskResult> future = SettableFuture.create(); - try (final AINodeClient ai = + try (final AINodeClient aiNodeClient = AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { - final TTrainingReq req = new TTrainingReq(); + final TTuningReq req = new TTuningReq(); req.setModelId(modelId); req.setParameters(parameters); if (existingModelId != null) { @@ -3747,7 +3747,7 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { if (existingModelId != null) { req.setExistingModelId(existingModelId); } - final TSStatus status = ai.createTrainingTask(req); + final TSStatus status = aiNodeClient.createTuningTask(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) { future.setException(new IoTDBException(status)); } else { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java index 85455756b05..2d7c0f6d1f3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java @@ -443,7 +443,7 @@ public interface IConfigTaskExecutor { SettableFuture<ConfigTaskResult> unloadModel(String existingModelId, List<String> deviceIdList); - SettableFuture<ConfigTaskResult> createTraining( + SettableFuture<ConfigTaskResult> createTuningTask( String modelId, boolean isTableModel, Map<String, String> parameters, diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java similarity index 93% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java index 9c93c5b7577..a66b7e700b7 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java @@ -28,7 +28,7 @@ import com.google.common.util.concurrent.ListenableFuture; import java.util.List; import java.util.Map; -public class CreateTrainingTask implements IConfigTask { +public class CreateTuningTask implements IConfigTask { private final String modelId; private final boolean isTableModel; @@ -43,7 +43,7 @@ public class CreateTrainingTask implements IConfigTask { private List<List<Long>> timeRanges; // For table model - public CreateTrainingTask( + public CreateTuningTask( String modelId, Map<String, String> parameters, String existingModelId, String targetSql) { this.modelId = modelId; this.parameters = parameters; @@ -53,7 +53,7 @@ public class CreateTrainingTask implements IConfigTask { } // For tree model - public CreateTrainingTask( + public CreateTuningTask( String modelId, Map<String, String> parameters, List<List<Long>> timeRanges, @@ -71,7 +71,7 @@ public class CreateTrainingTask implements IConfigTask { @Override public ListenableFuture<ConfigTaskResult> execute(IConfigTaskExecutor configTaskExecutor) throws InterruptedException { - return configTaskExecutor.createTraining( + return configTaskExecutor.createTuningTask( modelId, isTableModel, parameters, timeRanges, existingModelId, targetSql, targetPaths); } } diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index 1dc2f025f5c..cda356a948e 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -73,7 +73,7 @@ struct IDataSchema { 2: optional list<i64> timeRange } -struct TTrainingReq { +struct TTuningReq { 1: required string dbType 2: required string modelId 3: required string existingModelId @@ -131,30 +131,27 @@ struct TUnloadModelReq { service IAINodeRPCService { - // -------------- For Config Node -------------- common.TSStatus stopAINode() + TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) + + TShowAIDevicesResp showAIDevices() + TShowModelsResp showModels(TShowModelsReq req) TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) - TShowAIDevicesResp showAIDevices() - common.TSStatus deleteModel(TDeleteModelReq req) TRegisterModelResp registerModel(TRegisterModelReq req) - TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) - - common.TSStatus createTrainingTask(TTrainingReq req) - common.TSStatus loadModel(TLoadModelReq req) common.TSStatus unloadModel(TUnloadModelReq req) - // -------------- For Data Node -------------- - TInferenceResp inference(TInferenceReq req) TForecastResp forecast(TForecastReq req) + + common.TSStatus createTuningTask(TTuningReq req) } \ No newline at end of file
