This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new f1870cb9b06 [AINode] Limit max inference length (#15982)
f1870cb9b06 is described below

commit f1870cb9b06c2fe9cedaa2a3586fd706e256791d
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 23 12:05:42 2025 +0800

    [AINode] Limit max inference length (#15982)
---
 iotdb-core/ainode/ainode/core/ainode.py             |  2 +-
 iotdb-core/ainode/ainode/core/config.py             | 12 ++++++++++++
 iotdb-core/ainode/ainode/core/constant.py           |  1 +
 .../ainode/core/inference/inference_request_pool.py |  5 +++--
 iotdb-core/ainode/ainode/core/inference/utils.py    |  2 +-
 .../ainode/ainode/core/manager/inference_manager.py | 21 ++++++++++++++++++---
 iotdb-core/ainode/ainode/core/rpc/handler.py        |  6 +++---
 7 files changed, 39 insertions(+), 10 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/ainode.py 
b/iotdb-core/ainode/ainode/core/ainode.py
index d18ac21464e..82c6f6988e2 100644
--- a/iotdb-core/ainode/ainode/core/ainode.py
+++ b/iotdb-core/ainode/ainode/core/ainode.py
@@ -134,7 +134,7 @@ class AINode:
                 raise e
 
         # Start the RPC service
-        self._rpc_handler = AINodeRPCServiceHandler(aiNode=self)
+        self._rpc_handler = AINodeRPCServiceHandler(ainode=self)
         self._rpc_service = AINodeRPCService(self._rpc_handler)
         self._rpc_service.start()
         self._rpc_service.join(1)
diff --git a/iotdb-core/ainode/ainode/core/config.py 
b/iotdb-core/ainode/ainode/core/config.py
index d347379570b..5126d0e53e3 100644
--- a/iotdb-core/ainode/ainode/core/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -31,6 +31,7 @@ from ainode.core.constant import (
     AINODE_CONF_GIT_FILE_NAME,
     AINODE_CONF_POM_FILE_NAME,
     AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
+    AINODE_INFERENCE_MAX_PREDICT_LENGTH,
     AINODE_LOG_DIR,
     AINODE_MODELS_DIR,
     AINODE_ROOT_CONF_DIRECTORY_NAME,
@@ -72,6 +73,9 @@ class AINodeConfig(object):
         self._ain_inference_batch_interval_in_ms: int = (
             AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
         )
+        self._ain_inference_max_predict_length: int = (
+            AINODE_INFERENCE_MAX_PREDICT_LENGTH
+        )
 
         # log directory
         self._ain_logs_dir: str = AINODE_LOG_DIR
@@ -140,6 +144,14 @@ class AINodeConfig(object):
     ) -> None:
         self._ain_inference_batch_interval_in_ms = 
ain_inference_batch_interval_in_ms
 
+    def get_ain_inference_max_predict_length(self) -> int:
+        return self._ain_inference_max_predict_length
+
+    def set_ain_inference_max_predict_length(
+        self, ain_inference_max_predict_length: int
+    ) -> None:
+        self._ain_inference_max_predict_length = 
ain_inference_max_predict_length
+
     def get_ain_logs_dir(self) -> str:
         return self._ain_logs_dir
 
diff --git a/iotdb-core/ainode/ainode/core/constant.py 
b/iotdb-core/ainode/ainode/core/constant.py
index bd414ce1253..bd5646b3513 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -50,6 +50,7 @@ DEFAULT_RECONNECT_TIMES = 3
 
 # AINode inference configuration
 AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
+AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880
 
 # AINode folder structure
 AINODE_ROOT_DIR = os.path.dirname(
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
index ce40bd16859..d0cac2760ab 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -23,7 +23,7 @@ import time
 import numpy as np
 import torch
 import torch.multiprocessing as mp
-from transformers import PretrainedConfig, PreTrainedModel
+from transformers import PretrainedConfig
 
 from ainode.core.config import AINodeDescriptor
 from ainode.core.inference.inference_request import InferenceRequest
@@ -46,7 +46,7 @@ class InferenceRequestPool(mp.Process):
     def __init__(
         self,
         pool_id: int,
-        model_id: int,
+        model_id: str,
         config: PretrainedConfig,
         request_queue: mp.Queue,
         result_queue: mp.Queue,
@@ -58,6 +58,7 @@ class InferenceRequestPool(mp.Process):
         self.config = config
         self.pool_kwargs = pool_kwargs
         self.model = None
+        self._model_manager = None
         self.device = None
 
         # TODO: A scheduler is necessary for better handling following queues
diff --git a/iotdb-core/ainode/ainode/core/inference/utils.py 
b/iotdb-core/ainode/ainode/core/inference/utils.py
index c2a618d716c..cf10b5b2cd4 100644
--- a/iotdb-core/ainode/ainode/core/inference/utils.py
+++ b/iotdb-core/ainode/ainode/core/inference/utils.py
@@ -22,7 +22,7 @@ import torch
 from transformers.modeling_outputs import MoeCausalLMOutputWithPast
 
 
-def _generate_req_id(length=10, charset=string.ascii_letters + string.digits) 
-> str:
+def generate_req_id(length=10, charset=string.ascii_letters + string.digits) 
-> str:
     """
     Generate a random req_id string of specified length.
     The length is 10 by default, with 10^{17} possible combinations.
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index bed39c48d08..5a853ac4e72 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -30,6 +30,7 @@ from ainode.core.constant import TSStatusCode
 from ainode.core.exception import (
     InferenceModelInternalError,
     InvalidWindowArgumentError,
+    NumericalRangeException,
     runtime_error_extractor,
 )
 from ainode.core.inference.inference_request import (
@@ -40,7 +41,7 @@ from ainode.core.inference.inference_request_pool import 
InferenceRequestPool
 from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
     TimerSundialInferencePipeline,
 )
-from ainode.core.inference.utils import _generate_req_id
+from ainode.core.inference.utils import generate_req_id
 from ainode.core.log import Logger
 from ainode.core.manager.model_manager import ModelManager
 from ainode.core.model.sundial.configuration_sundial import SundialConfig
@@ -214,6 +215,20 @@ class InferenceManager:
             full_data = deserializer(raw)
             inference_attrs = extract_attrs(req)
 
+            predict_length = inference_attrs.get("predict_length", 96)
+            if (
+                predict_length
+                > 
AINodeDescriptor().get_config().get_ain_inference_max_predict_length()
+            ):
+                raise NumericalRangeException(
+                    "output_length",
+                    1,
+                    AINodeDescriptor()
+                    .get_config()
+                    .get_ain_inference_max_predict_length(),
+                    predict_length,
+                )
+
             if model_id == self.ACCELERATE_MODEL_ID and self.DEFAULT_POOL_SIZE 
> 0:
                 # TODO: Logic in this branch shall handle all LTSM inferences
                 # TODO: TSBlock -> Tensor codes should be unified
@@ -223,10 +238,10 @@ class InferenceManager:
                 # the inputs should be on CPU before passing to the inference 
request
                 inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
                 infer_req = InferenceRequest(
-                    req_id=_generate_req_id(),
+                    req_id=generate_req_id(),
                     inputs=inputs,
                     
inference_pipeline=TimerSundialInferencePipeline(SundialConfig()),
-                    max_new_tokens=inference_attrs.get("predict_length", 96),
+                    max_new_tokens=predict_length,
                 )
                 infer_proxy = InferenceRequestProxy(infer_req.req_id)
                 with self._result_wrapper_lock:
diff --git a/iotdb-core/ainode/ainode/core/rpc/handler.py 
b/iotdb-core/ainode/ainode/core/rpc/handler.py
index cb25420ae00..d3948020ab3 100644
--- a/iotdb-core/ainode/ainode/core/rpc/handler.py
+++ b/iotdb-core/ainode/ainode/core/rpc/handler.py
@@ -41,13 +41,13 @@ logger = Logger()
 
 
 class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
-    def __init__(self, aiNode):
-        self._aiNode = aiNode
+    def __init__(self, ainode):
+        self._ainode = ainode
         self._model_manager = ModelManager()
         self._inference_manager = InferenceManager()
 
     def stopAINode(self) -> TSStatus:
-        self._aiNode.stop()
+        self._ainode.stop()
         return get_status(TSStatusCode.SUCCESS_STATUS, "AINode stopped 
successfully.")
 
     def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:

Reply via email to