This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch aidevicemanager
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit ece7eb71d777fd34ee5ddde9caf532c3f356be74
Author: Yongzao <[email protected]>
AuthorDate: Wed Jan 7 13:39:44 2026 +0800

    stash
---
 .../ainode/iotdb/ainode/core/device/__init__.py    |   0
 .../iotdb/ainode/core/device/backend/__init__.py   |   0
 .../iotdb/ainode/core/device/backend/base.py       |  48 ++++++++
 .../ainode/core/device/backend/cpu_backend.py      |  60 +++++++++
 .../ainode/core/device/backend/cuda_backend.py     |  58 +++++++++
 .../iotdb/ainode/core/device/device_utils.py       |  47 +++++++
 iotdb-core/ainode/iotdb/ainode/core/device/env.py  |  37 ++++++
 .../core/inference/inference_request_pool.py       |  17 ++-
 .../core/inference/pipeline/basic_pipeline.py      |   2 +
 .../core/inference/pipeline/pipeline_loader.py     |   4 +-
 .../iotdb/ainode/core/inference/pool_controller.py |  67 +++++-----
 .../pool_scheduler/abstract_pool_scheduler.py      |  12 +-
 .../pool_scheduler/basic_pool_scheduler.py         |  39 +++---
 .../iotdb/ainode/core/manager/device_manager.py    | 136 +++++++++++++++++++++
 .../iotdb/ainode/core/manager/inference_manager.py |  27 ++--
 iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py |  43 ++++---
 .../ainode/iotdb/ainode/core/util/gpu_mapping.py   |  93 --------------
 iotdb-core/ainode/pyproject.toml                   |   6 +-
 .../config/metadata/ai/ShowAIDevicesTask.java      |   6 +-
 .../schema/column/ColumnHeaderConstant.java        |   5 +-
 .../thrift-ainode/src/main/thrift/ainode.thrift    |   2 +-
 21 files changed, 509 insertions(+), 200 deletions(-)

diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
new file mode 100644
index 00000000000..3ae7587284a
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from enum import Enum
+from typing import Protocol, Optional, ContextManager
+import torch
+
+class BackendType(Enum):
+    """
+    Different types of supported computation backends.
+    AINode will automatically select the available backend according to the 
order defined here.
+    """
+
+    CUDA = "cuda"
+    CPU = "cpu"
+
+class BackendAdapter(Protocol):
+    type: BackendType
+
+    # device basics
+    def is_available(self) -> bool: ...
+    def device_count(self) -> int: ...
+    def make_device(self, index: Optional[int]) -> torch.device: ...
+    def set_device(self, index: int) -> None: ...
+    def synchronize(self) -> None: ...
+
+    # precision / amp
+    def autocast(self, enabled: bool, dtype: torch.dtype) -> ContextManager: 
...
+    def make_grad_scaler(self, enabled: bool): ...
+
+    # distributed defaults/capabilities
+    def default_dist_backend(self) -> str: ...
+    def supports_bf16(self) -> bool: ...
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
new file mode 100644
index 00000000000..48a849e2df2
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from contextlib import nullcontext
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+
+
+class CPUBackend(BackendAdapter):
+    type = BackendType.CPU
+
+    def is_available(self) -> bool:
+        return True
+
+    def device_count(self) -> int:
+        return 1
+
+    def make_device(self, index: int | None) -> torch.device:
+        return torch.device("cpu")
+
+    def set_device(self, index: int) -> None:
+        return None
+
+    def synchronize(self) -> None:
+        return None
+
+    def autocast(self, enabled: bool, dtype: torch.dtype):
+        return nullcontext()
+
+    def make_grad_scaler(self, enabled: bool):
+        class _NoopScaler:
+            def scale(self, loss): return loss
+            def step(self, optim): optim.step()
+            def update(self): return None
+            def unscale_(self, optim): return None
+            @property
+            def is_enabled(self): return False
+        return _NoopScaler()
+
+    def default_dist_backend(self) -> str:
+        return "gloo"
+
+    def supports_bf16(self) -> bool:
+        return True
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
new file mode 100644
index 00000000000..0d25c58ac8f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from contextlib import nullcontext
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+
+
+class CUDABackend(BackendAdapter):
+    type = BackendType.CUDA
+
+    def is_available(self) -> bool:
+        return torch.cuda.is_available()
+
+    def device_count(self) -> int:
+        return torch.cuda.device_count()
+
+    def make_device(self, index: int | None) -> torch.device:
+        if index is None:
+            raise ValueError("CUDA backend requires a valid device index")
+        return torch.device(f"cuda:{index}")
+
+    def set_device(self, index: int) -> None:
+        torch.cuda.set_device(index)
+
+    def synchronize(self) -> None:
+        torch.cuda.synchronize()
+
+    def autocast(self, enabled: bool, dtype: torch.dtype):
+        if not enabled:
+            return nullcontext()
+        return torch.autocast(device_type="cuda", dtype=dtype, enabled=True)
+
+    def make_grad_scaler(self, enabled: bool):
+        return torch.cuda.amp.GradScaler(enabled=enabled)
+
+    def default_dist_backend(self) -> str:
+        return "nccl"
+
+    def supports_bf16(self) -> bool:
+        fn = getattr(torch.cuda, "is_bf16_supported", None)
+        return bool(fn()) if callable(fn) else True
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
new file mode 100644
index 00000000000..f5927555703
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from dataclasses import dataclass
+from typing import Union, Optional
+
+import torch
+
+DeviceLike = Union[torch.device, str, int]
+
+@dataclass(frozen=True)
+class DeviceSpec:
+    type: str
+    index: Optional[int]
+
+def parse_device_like(x: DeviceLike) -> DeviceSpec:
+    if isinstance(x, int):
+        return DeviceSpec("index", x)
+
+    if isinstance(x, str):
+        try:
+            return DeviceSpec("index", int(x))
+        except ValueError:
+            s = x.strip().lower()
+            if ":" in s:
+                t, idx = s.split(":", 1)
+                return DeviceSpec(t, int(idx))
+            return DeviceSpec(s, None)
+
+    if isinstance(x, torch.device):
+        return DeviceSpec(x.type, x.index)
+
+    raise TypeError(f"Unsupported device: {x!r}")
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/env.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/env.py
new file mode 100644
index 00000000000..091495505e5
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/env.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from dataclasses import dataclass
+
+@dataclass(frozen=True)
+class DistEnv:
+    rank: int
+    local_rank: int
+    world_size: int
+
+def read_dist_env() -> DistEnv:
+    # torchrun:
+    rank = int(os.environ.get("RANK", "0"))
+    world_size = int(os.environ.get("WORLD_SIZE", "1"))
+
+    # torchrun provides LOCAL_RANK; slurm often provides SLURM_LOCALID
+    local_rank = os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", 
"0"))
+    local_rank = int(local_rank)
+
+    return DistEnv(rank=rank, local_rank=local_rank, world_size=world_size)
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index fb03e0af520..073fa5fedce 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -41,7 +41,6 @@ from 
iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler impor
 )
 from iotdb.ainode.core.log import Logger
 from iotdb.ainode.core.model.model_storage import ModelInfo
-from iotdb.ainode.core.util.gpu_mapping import 
convert_device_id_to_torch_device
 
 
 class PoolState(Enum):
@@ -64,7 +63,7 @@ class InferenceRequestPool(mp.Process):
         self,
         pool_id: int,
         model_info: ModelInfo,
-        device: str,
+        device: torch.device,
         request_queue: mp.Queue,
         result_queue: mp.Queue,
         ready_event,
@@ -75,7 +74,7 @@ class InferenceRequestPool(mp.Process):
         self.model_info = model_info
         self.pool_kwargs = pool_kwargs
         self.ready_event = ready_event
-        self.device = convert_device_id_to_torch_device(device)
+        self.device = device
 
         self._threads = []
         self._waiting_queue = request_queue  # Requests that are waiting to be 
processed
@@ -102,7 +101,7 @@ class InferenceRequestPool(mp.Process):
             request.mark_running()
             self._running_queue.put(request)
             self._logger.debug(
-                
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][Req-{request.req_id}] 
Request is activated with inputs shape {request.inputs.shape}"
+                
f"[Inference][{self.device}][Pool-{self.pool_id}][Req-{request.req_id}] Request 
is activated with inputs shape {request.inputs.shape}"
             )
 
     def _requests_activate_loop(self):
@@ -164,12 +163,12 @@ class InferenceRequestPool(mp.Process):
                     request.output_tensor = request.output_tensor.cpu()
                     self._finished_queue.put(request)
                     self._logger.debug(
-                        
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] 
Request is finished"
+                        
f"[Inference][{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request 
is finished"
                     )
                 else:
                     self._waiting_queue.put(request)
                     self._logger.debug(
-                        
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] 
Request is not finished, re-queueing"
+                        
f"[Inference][{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request 
is not finished, re-queueing"
                     )
         return
 
@@ -183,7 +182,7 @@ class InferenceRequestPool(mp.Process):
             INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
         )
         self._request_scheduler.device = self.device
-        self._inference_pipeline = load_pipeline(self.model_info, 
str(self.device))
+        self._inference_pipeline = load_pipeline(self.model_info, self.device)
         self.ready_event.set()
 
         activate_daemon = threading.Thread(
@@ -197,12 +196,12 @@ class InferenceRequestPool(mp.Process):
         self._threads.append(execute_daemon)
         execute_daemon.start()
         self._logger.info(
-            f"[Inference][Device-{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} is activated."
+            f"[Inference][{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} is activated."
         )
         for thread in self._threads:
             thread.join()
         self._logger.info(
-            f"[Inference][Device-{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
+            f"[Inference][{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
         )
 
     def stop(self):
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
index f1704fb90c4..7ccef492b41 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -22,8 +22,10 @@ import torch
 
 from iotdb.ainode.core.exception import InferenceModelInternalException
 from iotdb.ainode.core.model.model_info import ModelInfo
+from iotdb.ainode.core.manager.device_manager import DeviceManager
 from iotdb.ainode.core.model.model_loader import load_model
 
+BACKEND = DeviceManager()
 
 class BasicPipeline(ABC):
     def __init__(self, model_info: ModelInfo, **model_kwargs):
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
index a30038dd5fe..865a449aa32 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
@@ -19,6 +19,8 @@
 import os
 from pathlib import Path
 
+import torch
+
 from iotdb.ainode.core.config import AINodeDescriptor
 from iotdb.ainode.core.log import Logger
 from iotdb.ainode.core.model.model_constants import ModelCategory
@@ -28,7 +30,7 @@ from iotdb.ainode.core.model.utils import 
import_class_from_path, temporary_sys_
 logger = Logger()
 
 
-def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs):
+def load_pipeline(model_info: ModelInfo, device: torch.device, **model_kwargs):
     if model_info.model_type == "sktime":
         from iotdb.ainode.core.model.sktime.pipeline_sktime import 
SktimePipeline
 
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index c580a89916d..416422d578e 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -22,6 +22,7 @@ import threading
 from concurrent.futures import wait
 from typing import Dict, Optional
 
+import torch
 import torch.multiprocessing as mp
 
 from iotdb.ainode.core.exception import InferenceModelInternalException
@@ -56,7 +57,7 @@ class PoolController:
     def __init__(self, result_queue: mp.Queue):
         self._model_manager = ModelManager()
         # structure: {model_id: {device_id: PoolGroup}}
-        self._request_pool_map: Dict[str, Dict[str, PoolGroup]] = {}
+        self._request_pool_map: Dict[str, Dict[torch.device, PoolGroup]] = {}
         self._new_pool_id = AtomicInt()
         self._result_queue = result_queue
         self._pool_scheduler = BasicPoolScheduler(self._request_pool_map)
@@ -132,31 +133,31 @@ class PoolController:
         #     )
 
     # =============== Pool Management ===============
-    def load_model(self, model_id: str, device_id_list: list[str]):
+    def load_model(self, model_id: str, device_id_list: list[torch.device]):
         """
         Load the model to the specified devices asynchronously.
         Args:
             model_id (str): The ID of the model to be loaded.
-            device_id_list (list[str]): List of device_ids where the model 
should be loaded.
+            device_id_list (list[torch.device]): List of device_ids where the 
model should be loaded.
         """
         self._task_queue.put((self._load_model_task, (model_id, 
device_id_list), {}))
 
-    def unload_model(self, model_id: str, device_id_list: list[str]):
+    def unload_model(self, model_id: str, device_id_list: list[torch.device]):
         """
         Unload the model from the specified devices asynchronously.
         Args:
             model_id (str): The ID of the model to be unloaded.
-            device_id_list (list[str]): List of device_ids where the model 
should be unloaded.
+            device_id_list (list[torch.device]): List of device_ids where the 
model should be unloaded.
         """
         self._task_queue.put((self._unload_model_task, (model_id, 
device_id_list), {}))
 
     def show_loaded_models(
-        self, device_id_list: list[str]
+        self, device_id_list: list[torch.device]
     ) -> Dict[str, Dict[str, int]]:
         """
         Show loaded model instances on the specified devices.
         Args:
-            device_id_list (list[str]): List of device_ids where to examine 
loaded instances.
+            device_id_list (list[torch.device]): List of device_ids where to 
examine loaded instances.
         Return:
             Dict[str, Dict[str, int]]: Dict[device_id, Dict[model_id, 
Count(instances)]].
         """
@@ -167,7 +168,7 @@ class PoolController:
                 if device_id in device_map:
                     pool_group = device_map[device_id]
                     device_models[model_id] = 
pool_group.get_running_pool_count()
-            result[device_id] = device_models
+            result[str(device_id.index)] = device_models
         return result
 
     def _worker_loop(self):
@@ -184,8 +185,8 @@ class PoolController:
             finally:
                 self._task_queue.task_done()
 
-    def _load_model_task(self, model_id: str, device_id_list: list[str]):
-        def _load_model_on_device_task(device_id: str):
+    def _load_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
+        def _load_model_on_device_task(device_id: torch.device):
             if not self.has_request_pools(model_id, device_id):
                 actions = self._pool_scheduler.schedule_load_model_to_device(
                     self._model_manager.get_model_info(model_id), device_id
@@ -201,7 +202,7 @@ class PoolController:
                         )
             else:
                 logger.info(
-                    f"[Inference][Device-{device_id}] Model {model_id} is 
already installed."
+                    f"[Inference][{device_id}] Model {model_id} is already 
installed."
                 )
 
         load_model_futures = self._executor.submit_batch(
@@ -211,8 +212,8 @@ class PoolController:
             load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
         )
 
-    def _unload_model_task(self, model_id: str, device_id_list: list[str]):
-        def _unload_model_on_device_task(device_id: str):
+    def _unload_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
+        def _unload_model_on_device_task(device_id: torch.device):
             if self.has_request_pools(model_id, device_id):
                 actions = 
self._pool_scheduler.schedule_unload_model_from_device(
                     self._model_manager.get_model_info(model_id), device_id
@@ -228,7 +229,7 @@ class PoolController:
                         )
             else:
                 logger.info(
-                    f"[Inference][Device-{device_id}] Model {model_id} is not 
installed."
+                    f"[Inference][{device_id}] Model {model_id} is not 
installed."
                 )
 
         unload_model_futures = self._executor.submit_batch(
@@ -238,12 +239,12 @@ class PoolController:
             unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED
         )
 
-    def _expand_pools_on_device(self, model_id: str, device_id: str, count: 
int):
+    def _expand_pools_on_device(self, model_id: str, device_id: torch.device, 
count: int):
         """
         Expand the pools for the given model_id and device_id sequentially.
         Args:
             model_id (str): The ID of the model.
-            device_id (str): The ID of the device.
+            device_id (torch.device): The ID of the device.
             count (int): The number of pools to be expanded.
         """
 
@@ -263,14 +264,14 @@ class PoolController:
             self._register_pool(model_id, device_id, pool_id, pool, 
request_queue)
             if not pool.ready_event.wait(timeout=300):
                 logger.error(
-                    f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool 
failed to be ready in time"
+                    f"[Inference][{device_id}][Pool-{pool_id}] Pool failed to 
be ready in time"
                 )
                 # TODO: retry or decrease the count? this error should be 
better handled
                 self._erase_pool(model_id, device_id, pool_id)
             else:
                 self.set_state(model_id, device_id, pool_id, PoolState.RUNNING)
                 logger.info(
-                    f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool 
started running for model {model_id}"
+                    f"[Inference][{device_id}][Pool-{pool_id}] Pool started 
running for model {model_id}"
                 )
 
         expand_pool_futures = self._executor.submit_batch(
@@ -280,7 +281,7 @@ class PoolController:
             expand_pool_futures, return_when=concurrent.futures.ALL_COMPLETED
         )
 
-    def _shrink_pools_on_device(self, model_id: str, device_id: str, count):
+    def _shrink_pools_on_device(self, model_id: str, device_id: torch.device, 
count: int):
         """
         Shrink the pools for the given model_id by count sequentially.
         TODO: shrink pools in parallel
@@ -335,7 +336,7 @@ class PoolController:
     def _register_pool(
         self,
         model_id: str,
-        device_id: str,
+        device_id: torch.device,
         pool_id: int,
         request_pool: InferenceRequestPool,
         request_queue: mp.Queue,
@@ -349,7 +350,7 @@ class PoolController:
         pool_group: PoolGroup = self.get_request_pools_group(model_id, 
device_id)
         pool_group.set_state(pool_id, PoolState.INITIALIZING)
         logger.info(
-            f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool 
initializing for model {model_id}"
+            f"[Inference][{device_id}][Pool-{pool_id}] Pool initializing for 
model {model_id}"
         )
 
     def _erase_pool(self, model_id: str, device_id: str, pool_id: int):
@@ -360,7 +361,7 @@ class PoolController:
         if pool_group:
             pool_group.remove_pool(pool_id)
             logger.info(
-                f"[Inference][Device-{device_id}][Pool-{pool_id}] Erase pool 
for model {model_id}"
+                f"[Inference][{device_id}][Pool-{pool_id}] Erase pool for 
model {model_id}"
             )
         # Clean up empty structures
         if pool_group and not pool_group.get_pool_ids():
@@ -387,7 +388,7 @@ class PoolController:
         self._request_pool_map[model_id][device_id].dispatch_request(req, 
infer_proxy)
 
     # =============== Getters / Setters ===============
-    def get_state(self, model_id, device_id, pool_id) -> Optional[PoolState]:
+    def get_state(self, model_id: str, device_id: torch.device, pool_id: int) 
-> Optional[PoolState]:
         """
         Get the state of the specified pool based on model_id, device_id, and 
pool_id.
         """
@@ -396,7 +397,7 @@ class PoolController:
             return pool_group.get_state(pool_id)
         return None
 
-    def set_state(self, model_id, device_id, pool_id, state):
+    def set_state(self, model_id: str, device_id: torch.device, pool_id: int, 
state: PoolState):
         """
         Set the state of the specified pool based on model_id, device_id, and 
pool_id.
         """
@@ -404,7 +405,7 @@ class PoolController:
         if pool_group:
             pool_group.set_state(pool_id, state)
 
-    def get_device_ids(self, model_id) -> list[str]:
+    def get_device_ids(self, model_id) -> list[torch.device]:
         """
         Get the list of device IDs for the given model_id, where the 
corresponding instances are loaded.
         """
@@ -412,7 +413,7 @@ class PoolController:
             return list(self._request_pool_map[model_id].keys())
         return []
 
-    def get_pool_ids(self, model_id: str, device_id: str) -> list[int]:
+    def get_pool_ids(self, model_id: str, device_id: torch.device) -> 
list[int]:
         """
         Get the list of pool IDs for the given model_id and device_id.
         """
@@ -421,9 +422,9 @@ class PoolController:
             return pool_group.get_pool_ids()
         return []
 
-    def has_request_pools(self, model_id: str, device_id: Optional[str] = 
None) -> bool:
+    def has_request_pools(self, model_id: str, device_id: 
Optional[torch.device]) -> bool:
         """
-        Check if there are request pools for the given model_id and device_id 
(optional).
+        Check if there are request pools for the given model_id ((optional) 
and device_id).
         """
         if model_id not in self._request_pool_map:
             return False
@@ -432,7 +433,7 @@ class PoolController:
         return True
 
     def get_request_pools_group(
-        self, model_id: str, device_id: str
+        self, model_id: str, device_id: torch.device
     ) -> Optional[PoolGroup]:
         if (
             model_id in self._request_pool_map
@@ -443,14 +444,14 @@ class PoolController:
             return None
 
     def get_request_pool(
-        self, model_id, device_id, pool_id
+        self, model_id: str, device_id: torch.device, pool_id: int
     ) -> Optional[InferenceRequestPool]:
         pool_group = self.get_request_pools_group(model_id, device_id)
         if pool_group:
             return pool_group.get_request_pool(pool_id)
         return None
 
-    def get_request_queue(self, model_id, device_id, pool_id) -> 
Optional[mp.Queue]:
+    def get_request_queue(self, model_id: str, device_id: torch.device, 
pool_id: int) -> Optional[mp.Queue]:
         pool_group = self.get_request_pools_group(model_id, device_id)
         if pool_group:
             return pool_group.get_request_queue(pool_id)
@@ -459,7 +460,7 @@ class PoolController:
     def set_request_pool_map(
         self,
         model_id: str,
-        device_id: str,
+        device_id: torch.device,
         pool_id: int,
         request_pool: InferenceRequestPool,
         request_queue: mp.Queue,
@@ -478,7 +479,7 @@ class PoolController:
             f"[Inference][Device-{device_id}][Pool-{pool_id}] Registered pool 
for model {model_id}"
         )
 
-    def get_load(self, model_id, device_id, pool_id) -> int:
+    def get_load(self, model_id: str, device_id: torch.device, pool_id: int) 
-> int:
         """
         Get the current load of the specified pool.
         """
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
index 19d21f5822d..7e74a6c62b3 100644
--- 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
+++ 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
@@ -21,6 +21,8 @@ from dataclasses import dataclass
 from enum import Enum
 from typing import Dict, List
 
+import torch
+
 from iotdb.ainode.core.inference.pool_group import PoolGroup
 from iotdb.ainode.core.model.model_info import ModelInfo
 
@@ -35,7 +37,7 @@ class ScaleAction:
     action: ScaleActionType
     amount: int
     model_id: str
-    device_id: str
+    device_id: torch.device
 
 
 class AbstractPoolScheduler(ABC):
@@ -43,10 +45,10 @@ class AbstractPoolScheduler(ABC):
     Abstract base class for pool scheduling strategies.
     """
 
-    def __init__(self, request_pool_map: Dict[str, Dict[str, PoolGroup]]):
+    def __init__(self, request_pool_map: Dict[str, Dict[torch.device, 
PoolGroup]]):
         """
         Args:
-            request_pool_map: Dict["model_id", Dict["device_id", PoolGroup]].
+            request_pool_map: Dict["model_id", Dict[device_id, PoolGroup]].
         """
         self._request_pool_map = request_pool_map
 
@@ -59,7 +61,7 @@ class AbstractPoolScheduler(ABC):
 
     @abstractmethod
     def schedule_load_model_to_device(
-        self, model_info: ModelInfo, device_id: str
+        self, model_info: ModelInfo, device_id: torch.device
     ) -> List[ScaleAction]:
         """
         Schedule a series of actions to load the model to the device.
@@ -73,7 +75,7 @@ class AbstractPoolScheduler(ABC):
 
     @abstractmethod
     def schedule_unload_model_from_device(
-        self, model_info: ModelInfo, device_id: str
+        self, model_info: ModelInfo, device_id: torch.device
     ) -> List[ScaleAction]:
         """
         Schedule a series of actions to unload the model from the device.
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index 21140cafb1f..65aa7714393 100644
--- 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++ 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional
 
 import torch
 
-from iotdb.ainode.core.exception import InferenceModelInternalException
 from iotdb.ainode.core.inference.pool_group import PoolGroup
 from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import 
(
     AbstractPoolScheduler,
@@ -33,11 +32,9 @@ from iotdb.ainode.core.manager.utils import (
     INFERENCE_EXTRA_MEMORY_RATIO,
     INFERENCE_MEMORY_USAGE_RATIO,
     MODEL_MEM_USAGE_MAP,
-    estimate_pool_size,
     evaluate_system_resources,
 )
 from iotdb.ainode.core.model.model_info import ModelInfo
-from iotdb.ainode.core.util.gpu_mapping import 
convert_device_id_to_torch_device
 
 logger = Logger()
 
@@ -74,7 +71,7 @@ def _estimate_shared_pool_size_by_total_mem(
     usable_mem = total_mem * INFERENCE_MEMORY_USAGE_RATIO
     if usable_mem <= 0:
         logger.error(
-            f"[Inference][Device-{device}] No usable memory on device. 
total={total_mem / 1024 ** 2:.2f} MB, usable={usable_mem / 1024 ** 2:.2f} MB"
+            f"[Inference][{device}] No usable memory on device. 
total={total_mem / 1024 ** 2:.2f} MB, usable={usable_mem / 1024 ** 2:.2f} MB"
         )
 
     # Each model gets an equal share of the TOTAL memory
@@ -87,39 +84,32 @@ def _estimate_shared_pool_size_by_total_mem(
         pool_num = int(per_model_share // mem_usages[model_info.model_id])
         if pool_num <= 0:
             logger.warning(
-                f"[Inference][Device-{device}] Not enough TOTAL memory to 
guarantee at least 1 pool for model {model_info.model_id}, no pool will be 
scheduled for this model. "
+                f"[Inference][{device}] Not enough TOTAL memory to guarantee 
at least 1 pool for model {model_info.model_id}, no pool will be scheduled for 
this model. "
                 f"Per-model share={per_model_share / 1024 ** 2:.2f} MB, 
need>={mem_usages[model_info.model_id] / 1024 ** 2:.2f} MB"
             )
         allocation[model_info.model_id] = pool_num
     logger.info(
-        f"[Inference][Device-{device}] Shared pool allocation (by TOTAL 
memory): {allocation}"
+        f"[Inference][{device}] Shared pool allocation (by TOTAL memory): 
{allocation}"
     )
     return allocation
 
 
 class BasicPoolScheduler(AbstractPoolScheduler):
     """
-    A basic scheduler to init the request pools. In short, different kind of 
models will equally share the available resource of the located device, and 
scale down actions are always ahead of scale up.
+    A basic scheduler to init the request pools. In short,
+    different kind of models will equally share the available resource of the 
located device,
+    and scale down actions are always ahead of scale up.
     """
 
-    def __init__(self, request_pool_map: Dict[str, Dict[str, PoolGroup]]):
+    def __init__(self, request_pool_map: Dict[str, Dict[torch.device, 
PoolGroup]]):
         super().__init__(request_pool_map)
         self._model_manager = ModelManager()
 
     def schedule(self, model_id: str) -> List[ScaleAction]:
-        """
-        Schedule a scaling action for the given model_id.
-        """
-        if model_id not in self._request_pool_map:
-            pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id)
-            if pool_num <= 0:
-                raise InferenceModelInternalException(
-                    f"Not enough memory to run model {model_id}."
-                )
-            return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
+        pass
 
     def schedule_load_model_to_device(
-        self, model_info: ModelInfo, device_id: str
+        self, model_info: ModelInfo, device_id: torch.device
     ) -> List[ScaleAction]:
         existing_model_infos = [
             self._model_manager.get_model_info(existing_model_id)
@@ -127,7 +117,7 @@ class BasicPoolScheduler(AbstractPoolScheduler):
             if existing_model_id != model_info.model_id and device_id in 
pool_group_map
         ]
         allocation_result = _estimate_shared_pool_size_by_total_mem(
-            device=convert_device_id_to_torch_device(device_id),
+            device=device_id,
             existing_model_infos=existing_model_infos,
             new_model_info=model_info,
         )
@@ -136,7 +126,7 @@ class BasicPoolScheduler(AbstractPoolScheduler):
         )
 
     def schedule_unload_model_from_device(
-        self, model_info: ModelInfo, device_id: str
+        self, model_info: ModelInfo, device_id: torch.device
     ) -> List[ScaleAction]:
         existing_model_infos = [
             self._model_manager.get_model_info(existing_model_id)
@@ -145,7 +135,7 @@ class BasicPoolScheduler(AbstractPoolScheduler):
         ]
         allocation_result = (
             _estimate_shared_pool_size_by_total_mem(
-                device=convert_device_id_to_torch_device(device_id),
+                device=device_id,
                 existing_model_infos=existing_model_infos,
                 new_model_info=None,
             )
@@ -159,10 +149,11 @@ class BasicPoolScheduler(AbstractPoolScheduler):
         )
 
     def _convert_allocation_result_to_scale_actions(
-        self, allocation_result: Dict[str, int], device_id: str
+        self, allocation_result: Dict[str, int], device_id: torch.device
     ) -> List[ScaleAction]:
         """
-        Convert the model allocation result to List[ScaleAction], where the 
scale down actions are always ahead of the scale up.
+        Convert the model allocation result to List[ScaleAction],
+        where the scale down actions are always ahead of the scale up.
         """
         actions = []
         for model_id, target_num in allocation_result.items():
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
new file mode 100644
index 00000000000..2271aa4ba0e
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from dataclasses import dataclass
+from typing import Optional, ContextManager
+import os
+import torch
+
+from iotdb.ainode.core.device.env import read_dist_env, DistEnv
+from iotdb.ainode.core.device.device_utils import (DeviceLike, 
parse_device_like)
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+from iotdb.ainode.core.device.backend.cuda_backend import CUDABackend
+from iotdb.ainode.core.device.backend.cpu_backend import CPUBackend
+from iotdb.ainode.core.util.decorator import singleton
+
+
+@dataclass(frozen=True)
+class DeviceManagerConfig:
+    use_local_rank_if_distributed: bool = True
+
+@singleton
+class DeviceManager:
+    """
+    Unified device entry point:
+    - Select backend (cuda/npu/cpu)
+    - Parse device expression (None/int/str/torch.device/DeviceSpec)
+    - Provide device, autocast, grad scaler, synchronize, dist backend 
recommendation, etc.
+    """
+    def __init__(self, cfg: DeviceManagerConfig):
+        self.cfg = cfg
+        self.env: DistEnv = read_dist_env()
+
+        self.backends: dict[BackendType, BackendAdapter] = {
+            BackendType.CUDA: CUDABackend(),
+            BackendType.CPU: CPUBackend(),
+        }
+
+        self.type: BackendType
+        self.backend: BackendAdapter = self._auto_select_backend()
+        self.default_index: Optional[int] = self._select_default_index()
+
+        # ensure process uses correct device early
+        self._set_device_for_process()
+        self.device: torch.device = 
self.backend.make_device(self.default_index)
+
+    # ==================== selection ====================
+    def _auto_select_backend(self) -> BackendAdapter:
+        for name in BackendType:
+            backend = self.backends.get(name)
+            if backend is not None and backend.is_available():
+                self.type = backend.type
+                return backend
+        return self.backends[BackendType.CPU]
+
+    def _select_default_index(self) -> Optional[int]:
+        if self.backend.type == BackendType.CPU:
+            return None
+        if self.cfg.use_local_rank_if_distributed and self.env.world_size > 1:
+            return self.env.local_rank
+        return 0
+
+    def _set_device_for_process(self) -> None:
+        if self.backend.type in (BackendType.CUDA) and self.default_index is 
not None:
+            self.backend.set_device(self.default_index)
+
+    # ==================== public API ====================
+    def device_ids(self) -> list[int]:
+        """
+        Returns a list of available device IDs for the current backend.
+        """
+        if self.backend.type == BackendType.CPU:
+            return []
+        return list(range(self.backend.device_count()))
+
+    def str_device_ids_with_cpu(self) -> list[str]:
+        """
+        Returns a list of available device IDs as strings, including "cpu".
+        """
+        device_id_list = self.device_ids()
+        device_id_list = [str(device_id) for device_id in device_id_list]
+        device_id_list.append("cpu")
+        return device_id_list
+
+    def torch_device(self, device: DeviceLike) -> torch.device:
+        """
+        Convert a DeviceLike specification into a torch.device object.
+        If device is None, returns the default device of current process.
+        Args:
+            device: Could be any of the following formats:
+                an integer (e.g., 0, 1, ...),
+                a string (e.g., "0", "cuda:0", "cpu", ...),
+                a torch.device object, return itself if so.
+        """
+        if isinstance(device, torch.device):
+            return device
+        spec = parse_device_like(device)
+        if spec.type == "cpu":
+            return torch.device("cpu")
+        return self.backend.make_device(spec.index)
+
+    def move_model(self, model: torch.nn.Module, device: DeviceLike = None) -> 
torch.nn.Module:
+        return model.to(self.torch_device(device))
+
+    def move_tensor(self, tensor: torch.Tensor, device: DeviceLike = None) -> 
torch.Tensor:
+        return tensor.to(self.torch_device(device))
+
+    def synchronize(self) -> None:
+        self.backend.synchronize()
+
+    def autocast(self, enabled: bool, dtype: torch.dtype) -> ContextManager:
+        return self.backend.autocast(enabled=enabled, dtype=dtype)
+
+    def make_grad_scaler(self, enabled: bool):
+        return self.backend.make_grad_scaler(enabled=enabled)
+
+    def default_dist_backend(self) -> str:
+        # allow user override
+        return os.environ.get("TORCH_DIST_BACKEND", 
self.backend.default_dist_backend())
+
+    def supports_bf16(self) -> bool:
+        return self.backend.supports_bf16()
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index ada641dd54c..d5482bc99b0 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -42,9 +42,9 @@ from iotdb.ainode.core.inference.pipeline.pipeline_loader 
import load_pipeline
 from iotdb.ainode.core.inference.pool_controller import PoolController
 from iotdb.ainode.core.inference.utils import generate_req_id
 from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.manager.device_manager import DeviceManager
 from iotdb.ainode.core.manager.model_manager import ModelManager
 from iotdb.ainode.core.rpc.status import get_status
-from iotdb.ainode.core.util.gpu_mapping import get_available_devices
 from iotdb.ainode.core.util.serde import (
     convert_tensor_to_tsblock,
     convert_tsblock_to_tensor,
@@ -71,6 +71,7 @@ class InferenceManager:
 
     def __init__(self):
         self._model_manager = ModelManager()
+        self._backend = DeviceManager()
         self._model_mem_usage_map: Dict[str, int] = (
             {}
         )  # store model memory usage for each model
@@ -85,22 +86,30 @@ class InferenceManager:
         self._result_handler_thread.start()
         self._pool_controller = PoolController(self._result_queue)
 
-    def load_model(self, req: TLoadModelReq) -> TSStatus:
-        devices_to_be_processed = []
-        devices_not_to_be_processed = []
-        for device_id in req.deviceIdList:
+    def load_model(self, existing_model_id: str, device_id_list: 
list[torch.device]) -> TSStatus:
+        """
+        Load a model to specified devices.
+        Args:
+            existing_model_id (str): The ID of the model to be loaded.
+            device_id_list (list[torch.device]): List of device IDs to load 
the model onto.
+        Returns:
+            TSStatus: The status of the load model operation.
+        """
+        devices_to_be_processed: list[torch.device] = []
+        devices_not_to_be_processed: list[torch.device] = []
+        for device_id in device_id_list:
             if self._pool_controller.has_request_pools(
-                model_id=req.existingModelId, device_id=device_id
+                model_id=existing_model_id, device_id=device_id
             ):
                 devices_not_to_be_processed.append(device_id)
             else:
                 devices_to_be_processed.append(device_id)
         if len(devices_to_be_processed) > 0:
             self._pool_controller.load_model(
-                model_id=req.existingModelId, 
device_id_list=devices_to_be_processed
+                model_id=existing_model_id, 
device_id_list=devices_to_be_processed
             )
         logger.info(
-            f"[Inference] Start loading model [{req.existingModelId}] to 
devices [{devices_to_be_processed}], skipped devices 
[{devices_not_to_be_processed}] cause they have already loaded this model."
+            f"[Inference] Start loading model [{existing_model_id}] to devices 
[{devices_to_be_processed}], skipped devices [{devices_not_to_be_processed}] 
cause they have already loaded this model."
         )
         return TSStatus(
             code=TSStatusCode.SUCCESS_STATUS.value,
@@ -135,7 +144,7 @@ class InferenceManager:
             deviceLoadedModelsMap=self._pool_controller.show_loaded_models(
                 req.deviceIdList
                 if len(req.deviceIdList) > 0
-                else get_available_devices()
+                else self._backend.str_device_ids_with_cpu()
             ),
         )
 
diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py 
b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
index 492802fc060..97059f7f169 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
@@ -19,10 +19,10 @@
 from iotdb.ainode.core.constant import TSStatusCode
 from iotdb.ainode.core.log import Logger
 from iotdb.ainode.core.manager.cluster_manager import ClusterManager
+from iotdb.ainode.core.manager.device_manager import DeviceManager
 from iotdb.ainode.core.manager.inference_manager import InferenceManager
 from iotdb.ainode.core.manager.model_manager import ModelManager
 from iotdb.ainode.core.rpc.status import get_status
-from iotdb.ainode.core.util.gpu_mapping import get_available_devices
 from iotdb.thrift.ainode import IAINodeRPCService
 from iotdb.thrift.ainode.ttypes import (
     TAIHeartbeatReq,
@@ -48,25 +48,12 @@ from iotdb.thrift.common.ttypes import TSStatus
 logger = Logger()
 
 
-def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus:
-    """
-    Ensure that the device IDs in the provided list are available.
-    """
-    available_devices = get_available_devices()
-    for device_id in device_id_list:
-        if device_id not in available_devices:
-            return TSStatus(
-                code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
-                message=f"AIDevice ID [{device_id}] is not available. You can 
use 'SHOW AI_DEVICES' to retrieve the available devices.",
-            )
-    return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
-
-
 class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
     def __init__(self, ainode):
         self._ainode = ainode
         self._model_manager = ModelManager()
         self._inference_manager = InferenceManager()
+        self._backend = DeviceManager()
 
     # ==================== Cluster Management ====================
 
@@ -82,9 +69,12 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
         return ClusterManager.get_heart_beat(req)
 
     def showAIDevices(self) -> TShowAIDevicesResp:
+        device_id_map = {"cpu": "cpu"}
+        for device_id in self._backend.device_ids():
+            device_id_map[str(device_id)] = self._backend.type.value
         return TShowAIDevicesResp(
             status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value),
-            deviceIdList=get_available_devices(),
+            deviceIdMap=device_id_map,
         )
 
     # ==================== Model Management ====================
@@ -102,7 +92,7 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
         status = self._ensure_model_is_registered(req.existingModelId)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return status
-        status = _ensure_device_id_is_available(req.deviceIdList)
+        status = self._ensure_device_id_is_available(req.deviceIdList)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return status
         return self._inference_manager.load_model(req)
@@ -111,13 +101,13 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
         status = self._ensure_model_is_registered(req.modelId)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return status
-        status = _ensure_device_id_is_available(req.deviceIdList)
+        status = self._ensure_device_id_is_available(req.deviceIdList)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return status
         return self._inference_manager.unload_model(req)
 
     def showLoadedModels(self, req: TShowLoadedModelsReq) -> 
TShowLoadedModelsResp:
-        status = _ensure_device_id_is_available(req.deviceIdList)
+        status = self._ensure_device_id_is_available(req.deviceIdList)
         if status.code != TSStatusCode.SUCCESS_STATUS.value:
             return TShowLoadedModelsResp(status=status, 
deviceLoadedModelsMap={})
         return self._inference_manager.show_loaded_models(req)
@@ -144,6 +134,21 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
             return TForecastResp(status, [])
         return self._inference_manager.forecast(req)
 
+    # ==================== Internal API ====================
+
+    def _ensure_device_id_is_available(self, device_id_list: list[str]) -> 
TSStatus:
+        """
+        Ensure that the device IDs in the provided list are available.
+        """
+        available_devices = self._backend.device_ids()
+        for device_id in device_id_list:
+            if device_id != "cpu" and int(device_id) not in available_devices:
+                return TSStatus(
+                    code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
+                    message=f"AIDevice ID [{device_id}] is not available. You 
can use 'SHOW AI_DEVICES' to retrieve the available devices.",
+                )
+        return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
+
     # ==================== Tuning ====================
 
     def createTuningTask(self, req: TTuningReq) -> TSStatus:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/gpu_mapping.py 
b/iotdb-core/ainode/iotdb/ainode/core/util/gpu_mapping.py
deleted file mode 100644
index 72b056adb87..00000000000
--- a/iotdb-core/ainode/iotdb/ainode/core/util/gpu_mapping.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-import torch
-
-
-def convert_device_id_to_torch_device(device_id: str) -> torch.device:
-    """
-    Converts a device ID string to a torch.device object.
-
-    Args:
-        device_id (str): The device ID string. It can be "cpu" or a GPU index 
like "0", "1", etc.
-
-    Returns:
-        torch.device: The corresponding torch.device object.
-
-    Raises:
-        ValueError: If the device_id is not "cpu" or a valid integer string.
-    """
-    if device_id.lower() == "cpu":
-        return torch.device("cpu")
-    try:
-        gpu_index = int(device_id)
-        if gpu_index < 0:
-            raise ValueError
-        return torch.device(f"cuda:{gpu_index}")
-    except ValueError:
-        raise ValueError(
-            f"Invalid device_id '{device_id}'. It should be 'cpu' or a 
non-negative integer string."
-        )
-
-
-def get_available_gpus() -> list[int]:
-    """
-    Returns a list of available GPU indices if CUDA is available, otherwise 
returns an empty list.
-    """
-
-    if not torch.cuda.is_available():
-        return []
-    return list(range(torch.cuda.device_count()))
-
-
-def get_available_devices() -> list[str]:
-    """
-    Returns: a list of available device IDs as strings, including "cpu".
-    """
-    device_id_list = get_available_gpus()
-    device_id_list = [str(device_id) for device_id in device_id_list]
-    device_id_list.append("cpu")
-    return device_id_list
-
-
-def parse_devices(devices):
-    """
-    Parses the input string of GPU devices and returns a comma-separated 
string of valid GPU indices.
-
-    Args:
-        devices (str): A comma-separated string of GPU indices (e.g., "0,1,2").
-    Returns:
-        str: A comma-separated string of valid GPU indices corresponding to 
the input. All available GPUs if no input is provided.
-    Exceptions:
-        RuntimeError: If no GPUs are available.
-        ValueError: If any of the provided GPU indices are not available.
-    """
-    if devices is None or devices == "":
-        gpu_ids = get_available_gpus()
-        if not gpu_ids:
-            raise RuntimeError("No available GPU")
-        return ",".join(map(str, gpu_ids))
-    else:
-        gpu_ids = [int(gpu) for gpu in devices.split(",")]
-        available_gpus = get_available_gpus()
-        for gpu_id in gpu_ids:
-            if gpu_id not in available_gpus:
-                raise ValueError(
-                    f"GPU {gpu_id} is not available, the available choices 
are: {available_gpus}"
-                )
-        return devices
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index 844bde45bd6..6e2d6984eee 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -76,7 +76,7 @@ exclude = [
 ]
 
 [tool.poetry.dependencies]
-python = ">=3.11.0,<3.14.0"
+python = ">=3.11.0,<3.12.0"
 
 # ---- DL / HF stack ----
 torch = "^2.8.0,<2.9.0"
@@ -88,9 +88,9 @@ safetensors = "^0.6.2"
 einops = "^0.8.1"
 
 # ---- Core scientific stack ----
-numpy = "^2.3.2"
+numpy = ">=2.0,<2.4.0"
+pandas = ">=2.0,<2.4.0"
 scipy = "^1.12.0"
-pandas = "^2.3.2"
 scikit-learn = "^1.7.1"
 statsmodels = "^0.14.5"
 sktime = "0.40.1"
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
index 690f6f9485f..2f856e846b1 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
@@ -36,6 +36,7 @@ import org.apache.tsfile.read.common.block.TsBlockBuilder;
 import org.apache.tsfile.utils.BytesUtils;
 
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 
 public class ShowAIDevicesTask implements IConfigTask {
@@ -53,9 +54,10 @@ public class ShowAIDevicesTask implements IConfigTask {
             .map(ColumnHeader::getColumnType)
             .collect(Collectors.toList());
     TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
-    for (String deviceId : resp.getDeviceIdList()) {
+    for (Map.Entry<String, String> deviceEntry : 
resp.getDeviceIdMap().entrySet()) {
       builder.getTimeColumnBuilder().writeLong(0L);
-      builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(deviceId));
+      
builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(deviceEntry.getKey()));
+      
builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(deviceEntry.getValue()));
       builder.declarePosition();
     }
     DatasetHeader datasetHeader = 
DatasetHeaderFactory.getShowAIDevicesHeader();
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
index 0459d4d2c86..dba2c2e368d 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
@@ -36,6 +36,7 @@ public class ColumnHeaderConstant {
   public static final String VALUE = "Value";
   public static final String DEVICE = "Device";
   public static final String DEVICE_ID = "DeviceId";
+  public static final String DEVICE_TYPE = "DeviceType";
   public static final String EXPLAIN_ANALYZE = "Explain Analyze";
 
   // column names for schema statement
@@ -660,7 +661,9 @@ public class ColumnHeaderConstant {
           new ColumnHeader(COUNT_INSTANCES, TSDataType.INT32));
 
   public static final List<ColumnHeader> showAIDevicesColumnHeaders =
-      ImmutableList.of(new ColumnHeader(DEVICE_ID, TSDataType.TEXT));
+      ImmutableList.of(
+          new ColumnHeader(DEVICE_ID, TSDataType.TEXT),
+          new ColumnHeader(DEVICE_TYPE, TSDataType.TEXT));
 
   public static final List<ColumnHeader> showLogicalViewColumnHeaders =
       ImmutableList.of(
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 8a5971823ec..1cb585f0323 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -118,7 +118,7 @@ struct TShowLoadedModelsResp {
 
 struct TShowAIDevicesResp {
     1: required common.TSStatus status
-    2: required list<string> deviceIdList
+    2: required map<string, string> deviceIdMap
 }
 
 struct TLoadModelReq {

Reply via email to