This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new c722aaf6c8d [AINode] More accurate exception for model management
(#16895)
c722aaf6c8d is described below
commit c722aaf6c8d3bb3b040d537ec8696efecc806156
Author: Yongzao <[email protected]>
AuthorDate: Thu Dec 11 16:32:56 2025 +0800
[AINode] More accurate exception for model management (#16895)
---
.../iotdb/ainode/it/AINodeClusterConfigIT.java | 3 -
.../iotdb/ainode/it/AINodeModelManageIT.java | 4 +-
.../java/org/apache/iotdb/rpc/TSStatusCode.java | 21 +--
iotdb-core/ainode/iotdb/ainode/core/config.py | 8 +-
iotdb-core/ainode/iotdb/ainode/core/constant.py | 33 ++---
iotdb-core/ainode/iotdb/ainode/core/exception.py | 108 ++++++--------
.../core/inference/dispatcher/basic_dispatcher.py | 6 +-
.../iotdb/ainode/core/inference/pool_controller.py | 4 +-
.../iotdb/ainode/core/inference/pool_group.py | 6 +-
.../pool_scheduler/basic_pool_scheduler.py | 4 +-
.../iotdb/ainode/core/manager/inference_manager.py | 4 +-
.../iotdb/ainode/core/manager/model_manager.py | 43 +++---
.../ainode/iotdb/ainode/core/manager/utils.py | 4 +-
.../iotdb/ainode/core/model/model_constants.py | 7 -
.../ainode/iotdb/ainode/core/model/model_loader.py | 4 +-
.../iotdb/ainode/core/model/model_storage.py | 158 +++++++--------------
.../core/model/sktime/configuration_sktime.py | 36 ++---
.../ainode/core/model/sktime/modeling_sktime.py | 12 +-
.../ainode/core/model/sundial/pipeline_sundial.py | 4 +-
.../ainode/core/model/timer_xl/pipeline_timer.py | 4 +-
iotdb-core/ainode/iotdb/ainode/core/model/utils.py | 53 ++++++-
iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py | 48 ++++---
iotdb-core/ainode/iotdb/ainode/core/util/serde.py | 6 +-
.../db/exception/ainode/GetModelInfoException.java | 2 +-
.../exception/ainode/ModelNotFoundException.java | 28 ----
.../iotdb/db/protocol/client/an/AINodeClient.java | 6 +-
.../execution/config/TableConfigTaskVisitor.java | 4 +-
.../execution/config/TreeConfigTaskVisitor.java | 4 +-
.../config/executor/ClusterConfigTaskExecutor.java | 10 +-
.../config/executor/IConfigTaskExecutor.java | 2 +-
...eateTrainingTask.java => CreateTuningTask.java} | 8 +-
.../thrift-ainode/src/main/thrift/ainode.thrift | 17 +--
32 files changed, 285 insertions(+), 376 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java
index 2e62f618091..e148be6b20a 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java
@@ -107,7 +107,4 @@ public class AINodeClusterConfigIT {
}
Assert.fail("The target AINode is not removed successfully after all
retries.");
}
-
- // TODO: We might need to add remove unknown test in the future, but current
infrastructure is too
- // hard to implement it.
}
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
index b92b80aecf3..3315617e7fd 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
@@ -131,7 +131,7 @@ public class AINodeModelManageIT {
public void dropBuiltInModelErrorTestInTree() throws SQLException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- errorTest(statement, "drop model sundial", "1510: Cannot delete built-in
model: sundial");
+ errorTest(statement, "drop model sundial", "1506: Cannot delete built-in
model: sundial");
}
}
@@ -139,7 +139,7 @@ public class AINodeModelManageIT {
public void dropBuiltInModelErrorTestInTable() throws SQLException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- errorTest(statement, "drop model sundial", "1510: Cannot delete built-in
model: sundial");
+ errorTest(statement, "drop model sundial", "1506: Cannot delete built-in
model: sundial");
}
}
diff --git
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
index 7f94dd696ac..6bd0da063b0 100644
---
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
+++
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
@@ -244,16 +244,17 @@ public enum TSStatusCode {
CQ_UPDATE_LAST_EXEC_TIME_ERROR(1403),
// AI
- CREATE_MODEL_ERROR(1500),
- DROP_MODEL_ERROR(1501),
- MODEL_EXIST_ERROR(1502),
- GET_MODEL_INFO_ERROR(1503),
- NO_REGISTERED_AI_NODE_ERROR(1504),
- MODEL_NOT_FOUND_ERROR(1505),
- REGISTER_AI_NODE_ERROR(1506),
- UNAVAILABLE_AI_DEVICE_ERROR(1507),
- AI_NODE_INTERNAL_ERROR(1510),
- REMOVE_AI_NODE_ERROR(1511),
+ NO_REGISTERED_AI_NODE_ERROR(1500),
+ REGISTER_AI_NODE_ERROR(1501),
+ REMOVE_AI_NODE_ERROR(1502),
+ MODEL_EXISTED_ERROR(1503),
+ MODEL_NOT_EXIST_ERROR(1504),
+ CREATE_MODEL_ERROR(1505),
+ DROP_BUILTIN_MODEL_ERROR(1506),
+ DROP_MODEL_ERROR(1507),
+ UNAVAILABLE_AI_DEVICE_ERROR(1508),
+
+ AINODE_INTERNAL_ERROR(1599), // In case somebody too lazy to add a new error
code
// Pipe Plugin
CREATE_PIPE_PLUGIN_ERROR(1600),
diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py
b/iotdb-core/ainode/iotdb/ainode/core/config.py
index e465df7e36d..b14efa3bedf 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/config.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/config.py
@@ -46,7 +46,7 @@ from iotdb.ainode.core.constant import (
AINODE_THRIFT_COMPRESSION_ENABLED,
AINODE_VERSION_INFO,
)
-from iotdb.ainode.core.exception import BadNodeUrlError
+from iotdb.ainode.core.exception import BadNodeUrlException
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.util.decorator import singleton
from iotdb.thrift.common.ttypes import TEndPoint
@@ -437,7 +437,7 @@ class AINodeDescriptor(object):
file_configs["ain_cluster_ingress_time_zone"]
)
- except BadNodeUrlError:
+ except BadNodeUrlException:
logger.warning("Cannot load AINode conf file, use default
configuration.")
except Exception as e:
@@ -489,7 +489,7 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
"""
split = endpoint_url.split(":")
if len(split) != 2:
- raise BadNodeUrlError(endpoint_url)
+ raise BadNodeUrlException(endpoint_url)
ip = split[0]
try:
@@ -497,4 +497,4 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
result = TEndPoint(ip, port)
return result
except ValueError:
- raise BadNodeUrlError(endpoint_url)
+ raise BadNodeUrlException(endpoint_url)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py
b/iotdb-core/ainode/iotdb/ainode/core/constant.py
index d8f730c829c..44e76840f73 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/constant.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py
@@ -81,33 +81,18 @@ DEFAULT_CHUNK_SIZE = 8192
class TSStatusCode(Enum):
SUCCESS_STATUS = 200
REDIRECTION_RECOMMEND = 400
- MODEL_EXIST_ERROR = 1502
- MODEL_NOT_FOUND_ERROR = 1505
- UNAVAILABLE_AI_DEVICE_ERROR = 1507
- AINODE_INTERNAL_ERROR = 1510
+ MODEL_EXISTED_ERROR = 1503
+ MODEL_NOT_EXIST_ERROR = 1504
+ CREATE_MODEL_ERROR = 1505
+ DROP_BUILTIN_MODEL_ERROR = 1506
+ DROP_MODEL_ERROR = 1507
+ UNAVAILABLE_AI_DEVICE_ERROR = 1508
+
INVALID_URI_ERROR = 1511
INVALID_INFERENCE_CONFIG = 1512
INFERENCE_INTERNAL_ERROR = 1520
- def get_status_code(self) -> int:
- return self.value
-
+ AINODE_INTERNAL_ERROR = 1599 # In case somebody too lazy to add a new
error code
-class HyperparameterName(Enum):
- # Training hyperparameter
- LEARNING_RATE = "learning_rate"
- EPOCHS = "epochs"
- BATCH_SIZE = "batch_size"
- USE_GPU = "use_gpu"
- NUM_WORKERS = "num_workers"
-
- # Structure hyperparameter
- KERNEL_SIZE = "kernel_size"
- INPUT_VARS = "input_vars"
- BLOCK_TYPE = "block_type"
- D_MODEL = "d_model"
- INNER_LAYERS = "inner_layer"
- OUTER_LAYERS = "outer_layer"
-
- def name(self):
+ def get_status_code(self) -> int:
return self.value
diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py
b/iotdb-core/ainode/iotdb/ainode/core/exception.py
index 30b9d54dcc7..b007ee58c48 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/exception.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py
@@ -23,7 +23,7 @@ from iotdb.ainode.core.model.model_constants import (
)
-class _BaseError(Exception):
+class _BaseException(Exception):
"""Base class for exceptions in this module."""
def __init__(self):
@@ -33,86 +33,74 @@ class _BaseError(Exception):
return self.message
-class BadNodeUrlError(_BaseError):
+class BadNodeUrlException(_BaseException):
def __init__(self, node_url: str):
+ super().__init__()
self.message = "Bad node url: {}".format(node_url)
-class ModelNotExistError(_BaseError):
- def __init__(self, msg: str):
- self.message = "Model is not exists: {} ".format(msg)
-
-
-class BadConfigValueError(_BaseError):
- def __init__(self, config_name: str, config_value, hint: str = ""):
- self.message = "Bad value [{0}] for config {1}. {2}".format(
- config_value, config_name, hint
- )
-
+# ==================== Model Management ====================
-class MissingConfigError(_BaseError):
- def __init__(self, config_name: str):
- self.message = "Missing config: {}".format(config_name)
-
-class MissingOptionError(_BaseError):
- def __init__(self, config_name: str):
- self.message = "Missing task option: {}".format(config_name)
+class ModelExistedException(_BaseException):
+ def __init__(self, model_id: str):
+ super().__init__()
+ self.message = "Model {} already exists".format(model_id)
-class RedundantOptionError(_BaseError):
- def __init__(self, option_name: str):
- self.message = "Redundant task option: {}".format(option_name)
+class ModelNotExistException(_BaseException):
+ def __init__(self, model_id: str):
+ super().__init__()
+ self.message = "Model {} does not exist".format(model_id)
-class WrongTypeConfigError(_BaseError):
- def __init__(self, config_name: str, expected_type: str):
- self.message = "Wrong type for config: {0}, expected: {1}".format(
- config_name, expected_type
+class InvalidModelUriException(_BaseException):
+ def __init__(self, msg: str):
+ super().__init__()
+ self.message = (
+ "Model registration failed because the specified uri is invalid:
{}".format(
+ msg
+ )
)
-class UnsupportedError(_BaseError):
- def __init__(self, msg: str):
- self.message = "{0} is not supported in current version".format(msg)
+class BuiltInModelDeletionException(_BaseException):
+ def __init__(self, model_id: str):
+ super().__init__()
+ self.message = "Cannot delete built-in model: {}".format(model_id)
-class InvalidUriError(_BaseError):
- def __init__(self, uri: str):
- self.message = "Invalid uri: {}, there are no {} or {} under this
uri.".format(
- uri, MODEL_WEIGHTS_FILE_IN_PT, MODEL_CONFIG_FILE_IN_YAML
+class BadConfigValueException(_BaseException):
+ def __init__(self, config_name: str, config_value, hint: str = ""):
+ super().__init__()
+ self.message = "Bad value [{0}] for config {1}. {2}".format(
+ config_value, config_name, hint
)
-class InvalidWindowArgumentError(_BaseError):
- def __init__(self, window_interval, window_step, dataset_length):
- self.message = f"Invalid inference input: window_interval
{window_interval}, window_step {window_step}, dataset_length {dataset_length}"
-
-
-class InferenceModelInternalError(_BaseError):
+class InferenceModelInternalException(_BaseException):
def __init__(self, msg: str):
+ super().__init__()
self.message = "Inference model internal error: {0}".format(msg)
-class BuiltInModelNotSupportError(_BaseError):
+class BuiltInModelNotSupportException(_BaseException):
def __init__(self, msg: str):
+ super().__init__()
self.message = "Built-in model not support: {0}".format(msg)
-class BuiltInModelDeletionError(_BaseError):
- def __init__(self, model_id: str):
- self.message = "Cannot delete built-in model: {0}".format(model_id)
-
-
-class WrongAttributeTypeError(_BaseError):
+class WrongAttributeTypeException(_BaseException):
def __init__(self, attribute_name: str, expected_type: str):
+ super().__init__()
self.message = "Wrong type for attribute: {0}, expected: {1}".format(
attribute_name, expected_type
)
-class NumericalRangeException(_BaseError):
+class NumericalRangeException(_BaseException):
def __init__(self, attribute_name: str, value, min_value, max_value):
+ super().__init__()
self.message = (
"Attribute {0} expect value between {1} and {2}, got {3}
instead.".format(
attribute_name, min_value, max_value, value
@@ -120,35 +108,19 @@ class NumericalRangeException(_BaseError):
)
-class StringRangeException(_BaseError):
+class StringRangeException(_BaseException):
def __init__(self, attribute_name: str, value: str, expect_value):
+ super().__init__()
self.message = "Attribute {0} expect value in {1}, got {2}
instead.".format(
attribute_name, expect_value, value
)
-class ListRangeException(_BaseError):
+class ListRangeException(_BaseException):
def __init__(self, attribute_name: str, value: list, expected_type: str):
+ super().__init__()
self.message = (
"Attribute {0} expect value type list[{1}], got {2}
instead.".format(
attribute_name, expected_type, value
)
)
-
-
-class AttributeNotSupportError(_BaseError):
- def __init__(self, model_name: str, attribute_name: str):
- self.message = "Attribute {0} is not supported in model {1}".format(
- attribute_name, model_name
- )
-
-
-# This is used to extract the key message in RuntimeError instead of the
traceback message
-def runtime_error_extractor(error_message):
- pattern = re.compile(r"RuntimeError: (.+)")
- match = pattern.search(error_message)
-
- if match:
- return match.group(1)
- else:
- return ""
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py
index d06ba55546f..d9a2d8663e5 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py
+++
b/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
#
-from iotdb.ainode.core.exception import InferenceModelInternalError
+from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.dispatcher.abstract_dispatcher import (
AbstractDispatcher,
)
@@ -41,7 +41,7 @@ class BasicDispatcher(AbstractDispatcher):
"""
model_id = req.model_id
if not pool_ids:
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
f"No available pools for model {model_id}"
)
start_idx = hash(req.req_id) % len(pool_ids)
@@ -51,7 +51,7 @@ class BasicDispatcher(AbstractDispatcher):
state = self.pool_states[pool_id]
if state == PoolState.RUNNING:
return pool_id
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
f"No RUNNING pools available for model {model_id}"
)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index 8ffa89ffd67..c580a89916d 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -24,7 +24,7 @@ from typing import Dict, Optional
import torch.multiprocessing as mp
-from iotdb.ainode.core.exception import InferenceModelInternalError
+from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.inference_request import (
InferenceRequest,
InferenceRequestProxy,
@@ -374,7 +374,7 @@ class PoolController:
if not self.has_request_pools(model_id):
logger.error(f"[Inference] No pools found for model {model_id}.")
infer_proxy.set_result(None)
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
"Dispatch request failed, because no inference pools are init."
)
# TODO: Implement adaptive scaling based on requests.(e.g. lazy
initialization)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
index a700dcee473..b85f64d42cc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
@@ -19,7 +19,7 @@ from typing import Dict, Tuple
import torch.multiprocessing as mp
-from iotdb.ainode.core.exception import InferenceModelInternalError
+from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.dispatcher.basic_dispatcher import
BasicDispatcher
from iotdb.ainode.core.inference.inference_request import (
InferenceRequest,
@@ -90,14 +90,14 @@ class PoolGroup:
def get_request_pool(self, pool_id) -> InferenceRequestPool:
if pool_id not in self.pool_group:
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
f"[Inference][Pool-{pool_id}] Pool not found for model
{self.model_id}"
)
return self.pool_group[pool_id][0]
def get_request_queue(self, pool_id) -> mp.Queue:
if pool_id not in self.pool_group:
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
f"[Inference][Pool-{pool_id}] Pool not found for model
{self.model_id}"
)
return self.pool_group[pool_id][1]
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index d2e7292ecd8..21140cafb1f 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -20,7 +20,7 @@ from typing import Dict, List, Optional
import torch
-from iotdb.ainode.core.exception import InferenceModelInternalError
+from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pool_group import PoolGroup
from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import
(
AbstractPoolScheduler,
@@ -113,7 +113,7 @@ class BasicPoolScheduler(AbstractPoolScheduler):
if model_id not in self._request_pool_map:
pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id)
if pool_num <= 0:
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
f"Not enough memory to run model {model_id}."
)
return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index 1ce2e84e059..34c315274f5 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -27,7 +27,7 @@ import torch.multiprocessing as mp
from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.constant import TSStatusCode
from iotdb.ainode.core.exception import (
- InferenceModelInternalError,
+ InferenceModelInternalException,
NumericalRangeException,
)
from iotdb.ainode.core.inference.inference_request import (
@@ -161,7 +161,7 @@ class InferenceManager:
return outputs
except Exception as e:
logger.error(e)
- raise InferenceModelInternalError(str(e))
+ raise InferenceModelInternalException(str(e))
finally:
with self._result_wrapper_lock:
del self._result_wrapper_map[req_id]
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
index 8ffb33d91e2..ef0846c3d78 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
@@ -16,12 +16,16 @@
# under the License.
#
-from typing import Any, List, Optional
+from typing import Optional
from iotdb.ainode.core.constant import TSStatusCode
-from iotdb.ainode.core.exception import BuiltInModelDeletionError
+from iotdb.ainode.core.exception import (
+ BuiltInModelDeletionException,
+ InvalidModelUriException,
+ ModelExistedException,
+ ModelNotExistException,
+)
from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.model_loader import load_model
from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo,
ModelStorage
from iotdb.ainode.core.rpc.status import get_status
from iotdb.ainode.core.util.decorator import singleton
@@ -47,16 +51,15 @@ class ModelManager:
req: TRegisterModelReq,
) -> TRegisterModelResp:
try:
- if self._model_storage.register_model(model_id=req.modelId,
uri=req.uri):
- return
TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS))
- return
TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR))
- except ValueError as e:
+ self._model_storage.register_model(model_id=req.modelId,
uri=req.uri)
+ return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS))
+ except ModelExistedException as e:
return TRegisterModelResp(
- get_status(TSStatusCode.INVALID_URI_ERROR, str(e))
+ get_status(TSStatusCode.MODEL_EXISTED_ERROR, str(e))
)
- except Exception as e:
+ except InvalidModelUriException as e:
return TRegisterModelResp(
- get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
+ get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
)
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
@@ -67,12 +70,12 @@ class ModelManager:
try:
self._model_storage.delete_model(req.modelId)
return get_status(TSStatusCode.SUCCESS_STATUS)
- except BuiltInModelDeletionError as e:
- logger.warning(e)
- return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
+ except ModelNotExistException as e:
+ return get_status(TSStatusCode.MODEL_NOT_EXIST_ERROR, str(e))
+ except BuiltInModelDeletionException as e:
+ return get_status(TSStatusCode.DROP_BUILTIN_MODEL_ERROR, str(e))
except Exception as e:
- logger.warning(e)
- return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
+ return get_status(TSStatusCode.DROP_MODEL_ERROR, str(e))
def get_model_info(
self,
@@ -81,19 +84,9 @@ class ModelManager:
) -> Optional[ModelInfo]:
return self._model_storage.get_model_info(model_id, category)
- def get_model_infos(
- self,
- category: Optional[ModelCategory] = None,
- model_type: Optional[str] = None,
- ) -> List[ModelInfo]:
- return self._model_storage.get_model_infos(category, model_type)
-
def _refresh(self):
"""Refresh the model list (re-scan the file system)"""
self._model_storage.discover_all_models()
- def get_registered_models(self) -> List[str]:
- return self._model_storage.get_registered_models()
-
def is_model_registered(self, model_id: str) -> bool:
return self._model_storage.is_model_registered(model_id)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
index 23a98f26bbf..17516876201 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
@@ -22,7 +22,7 @@ import psutil
import torch
from iotdb.ainode.core.config import AINodeDescriptor
-from iotdb.ainode.core.exception import ModelNotExistError
+from iotdb.ainode.core.exception import ModelNotExistException
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.model_manager import ModelManager
from iotdb.ainode.core.model.model_loader import load_model
@@ -86,7 +86,7 @@ def estimate_pool_size(device: torch.device, model_id: str)
-> int:
logger.error(
f"[Inference] Cannot estimate inference pool size on device:
{device}, because model: {model_id} is not supported."
)
- raise ModelNotExistError(model_id)
+ raise ModelNotExistException(model_id)
system_res = evaluate_system_resources(device)
free_mem = system_res["free_mem"]
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
index c42ec98551b..9f1801b5073 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
@@ -24,13 +24,6 @@ MODEL_WEIGHTS_FILE_IN_PT = "model.pt"
MODEL_CONFIG_FILE_IN_YAML = "config.yaml"
-# Model file constants
-MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors"
-MODEL_CONFIG_FILE_IN_JSON = "config.json"
-MODEL_WEIGHTS_FILE_IN_PT = "model.pt"
-MODEL_CONFIG_FILE_IN_YAML = "config.yaml"
-
-
class ModelCategory(Enum):
BUILTIN = "builtin"
USER_DEFINED = "user_defined"
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
index a6e3b1f7b5e..29a7c14c972 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
@@ -32,7 +32,7 @@ from transformers import (
)
from iotdb.ainode.core.config import AINodeDescriptor
-from iotdb.ainode.core.exception import ModelNotExistError
+from iotdb.ainode.core.exception import ModelNotExistException
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.model_constants import ModelCategory
from iotdb.ainode.core.model.model_info import ModelInfo
@@ -131,7 +131,7 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs):
model_file = os.path.join(model_path, "model.pt")
if not os.path.exists(model_file):
logger.error(f"Model file not found at {model_file}.")
- raise ModelNotExistError(model_file)
+ raise ModelNotExistException(model_file)
model = torch.jit.load(model_file)
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not
acceleration:
return model
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
index 5194ed4df1b..a79371d2e79 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
@@ -23,12 +23,16 @@ import shutil
from pathlib import Path
from typing import Dict, List, Optional
-from huggingface_hub import hf_hub_download, snapshot_download
+from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoModelForCausalLM
from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.constant import TSStatusCode
-from iotdb.ainode.core.exception import BuiltInModelDeletionError
+from iotdb.ainode.core.exception import (
+ BuiltInModelDeletionException,
+ ModelExistedException,
+ ModelNotExistException,
+)
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.model_constants import (
MODEL_CONFIG_FILE_IN_JSON,
@@ -43,6 +47,8 @@ from iotdb.ainode.core.model.model_info import (
ModelInfo,
)
from iotdb.ainode.core.model.utils import (
+ _fetch_model_from_hf_repo,
+ _fetch_model_from_local,
ensure_init_file,
get_parsed_uri,
import_class_from_path,
@@ -233,12 +239,23 @@ class ModelStorage:
# ==================== Registration Methods ====================
- def register_model(self, model_id: str, uri: str) -> bool:
+ def register_model(self, model_id: str, uri: str):
"""
- Supported URI formats:
- - repo://<huggingface_repo_id> (Maybe in the future)
- - file://<local_path>
+ Register a user-defined model from a given URI.
+ Args:
+ model_id (str): Unique identifier for the model.
+ uri (str): URI to fetch the model from.
+ Supported URI formats:
+ - file://<local_path>
+ - repo://<huggingface_repo_id> (Maybe in the future)
+ Raises:
+ ModelExistedException: If the model_id already exists.
+ InvalidModelUriException: If the URI format is invalid.
"""
+
+ if self.is_model_registered(model_id):
+ raise ModelExistedException(model_id)
+
uri_type = parse_uri_type(uri)
parsed_uri = get_parsed_uri(uri)
@@ -249,9 +266,9 @@ class ModelStorage:
ensure_init_file(model_dir)
if uri_type == UriType.REPO:
- self._fetch_model_from_hf_repo(parsed_uri, model_dir)
+ _fetch_model_from_hf_repo(parsed_uri, model_dir)
else:
- self._fetch_model_from_local(os.path.expanduser(parsed_uri),
model_dir)
+ _fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir)
config_path, _ = validate_model_files(model_dir)
config = load_model_config_in_json(config_path)
@@ -272,7 +289,7 @@ class ModelStorage:
self._models[ModelCategory.USER_DEFINED.value][model_id] =
model_info
if auto_map:
- # Transformers model: immediately register to Transformers
auto-loading mechanism
+ # Transformers model: immediately register to Transformers
autoloading mechanism
success = self._register_transformers_model(model_info)
if success:
with self._lock_pool.get_lock(model_id).write_lock():
@@ -281,46 +298,15 @@ class ModelStorage:
with self._lock_pool.get_lock(model_id).write_lock():
model_info.state = ModelStates.INACTIVE
logger.error(f"Failed to register Transformers model
{model_id}")
- return False
else:
# Other type models: only log
self._register_other_model(model_info)
logger.info(f"Successfully registered model {model_id} from URI:
{uri}")
- return True
- def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str):
- logger.info(
- f"Downloading model from HuggingFace repository: {repo_id} ->
{storage_path}"
- )
- # Use snapshot_download to download entire repository (including
config.json and model.safetensors)
- try:
- snapshot_download(
- repo_id=repo_id,
- local_dir=storage_path,
- local_dir_use_symlinks=False,
- )
- except Exception as e:
- logger.error(f"Failed to download model from HuggingFace: {e}")
- raise
-
- def _fetch_model_from_local(self, source_path: str, storage_path: str):
- logger.info(f"Copying model from local path: {source_path} ->
{storage_path}")
- source_dir = Path(source_path)
- if not source_dir.is_dir():
- raise ValueError(
- f"Source path does not exist or is not a directory:
{source_path}"
- )
-
- storage_dir = Path(storage_path)
- for file in source_dir.iterdir():
- if file.is_file():
- shutil.copy2(file, storage_dir / file.name)
- return
-
- def _register_transformers_model(self, model_info: ModelInfo) -> bool:
+ def _register_transformers_model(self, model_info: ModelInfo):
"""
- Register Transformers model to auto-loading mechanism (internal method)
+ Register Transformers model to autoloading mechanism (internal method)
"""
auto_map = model_info.auto_map
if not auto_map:
@@ -350,7 +336,6 @@ class ModelStorage:
logger.info(
f"Registered AutoModelForCausalLM: {config_class.__name__}
-> {auto_model_path}"
)
-
return True
except Exception as e:
logger.warning(
@@ -471,8 +456,16 @@ class ModelStorage:
stateMap=state_map,
)
- def delete_model(self, model_id: str) -> None:
- # Use write lock to protect entire deletion process
+ def delete_model(self, model_id: str):
+ """
+ Delete a user-defined model by model_id.
+ Args:
+ model_id (str): Unique identifier for the model to be deleted.
+ Raises:
+ ModelNotExistException: If the model_id does not exist.
+ BuiltInModelDeletionException: If attempting to delete a built-in
model.
+ Others: Any exceptions raised during file deletion.
+ """
with self._lock_pool.get_lock(model_id).write_lock():
model_info = None
category_value = None
@@ -481,30 +474,25 @@ class ModelStorage:
model_info = category_dict[model_id]
category_value = cat_value
break
-
if not model_info:
logger.warning(f"Model {model_id} does not exist, cannot
delete")
- return
-
+ raise ModelNotExistException(model_id)
if model_info.category == ModelCategory.BUILTIN:
- raise BuiltInModelDeletionError(model_id)
+ logger.warning(f"Model {model_id} is builtin, cannot delete")
+ raise BuiltInModelDeletionException(model_id)
model_info.state = ModelStates.DROPPING
model_path = os.path.join(
self._models_dir, model_info.category.value, model_id
)
- if model_path.exists():
+ if os.path.exists(model_path):
try:
shutil.rmtree(model_path)
- logger.info(f"Deleted model directory: {model_path}")
+ logger.info(f"Model directory is deleted: {model_path}")
except Exception as e:
logger.error(f"Failed to delete model directory
{model_path}: {e}")
- raise
-
- if category_value and model_id in self._models[category_value]:
- del self._models[category_value][model_id]
- logger.info(f"Model {model_id} has been removed from storage")
-
- return
+ raise e
+ del self._models[category_value][model_id]
+ logger.info(f"Model {model_id} has been removed from model
storage")
# ==================== Query Methods ====================
@@ -512,10 +500,14 @@ class ModelStorage:
self, model_id: str, category: Optional[ModelCategory] = None
) -> Optional[ModelInfo]:
"""
- Get single model information
-
- If category is specified, use model_id's lock
- If category is not specified, need to traverse all dictionaries, use
global lock
+ Get specified model information.
+ Args:
+ model_id (str): Unique identifier for the model.
+ category (Optional[ModelCategory]): Category of the model (if
known).
+ Returns:
+ ModelInfo: Information of the specified model.
+ Raises:
+ ModelNotExistException: If the model_id does not exist.
"""
if category:
# Category specified, only need to access specific dictionary, use
model_id's lock
@@ -527,39 +519,7 @@ class ModelStorage:
for category_dict in self._models.values():
if model_id in category_dict:
return category_dict[model_id]
- return None
-
- def get_model_infos(
- self, category: Optional[ModelCategory] = None, model_type:
Optional[str] = None
- ) -> List[ModelInfo]:
- """
- Get model information list
-
- Note: Since we need to traverse all models, use a global lock to
protect the entire dictionary structure
- For single model access, using model_id-based lock would be more
efficient
- """
- matching_models = []
-
- # For traversal operations, we need to protect the entire dictionary
structure
- # Use a special lock (using empty string as key) to protect the entire
dictionary
- with self._lock_pool.get_lock("").read_lock():
- if category and model_type:
- for model_info in self._models[category.value].values():
- if model_info.model_type == model_type:
- matching_models.append(model_info)
- return matching_models
- elif category:
- return list(self._models[category.value].values())
- elif model_type:
- for category_dict in self._models.values():
- for model_info in category_dict.values():
- if model_info.model_type == model_type:
- matching_models.append(model_info)
- return matching_models
- else:
- for category_dict in self._models.values():
- matching_models.extend(category_dict.values())
- return matching_models
+ raise ModelNotExistException(model_id)
def is_model_registered(self, model_id: str) -> bool:
"""Check if model is registered (search in _models)"""
@@ -572,11 +532,3 @@ class ModelStorage:
if model_id in category_dict:
return True
return False
-
- def get_registered_models(self) -> List[str]:
- """Get list of all registered model IDs"""
- with self._lock_pool.get_lock("").read_lock():
- model_ids = []
- for category_dict in self._models.values():
- model_ids.extend(category_dict.keys())
- return model_ids
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
index 261de3c9abe..d9d20545af6 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
@@ -20,11 +20,11 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Union
from iotdb.ainode.core.exception import (
- BuiltInModelNotSupportError,
+ BuiltInModelNotSupportException,
ListRangeException,
NumericalRangeException,
StringRangeException,
- WrongAttributeTypeError,
+ WrongAttributeTypeException,
)
from iotdb.ainode.core.log import Logger
@@ -49,7 +49,7 @@ class AttributeConfig:
if value is None:
return True # Allow None for optional int parameters
if not isinstance(value, int):
- raise WrongAttributeTypeError(self.name, "int")
+ raise WrongAttributeTypeException(self.name, "int")
if self.low is not None and self.high is not None:
if not (self.low <= value <= self.high):
raise NumericalRangeException(self.name, value, self.low,
self.high)
@@ -57,7 +57,7 @@ class AttributeConfig:
if value is None:
return True # Allow None for optional float parameters
if not isinstance(value, (int, float)):
- raise WrongAttributeTypeError(self.name, "float")
+ raise WrongAttributeTypeException(self.name, "float")
value = float(value)
if self.low is not None and self.high is not None:
if not (self.low <= value <= self.high):
@@ -66,26 +66,26 @@ class AttributeConfig:
if value is None:
return True # Allow None for optional str parameters
if not isinstance(value, str):
- raise WrongAttributeTypeError(self.name, "str")
+ raise WrongAttributeTypeException(self.name, "str")
if self.choices and value not in self.choices:
raise StringRangeException(self.name, value, self.choices)
elif self.type == "bool":
if value is None:
return True # Allow None for optional bool parameters
if not isinstance(value, bool):
- raise WrongAttributeTypeError(self.name, "bool")
+ raise WrongAttributeTypeException(self.name, "bool")
elif self.type == "list":
if not isinstance(value, list):
- raise WrongAttributeTypeError(self.name, "list")
+ raise WrongAttributeTypeException(self.name, "list")
for item in value:
if not isinstance(item, self.value_type):
- raise WrongAttributeTypeError(self.name, self.value_type)
+ raise WrongAttributeTypeException(self.name,
self.value_type)
elif self.type == "tuple":
if not isinstance(value, tuple):
- raise WrongAttributeTypeError(self.name, "tuple")
+ raise WrongAttributeTypeException(self.name, "tuple")
for item in value:
if not isinstance(item, self.value_type):
- raise WrongAttributeTypeError(self.name, self.value_type)
+ raise WrongAttributeTypeException(self.name,
self.value_type)
return True
def parse(self, string_value: str):
@@ -96,14 +96,14 @@ class AttributeConfig:
try:
return int(string_value)
except:
- raise WrongAttributeTypeError(self.name, "int")
+ raise WrongAttributeTypeException(self.name, "int")
elif self.type == "float":
if string_value.lower() == "none" or string_value.strip() == "":
return None
try:
return float(string_value)
except:
- raise WrongAttributeTypeError(self.name, "float")
+ raise WrongAttributeTypeException(self.name, "float")
elif self.type == "str":
if string_value.lower() == "none" or string_value.strip() == "":
return None
@@ -116,14 +116,14 @@ class AttributeConfig:
elif string_value.lower() == "none" or string_value.strip() == "":
return None
else:
- raise WrongAttributeTypeError(self.name, "bool")
+ raise WrongAttributeTypeException(self.name, "bool")
elif self.type == "list":
try:
list_value = eval(string_value)
except:
- raise WrongAttributeTypeError(self.name, "list")
+ raise WrongAttributeTypeException(self.name, "list")
if not isinstance(list_value, list):
- raise WrongAttributeTypeError(self.name, "list")
+ raise WrongAttributeTypeException(self.name, "list")
for i in range(len(list_value)):
try:
list_value[i] = self.value_type(list_value[i])
@@ -136,9 +136,9 @@ class AttributeConfig:
try:
tuple_value = eval(string_value)
except:
- raise WrongAttributeTypeError(self.name, "tuple")
+ raise WrongAttributeTypeException(self.name, "tuple")
if not isinstance(tuple_value, tuple):
- raise WrongAttributeTypeError(self.name, "tuple")
+ raise WrongAttributeTypeException(self.name, "tuple")
list_value = list(tuple_value)
for i in range(len(list_value)):
try:
@@ -390,7 +390,7 @@ def get_attributes(model_id: str) -> Dict[str,
AttributeConfig]:
"""Get attribute configuration for Sktime model"""
model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else
model_id
if model_id not in MODEL_CONFIGS:
- raise BuiltInModelNotSupportError(model_id)
+ raise BuiltInModelNotSupportException(model_id)
return MODEL_CONFIGS[model_id]
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
index eca812d35ec..9ddbcab286f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
@@ -30,8 +30,8 @@ from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.trend import STLForecaster
from iotdb.ainode.core.exception import (
- BuiltInModelNotSupportError,
- InferenceModelInternalError,
+ BuiltInModelNotSupportException,
+ InferenceModelInternalException,
)
from iotdb.ainode.core.log import Logger
@@ -66,7 +66,7 @@ class ForecastingModel(SktimeModel):
output = self._model.predict(fh=range(predict_length))
return np.array(output, dtype=np.float64)
except Exception as e:
- raise InferenceModelInternalError(str(e))
+ raise InferenceModelInternalException(str(e))
class DetectionModel(SktimeModel):
@@ -82,7 +82,7 @@ class DetectionModel(SktimeModel):
else:
return np.array(output, dtype=np.int32)
except Exception as e:
- raise InferenceModelInternalError(str(e))
+ raise InferenceModelInternalException(str(e))
class ArimaModel(ForecastingModel):
@@ -155,7 +155,7 @@ class STRAYModel(DetectionModel):
scaled_data = pd.Series(scaled_data.flatten())
return super().generate(scaled_data, **kwargs)
except Exception as e:
- raise InferenceModelInternalError(str(e))
+ raise InferenceModelInternalException(str(e))
# Model factory mapping
@@ -176,5 +176,5 @@ def create_sktime_model(model_id: str, **kwargs) ->
SktimeModel:
attributes = update_attribute({**kwargs}, get_attributes(model_id.upper()))
model_class = _MODEL_FACTORY.get(model_id.upper())
if model_class is None:
- raise BuiltInModelNotSupportError(model_id)
+ raise BuiltInModelNotSupportException(model_id)
return model_class(attributes)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
index 85b6f7db2ff..ee128802d24 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
@@ -18,7 +18,7 @@
import torch
-from iotdb.ainode.core.exception import InferenceModelInternalError
+from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pipeline.basic_pipeline import
ForecastPipeline
@@ -28,7 +28,7 @@ class SundialPipeline(ForecastPipeline):
def _preprocess(self, inputs):
if len(inputs.shape) != 2:
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
f"[Inference] Input shape must be: [batch_size, seq_len], but
receives {inputs.shape}"
)
return inputs
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
index c0f00b1f5ca..65c6cdd74cd 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
@@ -18,7 +18,7 @@
import torch
-from iotdb.ainode.core.exception import InferenceModelInternalError
+from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pipeline.basic_pipeline import
ForecastPipeline
@@ -28,7 +28,7 @@ class TimerPipeline(ForecastPipeline):
def _preprocess(self, inputs):
if len(inputs.shape) != 2:
- raise InferenceModelInternalError(
+ raise InferenceModelInternalException(
f"[Inference] Input shape must be: [batch_size, seq_len], but
receives {inputs.shape}"
)
return inputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
index 1cd0ee44912..815232c52b0 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
@@ -19,25 +19,31 @@
import importlib
import json
import os.path
+import shutil
import sys
from contextlib import contextmanager
+from pathlib import Path
from typing import Dict, Tuple
+from huggingface_hub import snapshot_download
+
+from iotdb.ainode.core.exception import InvalidModelUriException
+from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.model_constants import (
MODEL_CONFIG_FILE_IN_JSON,
MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
UriType,
)
+logger = Logger()
+
def parse_uri_type(uri: str) -> UriType:
- if uri.startswith("repo://"):
- return UriType.REPO
- elif uri.startswith("file://"):
+ if uri.startswith("file://"):
return UriType.FILE
else:
- raise ValueError(
- f"Unsupported URI type: {uri}. Supported formats: repo:// or
file://"
+ raise InvalidModelUriException(
+ f"Unknown uri type {uri}, currently supporting formats: file://"
)
@@ -70,9 +76,13 @@ def validate_model_files(model_dir: str) -> Tuple[str, str]:
weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
if not os.path.exists(config_path):
- raise ValueError(f"Model config file does not exist: {config_path}")
+ raise InvalidModelUriException(
+ f"Model config file does not exist: {config_path}"
+ )
if not os.path.exists(weights_path):
- raise ValueError(f"Model weights file does not exist: {weights_path}")
+ raise InvalidModelUriException(
+ f"Model weights file does not exist: {weights_path}"
+ )
# Create __init__.py file to ensure model directory can be imported as a
module
init_file = os.path.join(model_dir, "__init__.py")
@@ -96,3 +106,32 @@ def ensure_init_file(dir_path: str):
if not os.path.exists(init_file):
with open(init_file, "w"):
pass
+
+
+def _fetch_model_from_local(source_path: str, storage_path: str):
+ logger.info(f"Copying model from local path: {source_path} ->
{storage_path}")
+ source_dir = Path(source_path)
+ if not source_dir.exists():
+ raise InvalidModelUriException(f"Source path does not exist:
{source_path}")
+ if not source_dir.is_dir():
+ raise InvalidModelUriException(f"Source path is not a directory:
{source_path}")
+ storage_dir = Path(storage_path)
+ for file in source_dir.iterdir():
+ if file.is_file():
+ shutil.copy2(file, storage_dir / file.name)
+
+
+def _fetch_model_from_hf_repo(repo_id: str, storage_path: str):
+ logger.info(
+ f"Downloading model from HuggingFace repository: {repo_id} ->
{storage_path}"
+ )
+ # Use snapshot_download to download entire repository (including
config.json and model.safetensors)
+ try:
+ snapshot_download(
+ repo_id=repo_id,
+ local_dir=storage_path,
+ local_dir_use_symlinks=False,
+ )
+ except Exception as e:
+ logger.error(f"Failed to download model from HuggingFace: {e}")
+ raise
diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
index 6c4eedeb99f..492802fc060 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
@@ -40,7 +40,7 @@ from iotdb.thrift.ainode.ttypes import (
TShowLoadedModelsResp,
TShowModelsReq,
TShowModelsResp,
- TTrainingReq,
+ TTuningReq,
TUnloadModelReq,
)
from iotdb.thrift.common.ttypes import TSStatus
@@ -56,8 +56,8 @@ def _ensure_device_id_is_available(device_id_list: list[str])
-> TSStatus:
for device_id in device_id_list:
if device_id not in available_devices:
return TSStatus(
- code=TSStatusCode.INVALID_URI_ERROR.value,
- message=f"Device ID [{device_id}] is not available. You can
use 'SHOW AI_DEVICES' to retrieve the available devices.",
+ code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
+ message=f"AIDevice ID [{device_id}] is not available. You can
use 'SHOW AI_DEVICES' to retrieve the available devices.",
)
return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
@@ -68,7 +68,9 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
self._model_manager = ModelManager()
self._inference_manager = InferenceManager()
- def stop(self) -> None:
+ # ==================== Cluster Management ====================
+
+ def stop(self):
logger.info("Stopping the RPC service handler of IoTDB-AINode...")
self._inference_manager.stop()
@@ -76,6 +78,17 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
self._ainode.stop()
return get_status(TSStatusCode.SUCCESS_STATUS, "AINode stopped
successfully.")
+ def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+ return ClusterManager.get_heart_beat(req)
+
+ def showAIDevices(self) -> TShowAIDevicesResp:
+ return TShowAIDevicesResp(
+ status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value),
+ deviceIdList=get_available_devices(),
+ )
+
+ # ==================== Model Management ====================
+
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
return self._model_manager.register_model(req)
@@ -109,11 +122,15 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
return TShowLoadedModelsResp(status=status,
deviceLoadedModelsMap={})
return self._inference_manager.show_loaded_models(req)
- def showAIDevices(self) -> TShowAIDevicesResp:
- return TShowAIDevicesResp(
- status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value),
- deviceIdList=get_available_devices(),
- )
+ def _ensure_model_is_registered(self, model_id: str) -> TSStatus:
+ if not self._model_manager.is_model_registered(model_id):
+ return TSStatus(
+ code=TSStatusCode.MODEL_NOT_EXIST_ERROR.value,
+ message=f"Model [{model_id}] is not registered yet. You can
use 'SHOW MODELS' to retrieve the available models.",
+ )
+ return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
+
+ # ==================== Inference ====================
def inference(self, req: TInferenceReq) -> TInferenceResp:
status = self._ensure_model_is_registered(req.modelId)
@@ -127,16 +144,7 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
return TForecastResp(status, [])
return self._inference_manager.forecast(req)
- def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
- return ClusterManager.get_heart_beat(req)
+ # ==================== Tuning ====================
- def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
+ def createTuningTask(self, req: TTuningReq) -> TSStatus:
pass
-
- def _ensure_model_is_registered(self, model_id: str) -> TSStatus:
- if not self._model_manager.is_model_registered(model_id):
- return TSStatus(
- code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value,
- message=f"Model [{model_id}] is not registered yet. You can
use 'SHOW MODELS' to retrieve the available models.",
- )
- return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
index f8188209a37..9c6020019fc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
@@ -21,7 +21,7 @@ from enum import Enum
import numpy as np
import pandas as pd
-from iotdb.ainode.core.exception import BadConfigValueError
+from iotdb.ainode.core.exception import BadConfigValueException
class TSDataType(Enum):
@@ -122,7 +122,7 @@ def _get_type_in_byte(data_type: pd.Series):
elif data_type == "text":
return b"\x05"
else:
- raise BadConfigValueError(
+ raise BadConfigValueException(
"data_type",
data_type,
"data_type should be in ['bool', 'int32', 'int64', 'float32',
'float64', 'text']",
@@ -138,7 +138,7 @@ def get_data_type_byte_from_str(value):
byte: corresponding data type in [b'\x00', b'\x01', b'\x02', b'\x03',
b'\x04', b'\x05']
"""
if value not in ["bool", "int32", "int64", "float32", "float64", "text"]:
- raise BadConfigValueError(
+ raise BadConfigValueException(
"data_type",
value,
"data_type should be in ['bool', 'int32', 'int64', 'float32',
'float64', 'text']",
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java
index 03402d30c64..b8d98c9c1d7 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java
@@ -23,6 +23,6 @@ import org.apache.iotdb.rpc.TSStatusCode;
public class GetModelInfoException extends ModelException {
public GetModelInfoException(String message) {
- super(message, TSStatusCode.GET_MODEL_INFO_ERROR);
+ super(message, TSStatusCode.AINODE_INTERNAL_ERROR);
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java
deleted file mode 100644
index 38a5105cded..00000000000
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.exception.ainode;
-
-import org.apache.iotdb.rpc.TSStatusCode;
-
-public class ModelNotFoundException extends ModelException {
- public ModelNotFoundException(String message) {
- super(message, TSStatusCode.MODEL_NOT_FOUND_ERROR);
- }
-}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java
index 5eaffc40af9..ffad889a223 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java
@@ -35,7 +35,7 @@ import
org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
+import org.apache.iotdb.ainode.rpc.thrift.TTuningReq;
import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq;
import org.apache.iotdb.common.rpc.thrift.TAINodeLocation;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
@@ -129,8 +129,8 @@ public class AINodeClient implements
IAINodeRPCService.Iface, AutoCloseable, Thr
}
@Override
- public TSStatus createTrainingTask(TTrainingReq req) throws TException {
- return executeRemoteCallWithRetry(() -> client.createTrainingTask(req));
+ public TSStatus createTuningTask(TTuningReq req) throws TException {
+ return executeRemoteCallWithRetry(() -> client.createTuningTask(req));
}
@Override
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
index 847b5d62880..bc52a3a7e8c 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
@@ -62,7 +62,7 @@ import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowPipePl
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowRegionTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateModelTask;
-import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTuningTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.DropModelTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.LoadModelTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowAIDevicesTask;
@@ -1520,7 +1520,7 @@ public class TableConfigTaskVisitor extends
AstVisitor<IConfigTask, MPPQueryCont
protected IConfigTask visitCreateTraining(CreateTraining node,
MPPQueryContext context) {
context.setQueryType(QueryType.WRITE);
accessControl.checkUserGlobalSysPrivilege(context);
- return new CreateTrainingTask(
+ return new CreateTuningTask(
node.getModelId(), node.getParameters(), node.getExistingModelId(),
node.getTargetSql());
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
index cb4d05f1b99..ed5e2e434f4 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
@@ -68,7 +68,7 @@ import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTrigge
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.UnSetTTLTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateModelTask;
-import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask;
+import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTuningTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.DropModelTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.LoadModelTask;
import
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowAIDevicesTask;
@@ -926,7 +926,7 @@ public class TreeConfigTaskVisitor extends
StatementVisitor<IConfigTask, MPPQuer
for (PartialPath partialPath : partialPathList) {
targetPathPatterns.add(partialPath.getFullPath());
}
- return new CreateTrainingTask(
+ return new CreateTuningTask(
createTrainingStatement.getModelId(),
createTrainingStatement.getParameters(),
createTrainingStatement.getTargetTimeRanges(),
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 01f6757f02e..4d99a3ed892 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -28,7 +28,7 @@ import
org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
+import org.apache.iotdb.ainode.rpc.thrift.TTuningReq;
import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq;
import org.apache.iotdb.common.rpc.thrift.FunctionType;
import org.apache.iotdb.common.rpc.thrift.Model;
@@ -3727,7 +3727,7 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
}
@Override
- public SettableFuture<ConfigTaskResult> createTraining(
+ public SettableFuture<ConfigTaskResult> createTuningTask(
String modelId,
boolean isTableModel,
Map<String, String> parameters,
@@ -3736,9 +3736,9 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
@Nullable String targetSql,
@Nullable List<String> pathList) {
final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
- try (final AINodeClient ai =
+ try (final AINodeClient aiNodeClient =
AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER))
{
- final TTrainingReq req = new TTrainingReq();
+ final TTuningReq req = new TTuningReq();
req.setModelId(modelId);
req.setParameters(parameters);
if (existingModelId != null) {
@@ -3747,7 +3747,7 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
if (existingModelId != null) {
req.setExistingModelId(existingModelId);
}
- final TSStatus status = ai.createTrainingTask(req);
+ final TSStatus status = aiNodeClient.createTuningTask(req);
if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) {
future.setException(new IoTDBException(status));
} else {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
index 85455756b05..2d7c0f6d1f3 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
@@ -443,7 +443,7 @@ public interface IConfigTaskExecutor {
SettableFuture<ConfigTaskResult> unloadModel(String existingModelId,
List<String> deviceIdList);
- SettableFuture<ConfigTaskResult> createTraining(
+ SettableFuture<ConfigTaskResult> createTuningTask(
String modelId,
boolean isTableModel,
Map<String, String> parameters,
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java
similarity index 93%
rename from
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
rename to
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java
index 9c93c5b7577..a66b7e700b7 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java
@@ -28,7 +28,7 @@ import com.google.common.util.concurrent.ListenableFuture;
import java.util.List;
import java.util.Map;
-public class CreateTrainingTask implements IConfigTask {
+public class CreateTuningTask implements IConfigTask {
private final String modelId;
private final boolean isTableModel;
@@ -43,7 +43,7 @@ public class CreateTrainingTask implements IConfigTask {
private List<List<Long>> timeRanges;
// For table model
- public CreateTrainingTask(
+ public CreateTuningTask(
String modelId, Map<String, String> parameters, String existingModelId,
String targetSql) {
this.modelId = modelId;
this.parameters = parameters;
@@ -53,7 +53,7 @@ public class CreateTrainingTask implements IConfigTask {
}
// For tree model
- public CreateTrainingTask(
+ public CreateTuningTask(
String modelId,
Map<String, String> parameters,
List<List<Long>> timeRanges,
@@ -71,7 +71,7 @@ public class CreateTrainingTask implements IConfigTask {
@Override
public ListenableFuture<ConfigTaskResult> execute(IConfigTaskExecutor
configTaskExecutor)
throws InterruptedException {
- return configTaskExecutor.createTraining(
+ return configTaskExecutor.createTuningTask(
modelId, isTableModel, parameters, timeRanges, existingModelId,
targetSql, targetPaths);
}
}
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 1dc2f025f5c..cda356a948e 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -73,7 +73,7 @@ struct IDataSchema {
2: optional list<i64> timeRange
}
-struct TTrainingReq {
+struct TTuningReq {
1: required string dbType
2: required string modelId
3: required string existingModelId
@@ -131,30 +131,27 @@ struct TUnloadModelReq {
service IAINodeRPCService {
- // -------------- For Config Node --------------
common.TSStatus stopAINode()
+ TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req)
+
+ TShowAIDevicesResp showAIDevices()
+
TShowModelsResp showModels(TShowModelsReq req)
TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req)
- TShowAIDevicesResp showAIDevices()
-
common.TSStatus deleteModel(TDeleteModelReq req)
TRegisterModelResp registerModel(TRegisterModelReq req)
- TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req)
-
- common.TSStatus createTrainingTask(TTrainingReq req)
-
common.TSStatus loadModel(TLoadModelReq req)
common.TSStatus unloadModel(TUnloadModelReq req)
- // -------------- For Data Node --------------
-
TInferenceResp inference(TInferenceReq req)
TForecastResp forecast(TForecastReq req)
+
+ common.TSStatus createTuningTask(TTuningReq req)
}
\ No newline at end of file