This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new d8a0a5abda8 [AINode] Optimize model loading (#17046)
d8a0a5abda8 is described below
commit d8a0a5abda8d9e185158ba9e8c0f79b210c934cb
Author: Leo <[email protected]>
AuthorDate: Tue Jan 20 22:27:41 2026 +0800
[AINode] Optimize model loading (#17046)
---
.../ainode/iotdb/ainode/core/inference/pool_controller.py | 11 +++++++++++
.../ainode/iotdb/ainode/core/manager/inference_manager.py | 8 ++------
2 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index 1eb07adfde4..5a6db12edde 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -443,6 +443,17 @@ class PoolController:
return device_id in self._request_pool_map[model_id]
return True
+ def has_running_pools(self, model_id: str) -> bool:
+ """
+ Check if there are running pools for the given model_id.
+ """
+ if model_id not in self._request_pool_map:
+ return False
+ for device_id, pool_group in self._request_pool_map[model_id].items():
+ if pool_group.get_running_pool_count():
+ return True
+ return False
+
def get_request_pools_group(
self, model_id: str, device_id: torch.device
) -> Optional[PoolGroup]:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index ebbb036a9dc..180cc00ff49 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -18,7 +18,6 @@
import threading
import time
-from typing import Dict
import torch
import torch.multiprocessing as mp
@@ -69,9 +68,6 @@ class InferenceManager:
def __init__(self):
self._model_manager = ModelManager()
self._backend = DeviceManager()
- self._model_mem_usage_map: Dict[str, int] = (
- {}
- ) # store model memory usage for each model
self._result_queue = mp.Queue()
self._result_wrapper_map = {}
self._result_wrapper_lock = threading.RLock()
@@ -207,14 +203,14 @@ class InferenceManager:
):
raise NumericalRangeException(
"output_length",
+ output_length,
1,
AINodeDescriptor()
.get_config()
.get_ain_inference_max_output_length(),
- output_length,
)
- if self._pool_controller.has_request_pools(model_id=model_id):
+ if self._pool_controller.has_running_pools(model_id):
infer_req = InferenceRequest(
req_id=generate_req_id(),
model_id=model_id,