This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new d8a0a5abda8 [AINode] Optimize model loading (#17046)
d8a0a5abda8 is described below

commit d8a0a5abda8d9e185158ba9e8c0f79b210c934cb
Author: Leo <[email protected]>
AuthorDate: Tue Jan 20 22:27:41 2026 +0800

    [AINode] Optimize model loading (#17046)
---
 .../ainode/iotdb/ainode/core/inference/pool_controller.py     | 11 +++++++++++
 .../ainode/iotdb/ainode/core/manager/inference_manager.py     |  8 ++------
 2 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index 1eb07adfde4..5a6db12edde 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -443,6 +443,17 @@ class PoolController:
             return device_id in self._request_pool_map[model_id]
         return True
 
+    def has_running_pools(self, model_id: str) -> bool:
+        """
+        Check if there are running pools for the given model_id.
+        """
+        if model_id not in self._request_pool_map:
+            return False
+        for device_id, pool_group in self._request_pool_map[model_id].items():
+            if pool_group.get_running_pool_count():
+                return True
+        return False
+
     def get_request_pools_group(
         self, model_id: str, device_id: torch.device
     ) -> Optional[PoolGroup]:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index ebbb036a9dc..180cc00ff49 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -18,7 +18,6 @@
 
 import threading
 import time
-from typing import Dict
 
 import torch
 import torch.multiprocessing as mp
@@ -69,9 +68,6 @@ class InferenceManager:
     def __init__(self):
         self._model_manager = ModelManager()
         self._backend = DeviceManager()
-        self._model_mem_usage_map: Dict[str, int] = (
-            {}
-        )  # store model memory usage for each model
         self._result_queue = mp.Queue()
         self._result_wrapper_map = {}
         self._result_wrapper_lock = threading.RLock()
@@ -207,14 +203,14 @@ class InferenceManager:
             ):
                 raise NumericalRangeException(
                     "output_length",
+                    output_length,
                     1,
                     AINodeDescriptor()
                     .get_config()
                     .get_ain_inference_max_output_length(),
-                    output_length,
                 )
 
-            if self._pool_controller.has_request_pools(model_id=model_id):
+            if self._pool_controller.has_running_pools(model_id):
                 infer_req = InferenceRequest(
                     req_id=generate_req_id(),
                     model_id=model_id,

Reply via email to