This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new d8dfe1b9170 [AINode] Modify model loading (#17299)
d8dfe1b9170 is described below

commit d8dfe1b91704e88619363a154fa69a89dfef872f
Author: Leo <[email protected]>
AuthorDate: Tue Mar 17 13:41:49 2026 +0800

    [AINode] Modify model loading (#17299)
---
 .../ainode/it/AINodeInstanceManagementIT.java      |   8 +-
 .../iotdb/ainode/core/inference/pool_controller.py | 110 ++++++++++++++-------
 .../pool_scheduler/basic_pool_scheduler.py         |   4 +-
 3 files changed, 82 insertions(+), 40 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
index 15ddce11ede..f8aa27ce688 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
@@ -87,16 +87,16 @@ public class AINodeInstanceManagementIT {
     // Load sundial to each device
     statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", 
TARGET_DEVICES));
     checkModelOnSpecifiedDevice(statement, "sundial", 
TARGET_DEVICES.toString());
+    // Unload sundial from each device
+    statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", 
TARGET_DEVICES));
+    checkModelNotOnSpecifiedDevice(statement, "sundial", 
TARGET_DEVICES.toString());
 
     // Load timer_xl to each device
     statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", 
TARGET_DEVICES));
     checkModelOnSpecifiedDevice(statement, "timer_xl", 
TARGET_DEVICES.toString());
-
-    // Clean every device
-    statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", 
TARGET_DEVICES));
+    // Unload timer_xl from each device
     statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", 
TARGET_DEVICES));
     checkModelNotOnSpecifiedDevice(statement, "timer_xl", 
TARGET_DEVICES.toString());
-    checkModelNotOnSpecifiedDevice(statement, "sundial", 
TARGET_DEVICES.toString());
   }
 
   private static final int LOOP_CNT = 10;
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index f4b3d23d36f..73bf01c3f7a 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -141,7 +141,9 @@ class PoolController:
             model_id (str): The ID of the model to be loaded.
             device_id_list (list[torch.device]): List of device_ids where the 
model should be loaded.
         """
-        self._task_queue.put((self._load_model_task, (model_id, 
device_id_list), {}))
+        self._task_queue.put(
+            (self._load_one_model_task, (model_id, device_id_list), {})
+        )
 
     def unload_model(self, model_id: str, device_id_list: list[torch.device]):
         """
@@ -150,7 +152,9 @@ class PoolController:
             model_id (str): The ID of the model to be unloaded.
             device_id_list (list[torch.device]): List of device_ids where the 
model should be unloaded.
         """
-        self._task_queue.put((self._unload_model_task, (model_id, 
device_id_list), {}))
+        self._task_queue.put(
+            (self._unload_one_model_task, (model_id, device_id_list), {})
+        )
 
     def show_loaded_models(
         self, device_id_list: list[torch.device]
@@ -196,60 +200,92 @@ class PoolController:
             finally:
                 self._task_queue.task_done()
 
-    def _load_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
-        def _load_model_on_device_task(device_id: torch.device):
-            if not self.has_request_pools(model_id, device_id):
-                actions = self._pool_scheduler.schedule_load_model_to_device(
-                    self._model_manager.get_model_info(model_id), device_id
-                )
-                for action in actions:
-                    if action.action == ScaleActionType.SCALE_UP:
-                        self._expand_pools_on_device(
-                            action.model_id, device_id, action.amount
-                        )
-                    elif action.action == ScaleActionType.SCALE_DOWN:
-                        self._shrink_pools_on_device(
-                            action.model_id, device_id, action.amount
-                        )
+    def _load_one_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
+        def _load_one_model_on_device_task(device_id: torch.device):
+            if not self.has_pool_on_device(device_id):
+                self._expand_pools_on_device(model_id, device_id, 1)
             else:
                 logger.info(
-                    f"[Inference][{device_id}] Model {model_id} is already 
installed."
+                    f"[Inference][{device_id}] There are already pools on this 
device."
                 )
 
         load_model_futures = self._executor.submit_batch(
-            device_id_list, _load_model_on_device_task
+            device_id_list, _load_one_model_on_device_task
         )
         concurrent.futures.wait(
             load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
         )
 
-    def _unload_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
-        def _unload_model_on_device_task(device_id: torch.device):
+    def _unload_one_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
+        def _unload_one_model_on_device_task(device_id: torch.device):
             if self.has_request_pools(model_id, device_id):
-                actions = 
self._pool_scheduler.schedule_unload_model_from_device(
-                    self._model_manager.get_model_info(model_id), device_id
-                )
-                for action in actions:
-                    if action.action == ScaleActionType.SCALE_DOWN:
-                        self._shrink_pools_on_device(
-                            action.model_id, device_id, action.amount
-                        )
-                    elif action.action == ScaleActionType.SCALE_UP:
-                        self._expand_pools_on_device(
-                            action.model_id, device_id, action.amount
-                        )
+                self._shrink_pools_on_device(model_id, device_id, 1)
             else:
                 logger.info(
                     f"[Inference][{device_id}] Model {model_id} is not 
installed."
                 )
 
         unload_model_futures = self._executor.submit_batch(
-            device_id_list, _unload_model_on_device_task
+            device_id_list, _unload_one_model_on_device_task
         )
         concurrent.futures.wait(
             unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED
         )
 
+    # def _load_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
+    #     def _load_model_on_device_task(device_id: torch.device):
+    #         if not self.has_request_pools(model_id, device_id):
+    #             actions = self._pool_scheduler.schedule_load_model_to_device(
+    #                 self._model_manager.get_model_info(model_id), device_id
+    #             )
+    #             for action in actions:
+    #                 if action.action == ScaleActionType.SCALE_UP:
+    #                     self._expand_pools_on_device(
+    #                         action.model_id, device_id, action.amount
+    #                     )
+    #                 elif action.action == ScaleActionType.SCALE_DOWN:
+    #                     self._shrink_pools_on_device(
+    #                         action.model_id, device_id, action.amount
+    #                     )
+    #         else:
+    #             logger.info(
+    #                 f"[Inference][{device_id}] Model {model_id} is already 
installed."
+    #             )
+    #
+    #     load_model_futures = self._executor.submit_batch(
+    #         device_id_list, _load_model_on_device_task
+    #     )
+    #     concurrent.futures.wait(
+    #         load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
+    #     )
+    #
+    # def _unload_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
+    #     def _unload_model_on_device_task(device_id: torch.device):
+    #         if self.has_request_pools(model_id, device_id):
+    #             actions = 
self._pool_scheduler.schedule_unload_model_from_device(
+    #                 self._model_manager.get_model_info(model_id), device_id
+    #             )
+    #             for action in actions:
+    #                 if action.action == ScaleActionType.SCALE_DOWN:
+    #                     self._shrink_pools_on_device(
+    #                         action.model_id, device_id, action.amount
+    #                     )
+    #                 elif action.action == ScaleActionType.SCALE_UP:
+    #                     self._expand_pools_on_device(
+    #                         action.model_id, device_id, action.amount
+    #                     )
+    #         else:
+    #             logger.info(
+    #                 f"[Inference][{device_id}] Model {model_id} is not 
installed."
+    #             )
+    #
+    #     unload_model_futures = self._executor.submit_batch(
+    #         device_id_list, _unload_model_on_device_task
+    #     )
+    #     concurrent.futures.wait(
+    #         unload_model_futures, 
return_when=concurrent.futures.ALL_COMPLETED
+    #     )
+
     def _expand_pools_on_device(
         self, model_id: str, device_id: torch.device, count: int
     ):
@@ -462,6 +498,12 @@ class PoolController:
                 return True
         return False
 
+    def has_pool_on_device(self, device_id: torch.device) -> bool:
+        """
+        Check if there are pools on the given device_id.
+        """
+        return any(device_id in pools for pools in 
self._request_pool_map.values())
+
     def get_request_pools_group(
         self, model_id: str, device_id: torch.device
     ) -> Optional[PoolGroup]:
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index 5ce9eceba14..49aebe8a89e 100644
--- 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++ 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -46,7 +46,7 @@ def _estimate_shared_pool_size_by_total_mem(
     new_model_info: Optional[ModelInfo] = None,
 ) -> Dict[str, int]:
     """
-    Estimate pool counts for (existing_model_ids + new_model_id) by equally
+    Estimate pool counts for (existing_model_infos + new_model_info) by equally
     splitting the device's TOTAL memory among models.
 
     Returns:
@@ -60,7 +60,7 @@ def _estimate_shared_pool_size_by_total_mem(
         )
         raise ModelNotExistException(new_model_info.model_id)
 
-    # Extract unique model IDs
+    # Extract unique model infos
     all_models = existing_model_infos + (
         [new_model_info] if new_model_info is not None else []
     )

Reply via email to