This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch aidevicemanager
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 391e1896cf912d76824f8d9d0151a95fe4c1ad1f
Author: Yongzao <[email protected]>
AuthorDate: Thu Jan 8 10:52:21 2026 +0800

    seems finish
---
 .../iotdb/ainode/core/device/backend/base.py       |  5 ++-
 .../ainode/core/device/backend/cpu_backend.py      | 23 +----------
 .../ainode/core/device/backend/cuda_backend.py     | 19 +---------
 .../iotdb/ainode/core/device/device_utils.py       |  4 +-
 iotdb-core/ainode/iotdb/ainode/core/device/env.py  |  2 +
 .../core/inference/inference_request_pool.py       | 11 ++++--
 .../core/inference/pipeline/basic_pipeline.py      |  5 ++-
 .../iotdb/ainode/core/inference/pool_controller.py | 30 ++++++++++-----
 .../iotdb/ainode/core/manager/device_manager.py    | 44 +++++++++-------------
 .../iotdb/ainode/core/manager/inference_manager.py | 35 +++++++++--------
 .../ainode/iotdb/ainode/core/model/model_loader.py | 14 +++----
 iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py | 19 ++++++++--
 12 files changed, 100 insertions(+), 111 deletions(-)

diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
index 3ae7587284a..dee04f7ea2f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
@@ -17,9 +17,11 @@
 #
 
 from enum import Enum
-from typing import Protocol, Optional, ContextManager
+from typing import ContextManager, Optional, Protocol
+
 import torch
 
+
 class BackendType(Enum):
     """
     Different types of supported computation backends.
@@ -29,6 +31,7 @@ class BackendType(Enum):
     CUDA = "cuda"
     CPU = "cpu"
 
+
 class BackendAdapter(Protocol):
     type: BackendType
 
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
index 48a849e2df2..b196f2c8bd1 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
@@ -17,6 +17,7 @@
 #
 
 from contextlib import nullcontext
+
 import torch
 
 from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
@@ -36,25 +37,3 @@ class CPUBackend(BackendAdapter):
 
     def set_device(self, index: int) -> None:
         return None
-
-    def synchronize(self) -> None:
-        return None
-
-    def autocast(self, enabled: bool, dtype: torch.dtype):
-        return nullcontext()
-
-    def make_grad_scaler(self, enabled: bool):
-        class _NoopScaler:
-            def scale(self, loss): return loss
-            def step(self, optim): optim.step()
-            def update(self): return None
-            def unscale_(self, optim): return None
-            @property
-            def is_enabled(self): return False
-        return _NoopScaler()
-
-    def default_dist_backend(self) -> str:
-        return "gloo"
-
-    def supports_bf16(self) -> bool:
-        return True
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
index 0d25c58ac8f..e5b44d69b6e 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
@@ -17,6 +17,7 @@
 #
 
 from contextlib import nullcontext
+
 import torch
 
 from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
@@ -38,21 +39,3 @@ class CUDABackend(BackendAdapter):
 
     def set_device(self, index: int) -> None:
         torch.cuda.set_device(index)
-
-    def synchronize(self) -> None:
-        torch.cuda.synchronize()
-
-    def autocast(self, enabled: bool, dtype: torch.dtype):
-        if not enabled:
-            return nullcontext()
-        return torch.autocast(device_type="cuda", dtype=dtype, enabled=True)
-
-    def make_grad_scaler(self, enabled: bool):
-        return torch.cuda.amp.GradScaler(enabled=enabled)
-
-    def default_dist_backend(self) -> str:
-        return "nccl"
-
-    def supports_bf16(self) -> bool:
-        fn = getattr(torch.cuda, "is_bf16_supported", None)
-        return bool(fn()) if callable(fn) else True
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
index f5927555703..fa60f294d32 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
@@ -16,17 +16,19 @@
 # under the License.
 #
 from dataclasses import dataclass
-from typing import Union, Optional
+from typing import Optional, Union
 
 import torch
 
 DeviceLike = Union[torch.device, str, int]
 
+
 @dataclass(frozen=True)
 class DeviceSpec:
     type: str
     index: Optional[int]
 
+
 def parse_device_like(x: DeviceLike) -> DeviceSpec:
     if isinstance(x, int):
         return DeviceSpec("index", x)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/env.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/env.py
index 091495505e5..5252cca028f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/device/env.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/env.py
@@ -19,12 +19,14 @@
 import os
 from dataclasses import dataclass
 
+
 @dataclass(frozen=True)
 class DistEnv:
     rank: int
     local_rank: int
     world_size: int
 
+
 def read_dist_env() -> DistEnv:
     # torchrun:
     rank = int(os.environ.get("RANK", "0"))
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index 073fa5fedce..6520302f27c 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -40,6 +40,7 @@ from 
iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler impor
     BasicRequestScheduler,
 )
 from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.manager.device_manager import DeviceManager
 from iotdb.ainode.core.model.model_storage import ModelInfo
 
 
@@ -76,6 +77,8 @@ class InferenceRequestPool(mp.Process):
         self.ready_event = ready_event
         self.device = device
 
+        self._backend = DeviceManager()
+
         self._threads = []
         self._waiting_queue = request_queue  # Requests that are waiting to be 
processed
         self._running_queue = mp.Queue()  # Requests that are currently being 
processed
@@ -119,8 +122,8 @@ class InferenceRequestPool(mp.Process):
         grouped_requests = list(grouped_requests.values())
 
         for requests in grouped_requests:
-            batch_inputs = self._batcher.batch_request(requests).to(
-                "cpu"
+            batch_inputs = self._backend.move_tensor(
+                self._batcher.batch_request(requests), 
self._backend.torch_device("cpu")
             )  # The input data should first load to CPU in current version
             batch_input_list = []
             for i in range(batch_inputs.size(0)):
@@ -152,7 +155,9 @@ class InferenceRequestPool(mp.Process):
 
             offset = 0
             for request in requests:
-                request.output_tensor = request.output_tensor.to(self.device)
+                request.output_tensor = self._backend.move_tensor(
+                    request.output_tensor, self.device
+                )
                 cur_batch_size = request.batch_size
                 cur_output = batch_output[offset : offset + cur_batch_size]
                 offset += cur_batch_size
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
index 7ccef492b41..917c40fef83 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -21,16 +21,17 @@ from abc import ABC, abstractmethod
 import torch
 
 from iotdb.ainode.core.exception import InferenceModelInternalException
-from iotdb.ainode.core.model.model_info import ModelInfo
 from iotdb.ainode.core.manager.device_manager import DeviceManager
+from iotdb.ainode.core.model.model_info import ModelInfo
 from iotdb.ainode.core.model.model_loader import load_model
 
 BACKEND = DeviceManager()
 
+
 class BasicPipeline(ABC):
     def __init__(self, model_info: ModelInfo, **model_kwargs):
         self.model_info = model_info
-        self.device = model_kwargs.get("device", "cpu")
+        self.device = model_kwargs.get("device", BACKEND.torch_device("cpu"))
         self.model = load_model(model_info, device_map=self.device, 
**model_kwargs)
 
     @abstractmethod
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index 416422d578e..29018f1c59e 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -124,12 +124,12 @@ class PoolController:
         # if not ready_event.wait(timeout=30):
         #     self._erase_pool(model_id, device_id, 0)
         #     logger.error(
-        #         f"[Inference][Device-{device}][Pool-0] Pool failed to be 
ready in time"
+        #         f"[Inference][{device}][Pool-0] Pool failed to be ready in 
time"
         #     )
         # else:
         #     self.set_state(model_id, device_id, 0, PoolState.RUNNING)
         #     logger.info(
-        #         f"[Inference][Device-{device}][Pool-0] Pool started running 
for model {model_id}"
+        #         f"[Inference][{device}][Pool-0] Pool started running for 
model {model_id}"
         #     )
 
     # =============== Pool Management ===============
@@ -239,7 +239,9 @@ class PoolController:
             unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED
         )
 
-    def _expand_pools_on_device(self, model_id: str, device_id: torch.device, 
count: int):
+    def _expand_pools_on_device(
+        self, model_id: str, device_id: torch.device, count: int
+    ):
         """
         Expand the pools for the given model_id and device_id sequentially.
         Args:
@@ -281,7 +283,9 @@ class PoolController:
             expand_pool_futures, return_when=concurrent.futures.ALL_COMPLETED
         )
 
-    def _shrink_pools_on_device(self, model_id: str, device_id: torch.device, 
count: int):
+    def _shrink_pools_on_device(
+        self, model_id: str, device_id: torch.device, count: int
+    ):
         """
         Shrink the pools for the given model_id by count sequentially.
         TODO: shrink pools in parallel
@@ -353,7 +357,7 @@ class PoolController:
             f"[Inference][{device_id}][Pool-{pool_id}] Pool initializing for 
model {model_id}"
         )
 
-    def _erase_pool(self, model_id: str, device_id: str, pool_id: int):
+    def _erase_pool(self, model_id: str, device_id: torch.device, pool_id: 
int):
         """
         Erase the specified inference request pool for the given model_id, 
device_id and pool_id.
         """
@@ -388,7 +392,9 @@ class PoolController:
         self._request_pool_map[model_id][device_id].dispatch_request(req, 
infer_proxy)
 
     # =============== Getters / Setters ===============
-    def get_state(self, model_id: str, device_id: torch.device, pool_id: int) 
-> Optional[PoolState]:
+    def get_state(
+        self, model_id: str, device_id: torch.device, pool_id: int
+    ) -> Optional[PoolState]:
         """
         Get the state of the specified pool based on model_id, device_id, and 
pool_id.
         """
@@ -397,7 +403,9 @@ class PoolController:
             return pool_group.get_state(pool_id)
         return None
 
-    def set_state(self, model_id: str, device_id: torch.device, pool_id: int, 
state: PoolState):
+    def set_state(
+        self, model_id: str, device_id: torch.device, pool_id: int, state: 
PoolState
+    ):
         """
         Set the state of the specified pool based on model_id, device_id, and 
pool_id.
         """
@@ -422,7 +430,7 @@ class PoolController:
             return pool_group.get_pool_ids()
         return []
 
-    def has_request_pools(self, model_id: str, device_id: 
Optional[torch.device]) -> bool:
+    def has_request_pools(self, model_id: str, device_id: torch.device = None) 
-> bool:
         """
         Check if there are request pools for the given model_id ((optional) 
and device_id).
         """
@@ -451,7 +459,9 @@ class PoolController:
             return pool_group.get_request_pool(pool_id)
         return None
 
-    def get_request_queue(self, model_id: str, device_id: torch.device, 
pool_id: int) -> Optional[mp.Queue]:
+    def get_request_queue(
+        self, model_id: str, device_id: torch.device, pool_id: int
+    ) -> Optional[mp.Queue]:
         pool_group = self.get_request_pools_group(model_id, device_id)
         if pool_group:
             return pool_group.get_request_queue(pool_id)
@@ -476,7 +486,7 @@ class PoolController:
             pool_id, request_pool, request_queue
         )
         logger.info(
-            f"[Inference][Device-{device_id}][Pool-{pool_id}] Registered pool 
for model {model_id}"
+            f"[Inference][{device_id}][Pool-{pool_id}] Registered pool for 
model {model_id}"
         )
 
     def get_load(self, model_id: str, device_id: torch.device, pool_id: int) 
-> int:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
index 2271aa4ba0e..80d493d2dee 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
@@ -17,15 +17,15 @@
 #
 
 from dataclasses import dataclass
-from typing import Optional, ContextManager
-import os
+from typing import Optional
+
 import torch
 
-from iotdb.ainode.core.device.env import read_dist_env, DistEnv
-from iotdb.ainode.core.device.device_utils import (DeviceLike, 
parse_device_like)
 from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
-from iotdb.ainode.core.device.backend.cuda_backend import CUDABackend
 from iotdb.ainode.core.device.backend.cpu_backend import CPUBackend
+from iotdb.ainode.core.device.backend.cuda_backend import CUDABackend
+from iotdb.ainode.core.device.device_utils import DeviceLike, parse_device_like
+from iotdb.ainode.core.device.env import DistEnv, read_dist_env
 from iotdb.ainode.core.util.decorator import singleton
 
 
@@ -33,6 +33,7 @@ from iotdb.ainode.core.util.decorator import singleton
 class DeviceManagerConfig:
     use_local_rank_if_distributed: bool = True
 
+
 @singleton
 class DeviceManager:
     """
@@ -41,6 +42,7 @@ class DeviceManager:
     - Parse device expression (None/int/str/torch.device/DeviceSpec)
     - Provide device, autocast, grad scaler, synchronize, dist backend 
recommendation, etc.
     """
+
     def __init__(self, cfg: DeviceManagerConfig):
         self.cfg = cfg
         self.env: DistEnv = read_dist_env()
@@ -87,13 +89,13 @@ class DeviceManager:
             return []
         return list(range(self.backend.device_count()))
 
-    def str_device_ids_with_cpu(self) -> list[str]:
+    def available_devices_with_cpu(self) -> list[torch.device]:
         """
-        Returns a list of available device IDs as strings, including "cpu".
+        Returns the list of available torch.devices, including "cpu".
         """
         device_id_list = self.device_ids()
-        device_id_list = [str(device_id) for device_id in device_id_list]
-        device_id_list.append("cpu")
+        device_id_list = [self.torch_device(device_id) for device_id in 
device_id_list]
+        device_id_list.append(self.torch_device("cpu"))
         return device_id_list
 
     def torch_device(self, device: DeviceLike) -> torch.device:
@@ -113,24 +115,12 @@ class DeviceManager:
             return torch.device("cpu")
         return self.backend.make_device(spec.index)
 
-    def move_model(self, model: torch.nn.Module, device: DeviceLike = None) -> 
torch.nn.Module:
+    def move_model(
+        self, model: torch.nn.Module, device: DeviceLike = None
+    ) -> torch.nn.Module:
         return model.to(self.torch_device(device))
 
-    def move_tensor(self, tensor: torch.Tensor, device: DeviceLike = None) -> 
torch.Tensor:
+    def move_tensor(
+        self, tensor: torch.Tensor, device: DeviceLike = None
+    ) -> torch.Tensor:
         return tensor.to(self.torch_device(device))
-
-    def synchronize(self) -> None:
-        self.backend.synchronize()
-
-    def autocast(self, enabled: bool, dtype: torch.dtype) -> ContextManager:
-        return self.backend.autocast(enabled=enabled, dtype=dtype)
-
-    def make_grad_scaler(self, enabled: bool):
-        return self.backend.make_grad_scaler(enabled=enabled)
-
-    def default_dist_backend(self) -> str:
-        # allow user override
-        return os.environ.get("TORCH_DIST_BACKEND", 
self.backend.default_dist_backend())
-
-    def supports_bf16(self) -> bool:
-        return self.backend.supports_bf16()
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index d5482bc99b0..46ad37e2a08 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -54,10 +54,7 @@ from iotdb.thrift.ainode.ttypes import (
     TForecastResp,
     TInferenceReq,
     TInferenceResp,
-    TLoadModelReq,
-    TShowLoadedModelsReq,
     TShowLoadedModelsResp,
-    TUnloadModelReq,
 )
 from iotdb.thrift.common.ttypes import TSStatus
 
@@ -86,7 +83,9 @@ class InferenceManager:
         self._result_handler_thread.start()
         self._pool_controller = PoolController(self._result_queue)
 
-    def load_model(self, existing_model_id: str, device_id_list: 
list[torch.device]) -> TSStatus:
+    def load_model(
+        self, existing_model_id: str, device_id_list: list[torch.device]
+    ) -> TSStatus:
         """
         Load a model to specified devices.
         Args:
@@ -116,35 +115,39 @@ class InferenceManager:
             message='Successfully submitted load model task, please use "SHOW 
LOADED MODELS" to check progress.',
         )
 
-    def unload_model(self, req: TUnloadModelReq) -> TSStatus:
+    def unload_model(
+        self, model_id: str, device_id_list: list[torch.device]
+    ) -> TSStatus:
         devices_to_be_processed = []
         devices_not_to_be_processed = []
-        for device_id in req.deviceIdList:
+        for device_id in device_id_list:
             if self._pool_controller.has_request_pools(
-                model_id=req.modelId, device_id=device_id
+                model_id=model_id, device_id=device_id
             ):
                 devices_to_be_processed.append(device_id)
             else:
                 devices_not_to_be_processed.append(device_id)
         if len(devices_to_be_processed) > 0:
             self._pool_controller.unload_model(
-                model_id=req.modelId, device_id_list=req.deviceIdList
+                model_id=model_id, device_id_list=device_id_list
             )
         logger.info(
-            f"[Inference] Start unloading model [{req.modelId}] from devices 
[{devices_to_be_processed}], skipped devices [{devices_not_to_be_processed}] 
cause they haven't loaded this model."
+            f"[Inference] Start unloading model [{model_id}] from devices 
[{devices_to_be_processed}], skipped devices [{devices_not_to_be_processed}] 
cause they haven't loaded this model."
         )
         return TSStatus(
             code=TSStatusCode.SUCCESS_STATUS.value,
             message='Successfully submitted unload model task, please use 
"SHOW LOADED MODELS" to check progress.',
         )
 
-    def show_loaded_models(self, req: TShowLoadedModelsReq) -> 
TShowLoadedModelsResp:
+    def show_loaded_models(
+        self, device_id_list: list[torch.device]
+    ) -> TShowLoadedModelsResp:
         return TShowLoadedModelsResp(
             status=get_status(TSStatusCode.SUCCESS_STATUS),
             deviceLoadedModelsMap=self._pool_controller.show_loaded_models(
-                req.deviceIdList
-                if len(req.deviceIdList) > 0
-                else self._backend.str_device_ids_with_cpu()
+                device_id_list
+                if len(device_id_list) > 0
+                else self._backend.available_devices_with_cpu()
             ),
         )
 
@@ -211,7 +214,7 @@ class InferenceManager:
                     output_length,
                 )
 
-            if self._pool_controller.has_request_pools(model_id):
+            if self._pool_controller.has_request_pools(model_id=model_id):
                 infer_req = InferenceRequest(
                     req_id=generate_req_id(),
                     model_id=model_id,
@@ -223,7 +226,9 @@ class InferenceManager:
                 outputs = self._process_request(infer_req)
             else:
                 model_info = self._model_manager.get_model_info(model_id)
-                inference_pipeline = load_pipeline(model_info, device="cpu")
+                inference_pipeline = load_pipeline(
+                    model_info, device=self._backend.torch_device("cpu")
+                )
                 inputs = inference_pipeline.preprocess(
                     model_inputs_list, output_length=output_length
                 )
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
index 605620d4261..289786c8aa3 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
@@ -34,12 +34,14 @@ from transformers import (
 from iotdb.ainode.core.config import AINodeDescriptor
 from iotdb.ainode.core.exception import ModelNotExistException
 from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.manager.device_manager import DeviceManager
 from iotdb.ainode.core.model.model_constants import ModelCategory
 from iotdb.ainode.core.model.model_info import ModelInfo
 from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model
 from iotdb.ainode.core.model.utils import import_class_from_path, 
temporary_sys_path
 
 logger = Logger()
+BACKEND = DeviceManager()
 
 
 def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
@@ -105,17 +107,13 @@ def load_model_from_transformers(model_info: ModelInfo, 
**model_kwargs):
             model_cls = AutoModelForCausalLM
 
     if train_from_scratch:
-        model = model_cls.from_config(
-            config_cls, trust_remote_code=trust_remote_code, 
device_map=device_map
-        )
+        model = model_cls.from_config(config_cls, 
trust_remote_code=trust_remote_code)
     else:
         model = model_cls.from_pretrained(
-            model_path,
-            trust_remote_code=trust_remote_code,
-            device_map=device_map,
+            model_path, trust_remote_code=trust_remote_code
         )
 
-    return model
+    return BACKEND.move_model(model, device_map)
 
 
 def load_model_from_pt(model_info: ModelInfo, **kwargs):
@@ -138,7 +136,7 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs):
         model = torch.compile(model)
     except Exception as e:
         logger.warning(f"acceleration failed, fallback to normal mode: 
{str(e)}")
-    return model.to(device_map)
+    return BACKEND.move_model(model, device_map)
 
 
 def load_model_for_efficient_inference():
diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py 
b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
index 97059f7f169..7cf00082982 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
@@ -95,7 +95,10 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
         status = self._ensure_device_id_is_available(req.deviceIdList)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return status
-        return self._inference_manager.load_model(req)
+        return self._inference_manager.load_model(
+            req.existingModelId,
+            [self._backend.torch_device(device_id) for device_id in 
req.deviceIdList],
+        )
 
     def unloadModel(self, req: TUnloadModelReq) -> TSStatus:
         status = self._ensure_model_is_registered(req.modelId)
@@ -104,13 +107,18 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
         status = self._ensure_device_id_is_available(req.deviceIdList)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return status
-        return self._inference_manager.unload_model(req)
+        return self._inference_manager.unload_model(
+            req.modelId,
+            [self._backend.torch_device(device_id) for device_id in 
req.deviceIdList],
+        )
 
     def showLoadedModels(self, req: TShowLoadedModelsReq) -> 
TShowLoadedModelsResp:
         status = self._ensure_device_id_is_available(req.deviceIdList)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return TShowLoadedModelsResp(status=status, 
deviceLoadedModelsMap={})
-        return self._inference_manager.show_loaded_models(req)
+        return self._inference_manager.show_loaded_models(
+            [self._backend.torch_device(device_id) for device_id in 
req.deviceIdList]
+        )
 
     def _ensure_model_is_registered(self, model_id: str) -> TSStatus:
         if not self._model_manager.is_model_registered(model_id):
@@ -142,7 +150,10 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
         """
         available_devices = self._backend.device_ids()
         for device_id in device_id_list:
-            if device_id != "cpu" and int(device_id) not in available_devices:
+            try:
+                if device_id != "cpu" and int(device_id) not in 
available_devices:
+                    raise ValueError(f"Invalid device ID [{device_id}]")
+            except ValueError:
                 return TSStatus(
                     code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
                     message=f"AIDevice ID [{device_id}] is not available. You 
can use 'SHOW AI_DEVICES' to retrieve the available devices.",

Reply via email to