This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch inference_max_length in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit fa8ccb17dbe144d300cc54f05ff8f3d3b1800afa Author: Yongzao <[email protected]> AuthorDate: Mon Jul 21 14:45:43 2025 +0800 finish --- iotdb-core/ainode/ainode/core/ainode.py | 2 +- iotdb-core/ainode/ainode/core/config.py | 12 ++++++++++++ iotdb-core/ainode/ainode/core/constant.py | 1 + .../ainode/core/inference/inference_request_pool.py | 5 +++-- iotdb-core/ainode/ainode/core/inference/utils.py | 2 +- .../ainode/ainode/core/manager/inference_manager.py | 21 ++++++++++++++++++--- iotdb-core/ainode/ainode/core/rpc/handler.py | 6 +++--- iotdb-core/ainode/poetry.lock | 10 +++++----- 8 files changed, 44 insertions(+), 15 deletions(-) diff --git a/iotdb-core/ainode/ainode/core/ainode.py b/iotdb-core/ainode/ainode/core/ainode.py index 6380094b98a..8888c3cb853 100644 --- a/iotdb-core/ainode/ainode/core/ainode.py +++ b/iotdb-core/ainode/ainode/core/ainode.py @@ -136,7 +136,7 @@ class AINode: raise e # Start the RPC service - self._rpc_handler = AINodeRPCServiceHandler(aiNode=self) + self._rpc_handler = AINodeRPCServiceHandler(ainode=self) self._rpc_service = AINodeRPCService(self._rpc_handler) self._rpc_service.start() self._rpc_service.join(1) diff --git a/iotdb-core/ainode/ainode/core/config.py b/iotdb-core/ainode/ainode/core/config.py index 6f0336ad6ae..2c73794faaa 100644 --- a/iotdb-core/ainode/ainode/core/config.py +++ b/iotdb-core/ainode/ainode/core/config.py @@ -31,6 +31,7 @@ from ainode.core.constant import ( AINODE_CONF_GIT_FILE_NAME, AINODE_CONF_POM_FILE_NAME, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, + AINODE_INFERENCE_MAX_PREDICT_LENGTH, AINODE_INFERENCE_RPC_ADDRESS, AINODE_INFERENCE_RPC_PORT, AINODE_LOG_DIR, @@ -59,6 +60,9 @@ class AINodeConfig(object): self._ain_inference_batch_interval_in_ms: int = ( AINODE_INFERENCE_BATCH_INTERVAL_IN_MS ) + self._ain_inference_max_predict_length: int = ( + AINODE_INFERENCE_MAX_PREDICT_LENGTH + ) # log directory self._ain_logs_dir: str = AINODE_LOG_DIR @@ -144,6 +148,14 @@ class AINodeConfig(object): ) -> None: self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms + def get_ain_inference_max_predict_length(self) -> int: + return self._ain_inference_max_predict_length + + def set_ain_inference_max_predict_length( + self, ain_inference_max_predict_length: int + ) -> None: + self._ain_inference_max_predict_length = ain_inference_max_predict_length + def get_ain_logs_dir(self) -> str: return self._ain_logs_dir diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index 1521ca52fe2..8d3bd6b9592 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -34,6 +34,7 @@ AINODE_SYSTEM_FILE_NAME = "system.properties" AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1" AINODE_INFERENCE_RPC_PORT = 10810 AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 +AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 # AINode folder structure AINODE_MODELS_DIR = "data/ainode/models" diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py index ce40bd16859..d0cac2760ab 100644 --- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py @@ -23,7 +23,7 @@ import time import numpy as np import torch import torch.multiprocessing as mp -from transformers import PretrainedConfig, PreTrainedModel +from transformers import PretrainedConfig from ainode.core.config import AINodeDescriptor from ainode.core.inference.inference_request import InferenceRequest @@ -46,7 +46,7 @@ class InferenceRequestPool(mp.Process): def __init__( self, pool_id: int, - model_id: int, + model_id: str, config: PretrainedConfig, request_queue: mp.Queue, result_queue: mp.Queue, @@ -58,6 +58,7 @@ class InferenceRequestPool(mp.Process): self.config = config self.pool_kwargs = pool_kwargs self.model = None + self._model_manager = None self.device = None # TODO: A scheduler is necessary for better handling following queues diff --git a/iotdb-core/ainode/ainode/core/inference/utils.py b/iotdb-core/ainode/ainode/core/inference/utils.py index c2a618d716c..cf10b5b2cd4 100644 --- a/iotdb-core/ainode/ainode/core/inference/utils.py +++ b/iotdb-core/ainode/ainode/core/inference/utils.py @@ -22,7 +22,7 @@ import torch from transformers.modeling_outputs import MoeCausalLMOutputWithPast -def _generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str: +def generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str: """ Generate a random req_id string of specified length. The length is 10 by default, with 10^{17} possible combinations. diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index bed39c48d08..5a853ac4e72 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -30,6 +30,7 @@ from ainode.core.constant import TSStatusCode from ainode.core.exception import ( InferenceModelInternalError, InvalidWindowArgumentError, + NumericalRangeException, runtime_error_extractor, ) from ainode.core.inference.inference_request import ( @@ -40,7 +41,7 @@ from ainode.core.inference.inference_request_pool import InferenceRequestPool from ainode.core.inference.strategy.timer_sundial_inference_pipeline import ( TimerSundialInferencePipeline, ) -from ainode.core.inference.utils import _generate_req_id +from ainode.core.inference.utils import generate_req_id from ainode.core.log import Logger from ainode.core.manager.model_manager import ModelManager from ainode.core.model.sundial.configuration_sundial import SundialConfig @@ -214,6 +215,20 @@ class InferenceManager: full_data = deserializer(raw) inference_attrs = extract_attrs(req) + predict_length = inference_attrs.get("predict_length", 96) + if ( + predict_length + > AINodeDescriptor().get_config().get_ain_inference_max_predict_length() + ): + raise NumericalRangeException( + "output_length", + 1, + AINodeDescriptor() + .get_config() + .get_ain_inference_max_predict_length(), + predict_length, + ) + if model_id == self.ACCELERATE_MODEL_ID and self.DEFAULT_POOL_SIZE > 0: # TODO: Logic in this branch shall handle all LTSM inferences # TODO: TSBlock -> Tensor codes should be unified @@ -223,10 +238,10 @@ class InferenceManager: # the inputs should be on CPU before passing to the inference request inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") infer_req = InferenceRequest( - req_id=_generate_req_id(), + req_id=generate_req_id(), inputs=inputs, inference_pipeline=TimerSundialInferencePipeline(SundialConfig()), - max_new_tokens=inference_attrs.get("predict_length", 96), + max_new_tokens=predict_length, ) infer_proxy = InferenceRequestProxy(infer_req.req_id) with self._result_wrapper_lock: diff --git a/iotdb-core/ainode/ainode/core/rpc/handler.py b/iotdb-core/ainode/ainode/core/rpc/handler.py index cb25420ae00..d3948020ab3 100644 --- a/iotdb-core/ainode/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/ainode/core/rpc/handler.py @@ -41,13 +41,13 @@ logger = Logger() class AINodeRPCServiceHandler(IAINodeRPCService.Iface): - def __init__(self, aiNode): - self._aiNode = aiNode + def __init__(self, ainode): + self._ainode = ainode self._model_manager = ModelManager() self._inference_manager = InferenceManager() def stopAINode(self) -> TSStatus: - self._aiNode.stop() + self._ainode.stop() return get_status(TSStatusCode.SUCCESS_STATUS, "AINode stopped successfully.") def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: diff --git a/iotdb-core/ainode/poetry.lock b/iotdb-core/ainode/poetry.lock index 7d8b035e94f..988df41bd80 100644 --- a/iotdb-core/ainode/poetry.lock +++ b/iotdb-core/ainode/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "alembic" @@ -809,7 +809,7 @@ description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "python_version <= \"3.11\"" +markers = "python_version < \"3.11\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -848,7 +848,7 @@ description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.9" groups = ["main"] -markers = "python_version == \"3.12\"" +markers = "python_version >= \"3.11\"" files = [ {file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"}, {file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"}, @@ -1653,7 +1653,7 @@ description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.8" groups = ["main"] -markers = "python_version <= \"3.11\"" +markers = "python_version < \"3.11\"" files = [ {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"}, {file = "scipy-1.9.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd"}, @@ -1693,7 +1693,7 @@ description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.11" groups = ["main"] -markers = "python_version == \"3.12\"" +markers = "python_version >= \"3.11\"" files = [ {file = "scipy-1.16.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:deec06d831b8f6b5fb0b652433be6a09db29e996368ce5911faf673e78d20085"}, {file = "scipy-1.16.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d30c0fe579bb901c61ab4bb7f3eeb7281f0d4c4a7b52dbf563c89da4fd2949be"},
