This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new d8dfe1b9170 [AINode] Modify model loading (#17299)
d8dfe1b9170 is described below
commit d8dfe1b91704e88619363a154fa69a89dfef872f
Author: Leo <[email protected]>
AuthorDate: Tue Mar 17 13:41:49 2026 +0800
[AINode] Modify model loading (#17299)
---
.../ainode/it/AINodeInstanceManagementIT.java | 8 +-
.../iotdb/ainode/core/inference/pool_controller.py | 110 ++++++++++++++-------
.../pool_scheduler/basic_pool_scheduler.py | 4 +-
3 files changed, 82 insertions(+), 40 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
index 15ddce11ede..f8aa27ce688 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
@@ -87,16 +87,16 @@ public class AINodeInstanceManagementIT {
// Load sundial to each device
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'",
TARGET_DEVICES));
checkModelOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
+ // Unload sundial from each device
+ statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'",
TARGET_DEVICES));
+ checkModelNotOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
// Load timer_xl to each device
statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'",
TARGET_DEVICES));
checkModelOnSpecifiedDevice(statement, "timer_xl",
TARGET_DEVICES.toString());
-
- // Clean every device
- statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'",
TARGET_DEVICES));
+ // Unload timer_xl from each device
statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'",
TARGET_DEVICES));
checkModelNotOnSpecifiedDevice(statement, "timer_xl",
TARGET_DEVICES.toString());
- checkModelNotOnSpecifiedDevice(statement, "sundial",
TARGET_DEVICES.toString());
}
private static final int LOOP_CNT = 10;
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index f4b3d23d36f..73bf01c3f7a 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -141,7 +141,9 @@ class PoolController:
model_id (str): The ID of the model to be loaded.
device_id_list (list[torch.device]): List of device_ids where the
model should be loaded.
"""
- self._task_queue.put((self._load_model_task, (model_id,
device_id_list), {}))
+ self._task_queue.put(
+ (self._load_one_model_task, (model_id, device_id_list), {})
+ )
def unload_model(self, model_id: str, device_id_list: list[torch.device]):
"""
@@ -150,7 +152,9 @@ class PoolController:
model_id (str): The ID of the model to be unloaded.
device_id_list (list[torch.device]): List of device_ids where the
model should be unloaded.
"""
- self._task_queue.put((self._unload_model_task, (model_id,
device_id_list), {}))
+ self._task_queue.put(
+ (self._unload_one_model_task, (model_id, device_id_list), {})
+ )
def show_loaded_models(
self, device_id_list: list[torch.device]
@@ -196,60 +200,92 @@ class PoolController:
finally:
self._task_queue.task_done()
- def _load_model_task(self, model_id: str, device_id_list:
list[torch.device]):
- def _load_model_on_device_task(device_id: torch.device):
- if not self.has_request_pools(model_id, device_id):
- actions = self._pool_scheduler.schedule_load_model_to_device(
- self._model_manager.get_model_info(model_id), device_id
- )
- for action in actions:
- if action.action == ScaleActionType.SCALE_UP:
- self._expand_pools_on_device(
- action.model_id, device_id, action.amount
- )
- elif action.action == ScaleActionType.SCALE_DOWN:
- self._shrink_pools_on_device(
- action.model_id, device_id, action.amount
- )
+ def _load_one_model_task(self, model_id: str, device_id_list:
list[torch.device]):
+ def _load_one_model_on_device_task(device_id: torch.device):
+ if not self.has_pool_on_device(device_id):
+ self._expand_pools_on_device(model_id, device_id, 1)
else:
logger.info(
- f"[Inference][{device_id}] Model {model_id} is already
installed."
+ f"[Inference][{device_id}] There are already pools on this
device."
)
load_model_futures = self._executor.submit_batch(
- device_id_list, _load_model_on_device_task
+ device_id_list, _load_one_model_on_device_task
)
concurrent.futures.wait(
load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
)
- def _unload_model_task(self, model_id: str, device_id_list:
list[torch.device]):
- def _unload_model_on_device_task(device_id: torch.device):
+ def _unload_one_model_task(self, model_id: str, device_id_list:
list[torch.device]):
+ def _unload_one_model_on_device_task(device_id: torch.device):
if self.has_request_pools(model_id, device_id):
- actions =
self._pool_scheduler.schedule_unload_model_from_device(
- self._model_manager.get_model_info(model_id), device_id
- )
- for action in actions:
- if action.action == ScaleActionType.SCALE_DOWN:
- self._shrink_pools_on_device(
- action.model_id, device_id, action.amount
- )
- elif action.action == ScaleActionType.SCALE_UP:
- self._expand_pools_on_device(
- action.model_id, device_id, action.amount
- )
+ self._shrink_pools_on_device(model_id, device_id, 1)
else:
logger.info(
f"[Inference][{device_id}] Model {model_id} is not
installed."
)
unload_model_futures = self._executor.submit_batch(
- device_id_list, _unload_model_on_device_task
+ device_id_list, _unload_one_model_on_device_task
)
concurrent.futures.wait(
unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED
)
+ # def _load_model_task(self, model_id: str, device_id_list:
list[torch.device]):
+ # def _load_model_on_device_task(device_id: torch.device):
+ # if not self.has_request_pools(model_id, device_id):
+ # actions = self._pool_scheduler.schedule_load_model_to_device(
+ # self._model_manager.get_model_info(model_id), device_id
+ # )
+ # for action in actions:
+ # if action.action == ScaleActionType.SCALE_UP:
+ # self._expand_pools_on_device(
+ # action.model_id, device_id, action.amount
+ # )
+ # elif action.action == ScaleActionType.SCALE_DOWN:
+ # self._shrink_pools_on_device(
+ # action.model_id, device_id, action.amount
+ # )
+ # else:
+ # logger.info(
+ # f"[Inference][{device_id}] Model {model_id} is already
installed."
+ # )
+ #
+ # load_model_futures = self._executor.submit_batch(
+ # device_id_list, _load_model_on_device_task
+ # )
+ # concurrent.futures.wait(
+ # load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
+ # )
+ #
+ # def _unload_model_task(self, model_id: str, device_id_list:
list[torch.device]):
+ # def _unload_model_on_device_task(device_id: torch.device):
+ # if self.has_request_pools(model_id, device_id):
+ # actions =
self._pool_scheduler.schedule_unload_model_from_device(
+ # self._model_manager.get_model_info(model_id), device_id
+ # )
+ # for action in actions:
+ # if action.action == ScaleActionType.SCALE_DOWN:
+ # self._shrink_pools_on_device(
+ # action.model_id, device_id, action.amount
+ # )
+ # elif action.action == ScaleActionType.SCALE_UP:
+ # self._expand_pools_on_device(
+ # action.model_id, device_id, action.amount
+ # )
+ # else:
+ # logger.info(
+ # f"[Inference][{device_id}] Model {model_id} is not
installed."
+ # )
+ #
+ # unload_model_futures = self._executor.submit_batch(
+ # device_id_list, _unload_model_on_device_task
+ # )
+ # concurrent.futures.wait(
+ # unload_model_futures,
return_when=concurrent.futures.ALL_COMPLETED
+ # )
+
def _expand_pools_on_device(
self, model_id: str, device_id: torch.device, count: int
):
@@ -462,6 +498,12 @@ class PoolController:
return True
return False
+ def has_pool_on_device(self, device_id: torch.device) -> bool:
+ """
+ Check if there are pools on the given device_id.
+ """
+ return any(device_id in pools for pools in
self._request_pool_map.values())
+
def get_request_pools_group(
self, model_id: str, device_id: torch.device
) -> Optional[PoolGroup]:
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index 5ce9eceba14..49aebe8a89e 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -46,7 +46,7 @@ def _estimate_shared_pool_size_by_total_mem(
new_model_info: Optional[ModelInfo] = None,
) -> Dict[str, int]:
"""
- Estimate pool counts for (existing_model_ids + new_model_id) by equally
+ Estimate pool counts for (existing_model_infos + new_model_info) by equally
splitting the device's TOTAL memory among models.
Returns:
@@ -60,7 +60,7 @@ def _estimate_shared_pool_size_by_total_mem(
)
raise ModelNotExistException(new_model_info.model_id)
- # Extract unique model IDs
+ # Extract unique model infos
all_models = existing_model_infos + (
[new_model_info] if new_model_info is not None else []
)