This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new f71aabf60b6 [AINode] Integrate device manager framework (#16998)
f71aabf60b6 is described below
commit f71aabf60b6769bfc6280f69d254329325b3ab0a
Author: Yongzao <[email protected]>
AuthorDate: Mon Jan 12 21:12:24 2026 +0800
[AINode] Integrate device manager framework (#16998)
---
.../ainode/it/AINodeConcurrentForecastIT.java | 9 +-
.../iotdb/ainode/it/AINodeDeviceManageIT.java | 96 ++++++++++++++++++
iotdb-core/ainode/iotdb/ainode/core/constant.py | 2 +-
.../ainode/iotdb/ainode/core/device/__init__.py | 17 ++++
.../iotdb/ainode/core/device/backend/__init__.py | 17 ++++
.../iotdb/ainode/core/device/backend/base.py | 42 ++++++++
.../ainode/core/device/backend/cpu_backend.py | 37 +++++++
.../ainode/core/device/backend/cuda_backend.py | 39 ++++++++
.../iotdb/ainode/core/device/device_utils.py | 49 ++++++++++
iotdb-core/ainode/iotdb/ainode/core/device/env.py | 39 ++++++++
.../core/inference/inference_request_pool.py | 31 +++---
.../core/inference/pipeline/basic_pipeline.py | 11 ++-
.../core/inference/pipeline/pipeline_loader.py | 4 +-
.../iotdb/ainode/core/inference/pool_controller.py | 88 ++++++++++-------
.../pool_scheduler/abstract_pool_scheduler.py | 12 ++-
.../pool_scheduler/basic_pool_scheduler.py | 39 +++-----
.../iotdb/ainode/core/manager/device_manager.py | 108 +++++++++++++++++++++
.../iotdb/ainode/core/manager/inference_manager.py | 58 ++++++-----
.../ainode/iotdb/ainode/core/manager/utils.py | 6 +-
.../core/model/chronos2/pipeline_chronos2.py | 2 +-
.../ainode/iotdb/ainode/core/model/model_loader.py | 14 ++-
.../ainode/core/model/sktime/pipeline_sktime.py | 2 +-
.../ainode/core/model/sundial/pipeline_sundial.py | 2 +-
.../ainode/core/model/timer_xl/pipeline_timer.py | 2 +-
iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py | 62 +++++++-----
.../ainode/iotdb/ainode/core/util/gpu_mapping.py | 93 ------------------
iotdb-core/ainode/pyproject.toml | 6 +-
.../ainode/resources/conf/iotdb-ainode.properties | 2 +-
.../config/metadata/ai/ShowAIDevicesTask.java | 15 ++-
.../schema/column/ColumnHeaderConstant.java | 5 +-
.../thrift-ainode/src/main/thrift/ainode.thrift | 2 +-
31 files changed, 658 insertions(+), 253 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
index 7b465d10051..fd021099d5f 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
@@ -83,13 +83,15 @@ public class AINodeConcurrentForecastIT {
}
@Test
- public void concurrentGPUForecastTest() throws SQLException,
InterruptedException {
+ public void concurrentForecastTest() throws SQLException,
InterruptedException {
for (AINodeTestUtils.FakeModelInfo modelInfo : MODEL_LIST) {
- concurrentGPUForecastTest(modelInfo);
+ concurrentGPUForecastTest(modelInfo, "0,1");
+ // TODO: Enable cpu test after optimize memory consumption
+ // concurrentGPUForecastTest(modelInfo, "cpu");
}
}
- public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo
modelInfo)
+ public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo
modelInfo, String devices)
throws SQLException, InterruptedException {
final int forecastLength = 512;
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
@@ -100,7 +102,6 @@ public class AINodeConcurrentForecastIT {
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(),
forecastLength);
final int threadCnt = 10;
final int loop = 100;
- final String devices = "0,1";
statement.execute(
String.format("LOAD MODEL %s TO DEVICES '%s'",
modelInfo.getModelId(), devices));
checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeDeviceManageIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeDeviceManageIT.java
new file mode 100644
index 00000000000..bbffd3cffb0
--- /dev/null
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeDeviceManageIT.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.ainode.it;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.AIClusterIT;
+import org.apache.iotdb.itbase.env.BaseEnv;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({AIClusterIT.class})
+public class AINodeDeviceManageIT {
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ // Init 1C1D1A cluster environment
+ EnvFactory.getEnv().initClusterEnvironment(1, 1);
+ prepareDataInTree();
+ prepareDataInTable();
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ @Test
+ public void showAIDeviceTestInTree() throws SQLException {
+ try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ showAIDevicesTest(statement);
+ }
+ }
+
+ @Test
+ public void showAIDeviceTestInTable() throws SQLException {
+ try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ showAIDevicesTest(statement);
+ }
+ }
+
+ private void showAIDevicesTest(Statement statement) throws SQLException {
+ final String showSql = "SHOW AI_DEVICES";
+ final List<String> expectedDeviceIdList = new
LinkedList<>(Arrays.asList("0", "1", "cpu"));
+ final List<String> expectedDeviceTypeList =
+ new LinkedList<>(Arrays.asList("cuda", "cuda", "cpu"));
+ try (ResultSet resultSet = statement.executeQuery(showSql)) {
+ ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
+ checkHeader(resultSetMetaData, "DeviceId,DeviceType");
+ while (resultSet.next()) {
+ String deviceId = resultSet.getString(1);
+ String deviceType = resultSet.getString(2);
+ Assert.assertEquals(expectedDeviceIdList.remove(0), deviceId);
+ Assert.assertEquals(expectedDeviceTypeList.remove(0), deviceType);
+ }
+ }
+ }
+}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py
b/iotdb-core/ainode/iotdb/ainode/core/constant.py
index 44e76840f73..8a83c981437 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/constant.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py
@@ -56,7 +56,7 @@ AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {
"timer": 856 * 1024**2, # 856 MiB
} # the memory usage of each model in bytes
-AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for
inference
+AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2 # the device space allocated for
inference
AINODE_INFERENCE_EXTRA_MEMORY_RATIO = (
1.2 # the overhead ratio for inference, used to estimate the pool size
)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/__init__.py
b/iotdb-core/ainode/iotdb/ainode/core/device/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/__init__.py
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
new file mode 100644
index 00000000000..bf85a93a0c3
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/base.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from enum import Enum
+from typing import ContextManager, Optional, Protocol
+
+import torch
+
+
+class BackendType(Enum):
+ """
+ Different types of supported computation backends.
+ AINode will automatically select the available backend according to the
order defined here.
+ """
+
+ CUDA = "cuda"
+ CPU = "cpu"
+
+
+class BackendAdapter(Protocol):
+ type: BackendType
+
+ # device basics
+ def is_available(self) -> bool: ...
+ def device_count(self) -> int: ...
+ def make_device(self, index: Optional[int]) -> torch.device: ...
+ def set_device(self, index: int) -> None: ...
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
new file mode 100644
index 00000000000..f8c63817c5e
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cpu_backend.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+
+
+class CPUBackend(BackendAdapter):
+ type = BackendType.CPU
+
+ def is_available(self) -> bool:
+ return True
+
+ def device_count(self) -> int:
+ return 1
+
+ def make_device(self, index: int | None) -> torch.device:
+ return torch.device("cpu")
+
+ def set_device(self, index: int) -> None:
+ return None
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
new file mode 100644
index 00000000000..c7533cc4dd7
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+
+
+class CUDABackend(BackendAdapter):
+ type = BackendType.CUDA
+
+ def is_available(self) -> bool:
+ return torch.cuda.is_available()
+
+ def device_count(self) -> int:
+ return torch.cuda.device_count()
+
+ def make_device(self, index: int | None) -> torch.device:
+ if index is None:
+ raise ValueError("CUDA backend requires a valid device index")
+ return torch.device(f"cuda:{index}")
+
+ def set_device(self, index: int) -> None:
+ torch.cuda.set_device(index)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
b/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
new file mode 100644
index 00000000000..fa60f294d32
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/device_utils.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+
+DeviceLike = Union[torch.device, str, int]
+
+
+@dataclass(frozen=True)
+class DeviceSpec:
+ type: str
+ index: Optional[int]
+
+
+def parse_device_like(x: DeviceLike) -> DeviceSpec:
+ if isinstance(x, int):
+ return DeviceSpec("index", x)
+
+ if isinstance(x, str):
+ try:
+ return DeviceSpec("index", int(x))
+ except ValueError:
+ s = x.strip().lower()
+ if ":" in s:
+ t, idx = s.split(":", 1)
+ return DeviceSpec(t, int(idx))
+ return DeviceSpec(s, None)
+
+ if isinstance(x, torch.device):
+ return DeviceSpec(x.type, x.index)
+
+ raise TypeError(f"Unsupported device: {x!r}")
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/env.py
b/iotdb-core/ainode/iotdb/ainode/core/device/env.py
new file mode 100644
index 00000000000..5252cca028f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/env.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class DistEnv:
+ rank: int
+ local_rank: int
+ world_size: int
+
+
+def read_dist_env() -> DistEnv:
+ # torchrun:
+ rank = int(os.environ.get("RANK", "0"))
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+
+ # torchrun provides LOCAL_RANK; slurm often provides SLURM_LOCALID
+ local_rank = os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID",
"0"))
+ local_rank = int(local_rank)
+
+ return DistEnv(rank=rank, local_rank=local_rank, world_size=world_size)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index fb03e0af520..516c1d07c2c 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -40,8 +40,8 @@ from
iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler impor
BasicRequestScheduler,
)
from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.manager.device_manager import DeviceManager
from iotdb.ainode.core.model.model_storage import ModelInfo
-from iotdb.ainode.core.util.gpu_mapping import
convert_device_id_to_torch_device
class PoolState(Enum):
@@ -64,7 +64,7 @@ class InferenceRequestPool(mp.Process):
self,
pool_id: int,
model_info: ModelInfo,
- device: str,
+ device: torch.device,
request_queue: mp.Queue,
result_queue: mp.Queue,
ready_event,
@@ -75,7 +75,7 @@ class InferenceRequestPool(mp.Process):
self.model_info = model_info
self.pool_kwargs = pool_kwargs
self.ready_event = ready_event
- self.device = convert_device_id_to_torch_device(device)
+ self.device = device
self._threads = []
self._waiting_queue = request_queue # Requests that are waiting to be
processed
@@ -87,8 +87,8 @@ class InferenceRequestPool(mp.Process):
self._batcher = BasicBatcher()
self._stop_event = mp.Event()
+ self._backend = None
self._inference_pipeline = None
-
self._logger = None
# Fix inference seed
@@ -102,7 +102,7 @@ class InferenceRequestPool(mp.Process):
request.mark_running()
self._running_queue.put(request)
self._logger.debug(
-
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][Req-{request.req_id}]
Request is activated with inputs shape {request.inputs.shape}"
+
f"[Inference][{self.device}][Pool-{self.pool_id}][Req-{request.req_id}] Request
is activated with inputs shape {request.inputs.shape}"
)
def _requests_activate_loop(self):
@@ -120,9 +120,9 @@ class InferenceRequestPool(mp.Process):
grouped_requests = list(grouped_requests.values())
for requests in grouped_requests:
- batch_inputs = self._batcher.batch_request(requests).to(
- "cpu"
- ) # The input data should first load to CPU in current version
+ batch_inputs = self._backend.move_tensor(
+ self._batcher.batch_request(requests), self.device
+ )
batch_input_list = []
for i in range(batch_inputs.size(0)):
batch_input_list.append({"targets": batch_inputs[i]})
@@ -153,7 +153,9 @@ class InferenceRequestPool(mp.Process):
offset = 0
for request in requests:
- request.output_tensor = request.output_tensor.to(self.device)
+ request.output_tensor = self._backend.move_tensor(
+ request.output_tensor, self.device
+ )
cur_batch_size = request.batch_size
cur_output = batch_output[offset : offset + cur_batch_size]
offset += cur_batch_size
@@ -164,12 +166,12 @@ class InferenceRequestPool(mp.Process):
request.output_tensor = request.output_tensor.cpu()
self._finished_queue.put(request)
self._logger.debug(
-
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}]
Request is finished"
+
f"[Inference][{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request
is finished"
)
else:
self._waiting_queue.put(request)
self._logger.debug(
-
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}]
Request is not finished, re-queueing"
+
f"[Inference][{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request
is not finished, re-queueing"
)
return
@@ -182,8 +184,9 @@ class InferenceRequestPool(mp.Process):
self._logger = Logger(
INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
)
+ self._backend = DeviceManager()
self._request_scheduler.device = self.device
- self._inference_pipeline = load_pipeline(self.model_info,
str(self.device))
+ self._inference_pipeline = load_pipeline(self.model_info, self.device)
self.ready_event.set()
activate_daemon = threading.Thread(
@@ -197,12 +200,12 @@ class InferenceRequestPool(mp.Process):
self._threads.append(execute_daemon)
execute_daemon.start()
self._logger.info(
- f"[Inference][Device-{self.device}][Pool-{self.pool_id}]
InferenceRequestPool for model {self.model_info.model_id} is activated."
+ f"[Inference][{self.device}][Pool-{self.pool_id}]
InferenceRequestPool for model {self.model_info.model_id} is activated."
)
for thread in self._threads:
thread.join()
self._logger.info(
- f"[Inference][Device-{self.device}][Pool-{self.pool_id}]
InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
+ f"[Inference][{self.device}][Pool-{self.pool_id}]
InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
)
def stop(self):
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
index f1704fb90c4..ece395bf697 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -21,14 +21,17 @@ from abc import ABC, abstractmethod
import torch
from iotdb.ainode.core.exception import InferenceModelInternalException
+from iotdb.ainode.core.manager.device_manager import DeviceManager
from iotdb.ainode.core.model.model_info import ModelInfo
from iotdb.ainode.core.model.model_loader import load_model
+BACKEND = DeviceManager()
+
class BasicPipeline(ABC):
def __init__(self, model_info: ModelInfo, **model_kwargs):
self.model_info = model_info
- self.device = model_kwargs.get("device", "cpu")
+ self.device = model_kwargs.get("device", BACKEND.torch_device("cpu"))
self.model = load_model(model_info, device_map=self.device,
**model_kwargs)
@abstractmethod
@@ -48,7 +51,7 @@ class BasicPipeline(ABC):
class ForecastPipeline(BasicPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
- super().__init__(model_info, model_kwargs=model_kwargs)
+ super().__init__(model_info, **model_kwargs)
def preprocess(
self,
@@ -199,7 +202,7 @@ class ForecastPipeline(BasicPipeline):
class ClassificationPipeline(BasicPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
- super().__init__(model_info, model_kwargs=model_kwargs)
+ super().__init__(model_info, **model_kwargs)
def preprocess(self, inputs, **kwargs):
return inputs
@@ -214,7 +217,7 @@ class ClassificationPipeline(BasicPipeline):
class ChatPipeline(BasicPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
- super().__init__(model_info, model_kwargs=model_kwargs)
+ super().__init__(model_info, **model_kwargs)
def preprocess(self, inputs, **kwargs):
return inputs
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
index a30038dd5fe..865a449aa32 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
@@ -19,6 +19,8 @@
import os
from pathlib import Path
+import torch
+
from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.model_constants import ModelCategory
@@ -28,7 +30,7 @@ from iotdb.ainode.core.model.utils import
import_class_from_path, temporary_sys_
logger = Logger()
-def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs):
+def load_pipeline(model_info: ModelInfo, device: torch.device, **model_kwargs):
if model_info.model_type == "sktime":
from iotdb.ainode.core.model.sktime.pipeline_sktime import
SktimePipeline
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index c580a89916d..1eb07adfde4 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -22,6 +22,7 @@ import threading
from concurrent.futures import wait
from typing import Dict, Optional
+import torch
import torch.multiprocessing as mp
from iotdb.ainode.core.exception import InferenceModelInternalException
@@ -56,7 +57,7 @@ class PoolController:
def __init__(self, result_queue: mp.Queue):
self._model_manager = ModelManager()
# structure: {model_id: {device_id: PoolGroup}}
- self._request_pool_map: Dict[str, Dict[str, PoolGroup]] = {}
+ self._request_pool_map: Dict[str, Dict[torch.device, PoolGroup]] = {}
self._new_pool_id = AtomicInt()
self._result_queue = result_queue
self._pool_scheduler = BasicPoolScheduler(self._request_pool_map)
@@ -123,40 +124,40 @@ class PoolController:
# if not ready_event.wait(timeout=30):
# self._erase_pool(model_id, device_id, 0)
# logger.error(
- # f"[Inference][Device-{device}][Pool-0] Pool failed to be
ready in time"
+ # f"[Inference][{device}][Pool-0] Pool failed to be ready in
time"
# )
# else:
# self.set_state(model_id, device_id, 0, PoolState.RUNNING)
# logger.info(
- # f"[Inference][Device-{device}][Pool-0] Pool started running
for model {model_id}"
+ # f"[Inference][{device}][Pool-0] Pool started running for
model {model_id}"
# )
# =============== Pool Management ===============
- def load_model(self, model_id: str, device_id_list: list[str]):
+ def load_model(self, model_id: str, device_id_list: list[torch.device]):
"""
Load the model to the specified devices asynchronously.
Args:
model_id (str): The ID of the model to be loaded.
- device_id_list (list[str]): List of device_ids where the model
should be loaded.
+ device_id_list (list[torch.device]): List of device_ids where the
model should be loaded.
"""
self._task_queue.put((self._load_model_task, (model_id,
device_id_list), {}))
- def unload_model(self, model_id: str, device_id_list: list[str]):
+ def unload_model(self, model_id: str, device_id_list: list[torch.device]):
"""
Unload the model from the specified devices asynchronously.
Args:
model_id (str): The ID of the model to be unloaded.
- device_id_list (list[str]): List of device_ids where the model
should be unloaded.
+ device_id_list (list[torch.device]): List of device_ids where the
model should be unloaded.
"""
self._task_queue.put((self._unload_model_task, (model_id,
device_id_list), {}))
def show_loaded_models(
- self, device_id_list: list[str]
+ self, device_id_list: list[torch.device]
) -> Dict[str, Dict[str, int]]:
"""
Show loaded model instances on the specified devices.
Args:
- device_id_list (list[str]): List of device_ids where to examine
loaded instances.
+ device_id_list (list[torch.device]): List of device_ids where to
examine loaded instances.
Return:
Dict[str, Dict[str, int]]: Dict[device_id, Dict[model_id,
Count(instances)]].
"""
@@ -167,7 +168,10 @@ class PoolController:
if device_id in device_map:
pool_group = device_map[device_id]
device_models[model_id] =
pool_group.get_running_pool_count()
- result[device_id] = device_models
+ device_key = (
+ device_id.type if device_id.index is None else
str(device_id.index)
+ )
+ result[device_key] = device_models
return result
def _worker_loop(self):
@@ -184,8 +188,8 @@ class PoolController:
finally:
self._task_queue.task_done()
- def _load_model_task(self, model_id: str, device_id_list: list[str]):
- def _load_model_on_device_task(device_id: str):
+ def _load_model_task(self, model_id: str, device_id_list:
list[torch.device]):
+ def _load_model_on_device_task(device_id: torch.device):
if not self.has_request_pools(model_id, device_id):
actions = self._pool_scheduler.schedule_load_model_to_device(
self._model_manager.get_model_info(model_id), device_id
@@ -201,7 +205,7 @@ class PoolController:
)
else:
logger.info(
- f"[Inference][Device-{device_id}] Model {model_id} is
already installed."
+ f"[Inference][{device_id}] Model {model_id} is already
installed."
)
load_model_futures = self._executor.submit_batch(
@@ -211,8 +215,8 @@ class PoolController:
load_model_futures, return_when=concurrent.futures.ALL_COMPLETED
)
- def _unload_model_task(self, model_id: str, device_id_list: list[str]):
- def _unload_model_on_device_task(device_id: str):
+ def _unload_model_task(self, model_id: str, device_id_list:
list[torch.device]):
+ def _unload_model_on_device_task(device_id: torch.device):
if self.has_request_pools(model_id, device_id):
actions =
self._pool_scheduler.schedule_unload_model_from_device(
self._model_manager.get_model_info(model_id), device_id
@@ -228,7 +232,7 @@ class PoolController:
)
else:
logger.info(
- f"[Inference][Device-{device_id}] Model {model_id} is not
installed."
+ f"[Inference][{device_id}] Model {model_id} is not
installed."
)
unload_model_futures = self._executor.submit_batch(
@@ -238,12 +242,14 @@ class PoolController:
unload_model_futures, return_when=concurrent.futures.ALL_COMPLETED
)
- def _expand_pools_on_device(self, model_id: str, device_id: str, count:
int):
+ def _expand_pools_on_device(
+ self, model_id: str, device_id: torch.device, count: int
+ ):
"""
Expand the pools for the given model_id and device_id sequentially.
Args:
model_id (str): The ID of the model.
- device_id (str): The ID of the device.
+ device_id (torch.device): The ID of the device.
count (int): The number of pools to be expanded.
"""
@@ -263,14 +269,14 @@ class PoolController:
self._register_pool(model_id, device_id, pool_id, pool,
request_queue)
if not pool.ready_event.wait(timeout=300):
logger.error(
- f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool
failed to be ready in time"
+ f"[Inference][{device_id}][Pool-{pool_id}] Pool failed to
be ready in time"
)
# TODO: retry or decrease the count? this error should be
better handled
self._erase_pool(model_id, device_id, pool_id)
else:
self.set_state(model_id, device_id, pool_id, PoolState.RUNNING)
logger.info(
- f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool
started running for model {model_id}"
+ f"[Inference][{device_id}][Pool-{pool_id}] Pool started
running for model {model_id}"
)
expand_pool_futures = self._executor.submit_batch(
@@ -280,7 +286,9 @@ class PoolController:
expand_pool_futures, return_when=concurrent.futures.ALL_COMPLETED
)
- def _shrink_pools_on_device(self, model_id: str, device_id: str, count):
+ def _shrink_pools_on_device(
+ self, model_id: str, device_id: torch.device, count: int
+ ):
"""
Shrink the pools for the given model_id by count sequentially.
TODO: shrink pools in parallel
@@ -335,7 +343,7 @@ class PoolController:
def _register_pool(
self,
model_id: str,
- device_id: str,
+ device_id: torch.device,
pool_id: int,
request_pool: InferenceRequestPool,
request_queue: mp.Queue,
@@ -349,10 +357,10 @@ class PoolController:
pool_group: PoolGroup = self.get_request_pools_group(model_id,
device_id)
pool_group.set_state(pool_id, PoolState.INITIALIZING)
logger.info(
- f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool
initializing for model {model_id}"
+ f"[Inference][{device_id}][Pool-{pool_id}] Pool initializing for
model {model_id}"
)
- def _erase_pool(self, model_id: str, device_id: str, pool_id: int):
+ def _erase_pool(self, model_id: str, device_id: torch.device, pool_id:
int):
"""
Erase the specified inference request pool for the given model_id,
device_id and pool_id.
"""
@@ -360,7 +368,7 @@ class PoolController:
if pool_group:
pool_group.remove_pool(pool_id)
logger.info(
- f"[Inference][Device-{device_id}][Pool-{pool_id}] Erase pool
for model {model_id}"
+ f"[Inference][{device_id}][Pool-{pool_id}] Erase pool for
model {model_id}"
)
# Clean up empty structures
if pool_group and not pool_group.get_pool_ids():
@@ -387,7 +395,9 @@ class PoolController:
self._request_pool_map[model_id][device_id].dispatch_request(req,
infer_proxy)
# =============== Getters / Setters ===============
- def get_state(self, model_id, device_id, pool_id) -> Optional[PoolState]:
+ def get_state(
+ self, model_id: str, device_id: torch.device, pool_id: int
+ ) -> Optional[PoolState]:
"""
Get the state of the specified pool based on model_id, device_id, and
pool_id.
"""
@@ -396,7 +406,9 @@ class PoolController:
return pool_group.get_state(pool_id)
return None
- def set_state(self, model_id, device_id, pool_id, state):
+ def set_state(
+ self, model_id: str, device_id: torch.device, pool_id: int, state:
PoolState
+ ):
"""
Set the state of the specified pool based on model_id, device_id, and
pool_id.
"""
@@ -404,7 +416,7 @@ class PoolController:
if pool_group:
pool_group.set_state(pool_id, state)
- def get_device_ids(self, model_id) -> list[str]:
+ def get_device_ids(self, model_id) -> list[torch.device]:
"""
Get the list of device IDs for the given model_id, where the
corresponding instances are loaded.
"""
@@ -412,7 +424,7 @@ class PoolController:
return list(self._request_pool_map[model_id].keys())
return []
- def get_pool_ids(self, model_id: str, device_id: str) -> list[int]:
+ def get_pool_ids(self, model_id: str, device_id: torch.device) ->
list[int]:
"""
Get the list of pool IDs for the given model_id and device_id.
"""
@@ -421,9 +433,9 @@ class PoolController:
return pool_group.get_pool_ids()
return []
- def has_request_pools(self, model_id: str, device_id: Optional[str] =
None) -> bool:
+ def has_request_pools(self, model_id: str, device_id: torch.device = None)
-> bool:
"""
- Check if there are request pools for the given model_id and device_id
(optional).
+ Check if there are request pools for the given model_id ((optional)
and device_id).
"""
if model_id not in self._request_pool_map:
return False
@@ -432,7 +444,7 @@ class PoolController:
return True
def get_request_pools_group(
- self, model_id: str, device_id: str
+ self, model_id: str, device_id: torch.device
) -> Optional[PoolGroup]:
if (
model_id in self._request_pool_map
@@ -443,14 +455,16 @@ class PoolController:
return None
def get_request_pool(
- self, model_id, device_id, pool_id
+ self, model_id: str, device_id: torch.device, pool_id: int
) -> Optional[InferenceRequestPool]:
pool_group = self.get_request_pools_group(model_id, device_id)
if pool_group:
return pool_group.get_request_pool(pool_id)
return None
- def get_request_queue(self, model_id, device_id, pool_id) ->
Optional[mp.Queue]:
+ def get_request_queue(
+ self, model_id: str, device_id: torch.device, pool_id: int
+ ) -> Optional[mp.Queue]:
pool_group = self.get_request_pools_group(model_id, device_id)
if pool_group:
return pool_group.get_request_queue(pool_id)
@@ -459,7 +473,7 @@ class PoolController:
def set_request_pool_map(
self,
model_id: str,
- device_id: str,
+ device_id: torch.device,
pool_id: int,
request_pool: InferenceRequestPool,
request_queue: mp.Queue,
@@ -475,10 +489,10 @@ class PoolController:
pool_id, request_pool, request_queue
)
logger.info(
- f"[Inference][Device-{device_id}][Pool-{pool_id}] Registered pool
for model {model_id}"
+ f"[Inference][{device_id}][Pool-{pool_id}] Registered pool for
model {model_id}"
)
- def get_load(self, model_id, device_id, pool_id) -> int:
+ def get_load(self, model_id: str, device_id: torch.device, pool_id: int)
-> int:
"""
Get the current load of the specified pool.
"""
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
index 19d21f5822d..7e74a6c62b3 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
+++
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
@@ -21,6 +21,8 @@ from dataclasses import dataclass
from enum import Enum
from typing import Dict, List
+import torch
+
from iotdb.ainode.core.inference.pool_group import PoolGroup
from iotdb.ainode.core.model.model_info import ModelInfo
@@ -35,7 +37,7 @@ class ScaleAction:
action: ScaleActionType
amount: int
model_id: str
- device_id: str
+ device_id: torch.device
class AbstractPoolScheduler(ABC):
@@ -43,10 +45,10 @@ class AbstractPoolScheduler(ABC):
Abstract base class for pool scheduling strategies.
"""
- def __init__(self, request_pool_map: Dict[str, Dict[str, PoolGroup]]):
+ def __init__(self, request_pool_map: Dict[str, Dict[torch.device,
PoolGroup]]):
"""
Args:
- request_pool_map: Dict["model_id", Dict["device_id", PoolGroup]].
+ request_pool_map: Dict["model_id", Dict[device_id, PoolGroup]].
"""
self._request_pool_map = request_pool_map
@@ -59,7 +61,7 @@ class AbstractPoolScheduler(ABC):
@abstractmethod
def schedule_load_model_to_device(
- self, model_info: ModelInfo, device_id: str
+ self, model_info: ModelInfo, device_id: torch.device
) -> List[ScaleAction]:
"""
Schedule a series of actions to load the model to the device.
@@ -73,7 +75,7 @@ class AbstractPoolScheduler(ABC):
@abstractmethod
def schedule_unload_model_from_device(
- self, model_info: ModelInfo, device_id: str
+ self, model_info: ModelInfo, device_id: torch.device
) -> List[ScaleAction]:
"""
Schedule a series of actions to unload the model from the device.
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index 21140cafb1f..65aa7714393 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -20,7 +20,6 @@ from typing import Dict, List, Optional
import torch
-from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pool_group import PoolGroup
from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import
(
AbstractPoolScheduler,
@@ -33,11 +32,9 @@ from iotdb.ainode.core.manager.utils import (
INFERENCE_EXTRA_MEMORY_RATIO,
INFERENCE_MEMORY_USAGE_RATIO,
MODEL_MEM_USAGE_MAP,
- estimate_pool_size,
evaluate_system_resources,
)
from iotdb.ainode.core.model.model_info import ModelInfo
-from iotdb.ainode.core.util.gpu_mapping import
convert_device_id_to_torch_device
logger = Logger()
@@ -74,7 +71,7 @@ def _estimate_shared_pool_size_by_total_mem(
usable_mem = total_mem * INFERENCE_MEMORY_USAGE_RATIO
if usable_mem <= 0:
logger.error(
- f"[Inference][Device-{device}] No usable memory on device.
total={total_mem / 1024 ** 2:.2f} MB, usable={usable_mem / 1024 ** 2:.2f} MB"
+ f"[Inference][{device}] No usable memory on device.
total={total_mem / 1024 ** 2:.2f} MB, usable={usable_mem / 1024 ** 2:.2f} MB"
)
# Each model gets an equal share of the TOTAL memory
@@ -87,39 +84,32 @@ def _estimate_shared_pool_size_by_total_mem(
pool_num = int(per_model_share // mem_usages[model_info.model_id])
if pool_num <= 0:
logger.warning(
- f"[Inference][Device-{device}] Not enough TOTAL memory to
guarantee at least 1 pool for model {model_info.model_id}, no pool will be
scheduled for this model. "
+ f"[Inference][{device}] Not enough TOTAL memory to guarantee
at least 1 pool for model {model_info.model_id}, no pool will be scheduled for
this model. "
f"Per-model share={per_model_share / 1024 ** 2:.2f} MB,
need>={mem_usages[model_info.model_id] / 1024 ** 2:.2f} MB"
)
allocation[model_info.model_id] = pool_num
logger.info(
- f"[Inference][Device-{device}] Shared pool allocation (by TOTAL
memory): {allocation}"
+ f"[Inference][{device}] Shared pool allocation (by TOTAL memory):
{allocation}"
)
return allocation
class BasicPoolScheduler(AbstractPoolScheduler):
"""
- A basic scheduler to init the request pools. In short, different kind of
models will equally share the available resource of the located device, and
scale down actions are always ahead of scale up.
+ A basic scheduler to init the request pools. In short,
+ different kind of models will equally share the available resource of the
located device,
+ and scale down actions are always ahead of scale up.
"""
- def __init__(self, request_pool_map: Dict[str, Dict[str, PoolGroup]]):
+ def __init__(self, request_pool_map: Dict[str, Dict[torch.device,
PoolGroup]]):
super().__init__(request_pool_map)
self._model_manager = ModelManager()
def schedule(self, model_id: str) -> List[ScaleAction]:
- """
- Schedule a scaling action for the given model_id.
- """
- if model_id not in self._request_pool_map:
- pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id)
- if pool_num <= 0:
- raise InferenceModelInternalException(
- f"Not enough memory to run model {model_id}."
- )
- return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
+ pass
def schedule_load_model_to_device(
- self, model_info: ModelInfo, device_id: str
+ self, model_info: ModelInfo, device_id: torch.device
) -> List[ScaleAction]:
existing_model_infos = [
self._model_manager.get_model_info(existing_model_id)
@@ -127,7 +117,7 @@ class BasicPoolScheduler(AbstractPoolScheduler):
if existing_model_id != model_info.model_id and device_id in
pool_group_map
]
allocation_result = _estimate_shared_pool_size_by_total_mem(
- device=convert_device_id_to_torch_device(device_id),
+ device=device_id,
existing_model_infos=existing_model_infos,
new_model_info=model_info,
)
@@ -136,7 +126,7 @@ class BasicPoolScheduler(AbstractPoolScheduler):
)
def schedule_unload_model_from_device(
- self, model_info: ModelInfo, device_id: str
+ self, model_info: ModelInfo, device_id: torch.device
) -> List[ScaleAction]:
existing_model_infos = [
self._model_manager.get_model_info(existing_model_id)
@@ -145,7 +135,7 @@ class BasicPoolScheduler(AbstractPoolScheduler):
]
allocation_result = (
_estimate_shared_pool_size_by_total_mem(
- device=convert_device_id_to_torch_device(device_id),
+ device=device_id,
existing_model_infos=existing_model_infos,
new_model_info=None,
)
@@ -159,10 +149,11 @@ class BasicPoolScheduler(AbstractPoolScheduler):
)
def _convert_allocation_result_to_scale_actions(
- self, allocation_result: Dict[str, int], device_id: str
+ self, allocation_result: Dict[str, int], device_id: torch.device
) -> List[ScaleAction]:
"""
- Convert the model allocation result to List[ScaleAction], where the
scale down actions are always ahead of the scale up.
+ Convert the model allocation result to List[ScaleAction],
+ where the scale down actions are always ahead of the scale up.
"""
actions = []
for model_id, target_num in allocation_result.items():
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
new file mode 100644
index 00000000000..daac19b8a42
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/device_manager.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+
+from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
+from iotdb.ainode.core.device.backend.cpu_backend import CPUBackend
+from iotdb.ainode.core.device.backend.cuda_backend import CUDABackend
+from iotdb.ainode.core.device.device_utils import DeviceLike, parse_device_like
+from iotdb.ainode.core.device.env import DistEnv, read_dist_env
+from iotdb.ainode.core.util.decorator import singleton
+
+
+@singleton
+class DeviceManager:
+ use_local_rank_if_distributed: bool = True
+
+ """
+ Unified device entry point:
+ - Select backend (cuda/npu/cpu)
+ - Parse device expression (None/int/str/torch.device/DeviceSpec)
+ - Provide device, autocast, grad scaler, synchronize, dist backend
recommendation, etc.
+ """
+
+ def __init__(self):
+ self.env: DistEnv = read_dist_env()
+
+ self.backends: dict[BackendType, BackendAdapter] = {
+ BackendType.CUDA: CUDABackend(),
+ BackendType.CPU: CPUBackend(),
+ }
+
+ self.type: BackendType
+ self.backend: BackendAdapter = self._auto_select_backend()
+
+ # ==================== selection ====================
+ def _auto_select_backend(self) -> BackendAdapter:
+ for name in BackendType:
+ backend = self.backends.get(name)
+ if backend is not None and backend.is_available():
+ self.type = backend.type
+ return backend
+ return self.backends[BackendType.CPU]
+
+ # ==================== public API ====================
+ def device_ids(self) -> list[int]:
+ """
+ Returns a list of available device IDs for the current backend.
+ """
+ if self.backend.type == BackendType.CPU:
+ return []
+ return list(range(self.backend.device_count()))
+
+ def available_devices_with_cpu(self) -> list[torch.device]:
+ """
+ Returns the list of available torch.devices, including "cpu".
+ """
+ device_id_list = self.device_ids()
+ device_id_list = [self.torch_device(device_id) for device_id in
device_id_list]
+ device_id_list.append(self.torch_device("cpu"))
+ return device_id_list
+
+ def torch_device(self, device: DeviceLike) -> torch.device:
+ """
+ Convert a DeviceLike specification into a torch.device object.
+ Args:
+ device: Could be any of the following formats:
+ an integer (e.g., 0, 1, ...),
+ a string (e.g., "0", "cuda:0", "cpu", ...),
+ a torch.device object, return itself if so.
+ Raise:
+ ValueError: If device is None or incorrect.
+ """
+ if device is None:
+ raise ValueError(
+ "Device must be specified explicitly; None is not allowed."
+ )
+ if isinstance(device, torch.device):
+ return device
+ spec = parse_device_like(device)
+ if spec.type == "cpu":
+ return torch.device("cpu")
+ return self.backend.make_device(spec.index)
+
+ def move_model(
+ self, model: torch.nn.Module, device: DeviceLike = None
+ ) -> torch.nn.Module:
+ return model.to(self.torch_device(device))
+
+ def move_tensor(
+ self, tensor: torch.Tensor, device: DeviceLike = None
+ ) -> torch.Tensor:
+ return tensor.to(self.torch_device(device))
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index b758908c5e4..addcfad6cfb 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -42,9 +42,9 @@ from iotdb.ainode.core.inference.pipeline.pipeline_loader
import load_pipeline
from iotdb.ainode.core.inference.pool_controller import PoolController
from iotdb.ainode.core.inference.utils import generate_req_id
from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.manager.device_manager import DeviceManager
from iotdb.ainode.core.manager.model_manager import ModelManager
from iotdb.ainode.core.rpc.status import get_status
-from iotdb.ainode.core.util.gpu_mapping import get_available_devices
from iotdb.ainode.core.util.serde import (
convert_tensor_to_tsblock,
convert_tsblock_to_tensor,
@@ -54,10 +54,7 @@ from iotdb.thrift.ainode.ttypes import (
TForecastResp,
TInferenceReq,
TInferenceResp,
- TLoadModelReq,
- TShowLoadedModelsReq,
TShowLoadedModelsResp,
- TUnloadModelReq,
)
from iotdb.thrift.common.ttypes import TSStatus
@@ -71,6 +68,7 @@ class InferenceManager:
def __init__(self):
self._model_manager = ModelManager()
+ self._backend = DeviceManager()
self._model_mem_usage_map: Dict[str, int] = (
{}
) # store model memory usage for each model
@@ -85,57 +83,71 @@ class InferenceManager:
self._result_handler_thread.start()
self._pool_controller = PoolController(self._result_queue)
- def load_model(self, req: TLoadModelReq) -> TSStatus:
- devices_to_be_processed = []
- devices_not_to_be_processed = []
- for device_id in req.deviceIdList:
+ def load_model(
+ self, existing_model_id: str, device_id_list: list[torch.device]
+ ) -> TSStatus:
+ """
+ Load a model to specified devices.
+ Args:
+ existing_model_id (str): The ID of the model to be loaded.
+ device_id_list (list[torch.device]): List of device IDs to load
the model onto.
+ Returns:
+ TSStatus: The status of the load model operation.
+ """
+ devices_to_be_processed: list[torch.device] = []
+ devices_not_to_be_processed: list[torch.device] = []
+ for device_id in device_id_list:
if self._pool_controller.has_request_pools(
- model_id=req.existingModelId, device_id=device_id
+ model_id=existing_model_id, device_id=device_id
):
devices_not_to_be_processed.append(device_id)
else:
devices_to_be_processed.append(device_id)
if len(devices_to_be_processed) > 0:
self._pool_controller.load_model(
- model_id=req.existingModelId,
device_id_list=devices_to_be_processed
+ model_id=existing_model_id,
device_id_list=devices_to_be_processed
)
logger.info(
- f"[Inference] Start loading model [{req.existingModelId}] to
devices [{devices_to_be_processed}], skipped devices
[{devices_not_to_be_processed}] cause they have already loaded this model."
+ f"[Inference] Start loading model [{existing_model_id}] to devices
[{devices_to_be_processed}], skipped devices [{devices_not_to_be_processed}]
cause they have already loaded this model."
)
return TSStatus(
code=TSStatusCode.SUCCESS_STATUS.value,
message='Successfully submitted load model task, please use "SHOW
LOADED MODELS" to check progress.',
)
- def unload_model(self, req: TUnloadModelReq) -> TSStatus:
+ def unload_model(
+ self, model_id: str, device_id_list: list[torch.device]
+ ) -> TSStatus:
devices_to_be_processed = []
devices_not_to_be_processed = []
- for device_id in req.deviceIdList:
+ for device_id in device_id_list:
if self._pool_controller.has_request_pools(
- model_id=req.modelId, device_id=device_id
+ model_id=model_id, device_id=device_id
):
devices_to_be_processed.append(device_id)
else:
devices_not_to_be_processed.append(device_id)
if len(devices_to_be_processed) > 0:
self._pool_controller.unload_model(
- model_id=req.modelId, device_id_list=req.deviceIdList
+ model_id=model_id, device_id_list=device_id_list
)
logger.info(
- f"[Inference] Start unloading model [{req.modelId}] from devices
[{devices_to_be_processed}], skipped devices [{devices_not_to_be_processed}]
cause they haven't loaded this model."
+ f"[Inference] Start unloading model [{model_id}] from devices
[{devices_to_be_processed}], skipped devices [{devices_not_to_be_processed}]
cause they haven't loaded this model."
)
return TSStatus(
code=TSStatusCode.SUCCESS_STATUS.value,
message='Successfully submitted unload model task, please use
"SHOW LOADED MODELS" to check progress.',
)
- def show_loaded_models(self, req: TShowLoadedModelsReq) ->
TShowLoadedModelsResp:
+ def show_loaded_models(
+ self, device_id_list: list[torch.device]
+ ) -> TShowLoadedModelsResp:
return TShowLoadedModelsResp(
status=get_status(TSStatusCode.SUCCESS_STATUS),
deviceLoadedModelsMap=self._pool_controller.show_loaded_models(
- req.deviceIdList
- if len(req.deviceIdList) > 0
- else get_available_devices()
+ device_id_list
+ if len(device_id_list) > 0
+ else self._backend.available_devices_with_cpu()
),
)
@@ -202,7 +214,7 @@ class InferenceManager:
output_length,
)
- if self._pool_controller.has_request_pools(model_id):
+ if self._pool_controller.has_request_pools(model_id=model_id):
infer_req = InferenceRequest(
req_id=generate_req_id(),
model_id=model_id,
@@ -214,7 +226,9 @@ class InferenceManager:
outputs = self._process_request(infer_req)
else:
model_info = self._model_manager.get_model_info(model_id)
- inference_pipeline = load_pipeline(model_info, device="cpu")
+ inference_pipeline = load_pipeline(
+ model_info, device=self._backend.torch_device("cpu")
+ )
inputs = inference_pipeline.preprocess(
model_inputs_list, output_length=output_length
)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
index 17516876201..41bc6ec91c8 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
@@ -65,17 +65,17 @@ def measure_model_memory(device: torch.device, model_id:
str) -> int:
def evaluate_system_resources(device: torch.device) -> dict:
- if torch.cuda.is_available():
+ if device.type == "cuda":
free_mem, total_mem = torch.cuda.mem_get_info()
logger.info(
- f"[Inference][Device-{device}] CUDA device memory:
free={free_mem/1024**2:.2f} MB, total={total_mem/1024**2:.2f} MB"
+ f"[Inference][{device}] CUDA device memory:
free={free_mem/1024**2:.2f} MB, total={total_mem/1024**2:.2f} MB"
)
return {"device": "cuda", "free_mem": free_mem, "total_mem": total_mem}
else:
free_mem = psutil.virtual_memory().available
total_mem = psutil.virtual_memory().total
logger.info(
- f"[Inference][Device-{device}] CPU memory:
free={free_mem/1024**2:.2f} MB, total={total_mem/1024**2:.2f} MB"
+ f"[Inference][{device}] CPU memory: free={free_mem/1024**2:.2f}
MB, total={total_mem/1024**2:.2f} MB"
)
return {"device": "cpu", "free_mem": free_mem, "total_mem": total_mem}
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
index 3fdc7b41b17..b28f8f35a66 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
@@ -34,7 +34,7 @@ logger = Logger()
class Chronos2Pipeline(ForecastPipeline):
def __init__(self, model_info, **model_kwargs):
- super().__init__(model_info, model_kwargs=model_kwargs)
+ super().__init__(model_info, **model_kwargs)
def preprocess(self, inputs, **infer_kwargs):
"""
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
index 605620d4261..289786c8aa3 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
@@ -34,12 +34,14 @@ from transformers import (
from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.exception import ModelNotExistException
from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.manager.device_manager import DeviceManager
from iotdb.ainode.core.model.model_constants import ModelCategory
from iotdb.ainode.core.model.model_info import ModelInfo
from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model
from iotdb.ainode.core.model.utils import import_class_from_path,
temporary_sys_path
logger = Logger()
+BACKEND = DeviceManager()
def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
@@ -105,17 +107,13 @@ def load_model_from_transformers(model_info: ModelInfo,
**model_kwargs):
model_cls = AutoModelForCausalLM
if train_from_scratch:
- model = model_cls.from_config(
- config_cls, trust_remote_code=trust_remote_code,
device_map=device_map
- )
+ model = model_cls.from_config(config_cls,
trust_remote_code=trust_remote_code)
else:
model = model_cls.from_pretrained(
- model_path,
- trust_remote_code=trust_remote_code,
- device_map=device_map,
+ model_path, trust_remote_code=trust_remote_code
)
- return model
+ return BACKEND.move_model(model, device_map)
def load_model_from_pt(model_info: ModelInfo, **kwargs):
@@ -138,7 +136,7 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs):
model = torch.compile(model)
except Exception as e:
logger.warning(f"acceleration failed, fallback to normal mode:
{str(e)}")
- return model.to(device_map)
+ return BACKEND.move_model(model, device_map)
def load_model_for_efficient_inference():
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
index 964ab156e26..12b2668543e 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
@@ -31,7 +31,7 @@ logger = Logger()
class SktimePipeline(ForecastPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
model_kwargs.pop("device", None) # sktime models run on CPU
- super().__init__(model_info, model_kwargs=model_kwargs)
+ super().__init__(model_info, **model_kwargs)
def preprocess(
self,
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
index 1715f190e32..8aa9b175169 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
@@ -28,7 +28,7 @@ logger = Logger()
class SundialPipeline(ForecastPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
- super().__init__(model_info, model_kwargs=model_kwargs)
+ super().__init__(model_info, **model_kwargs)
def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
"""
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
index bb54eed4ec6..213e6102c8b 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
@@ -28,7 +28,7 @@ logger = Logger()
class TimerPipeline(ForecastPipeline):
def __init__(self, model_info: ModelInfo, **model_kwargs):
- super().__init__(model_info, model_kwargs=model_kwargs)
+ super().__init__(model_info, **model_kwargs)
def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
"""
diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
index 492802fc060..e925c3791b8 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
@@ -19,10 +19,10 @@
from iotdb.ainode.core.constant import TSStatusCode
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.cluster_manager import ClusterManager
+from iotdb.ainode.core.manager.device_manager import DeviceManager
from iotdb.ainode.core.manager.inference_manager import InferenceManager
from iotdb.ainode.core.manager.model_manager import ModelManager
from iotdb.ainode.core.rpc.status import get_status
-from iotdb.ainode.core.util.gpu_mapping import get_available_devices
from iotdb.thrift.ainode import IAINodeRPCService
from iotdb.thrift.ainode.ttypes import (
TAIHeartbeatReq,
@@ -48,25 +48,12 @@ from iotdb.thrift.common.ttypes import TSStatus
logger = Logger()
-def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus:
- """
- Ensure that the device IDs in the provided list are available.
- """
- available_devices = get_available_devices()
- for device_id in device_id_list:
- if device_id not in available_devices:
- return TSStatus(
- code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
- message=f"AIDevice ID [{device_id}] is not available. You can
use 'SHOW AI_DEVICES' to retrieve the available devices.",
- )
- return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
-
-
class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
def __init__(self, ainode):
self._ainode = ainode
self._model_manager = ModelManager()
self._inference_manager = InferenceManager()
+ self._backend = DeviceManager()
# ==================== Cluster Management ====================
@@ -82,9 +69,12 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
return ClusterManager.get_heart_beat(req)
def showAIDevices(self) -> TShowAIDevicesResp:
+ device_id_map = {"cpu": "cpu"}
+ for device_id in self._backend.device_ids():
+ device_id_map[str(device_id)] = self._backend.type.value
return TShowAIDevicesResp(
status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value),
- deviceIdList=get_available_devices(),
+ deviceIdMap=device_id_map,
)
# ==================== Model Management ====================
@@ -102,25 +92,33 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
status = self._ensure_model_is_registered(req.existingModelId)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
return status
- status = _ensure_device_id_is_available(req.deviceIdList)
+ status = self._ensure_device_id_is_available(req.deviceIdList)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
return status
- return self._inference_manager.load_model(req)
+ return self._inference_manager.load_model(
+ req.existingModelId,
+ [self._backend.torch_device(device_id) for device_id in
req.deviceIdList],
+ )
def unloadModel(self, req: TUnloadModelReq) -> TSStatus:
status = self._ensure_model_is_registered(req.modelId)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
return status
- status = _ensure_device_id_is_available(req.deviceIdList)
+ status = self._ensure_device_id_is_available(req.deviceIdList)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
return status
- return self._inference_manager.unload_model(req)
+ return self._inference_manager.unload_model(
+ req.modelId,
+ [self._backend.torch_device(device_id) for device_id in
req.deviceIdList],
+ )
def showLoadedModels(self, req: TShowLoadedModelsReq) ->
TShowLoadedModelsResp:
- status = _ensure_device_id_is_available(req.deviceIdList)
+ status = self._ensure_device_id_is_available(req.deviceIdList)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
return TShowLoadedModelsResp(status=status,
deviceLoadedModelsMap={})
- return self._inference_manager.show_loaded_models(req)
+ return self._inference_manager.show_loaded_models(
+ [self._backend.torch_device(device_id) for device_id in
req.deviceIdList]
+ )
def _ensure_model_is_registered(self, model_id: str) -> TSStatus:
if not self._model_manager.is_model_registered(model_id):
@@ -144,6 +142,26 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
return TForecastResp(status, [])
return self._inference_manager.forecast(req)
+ # ==================== Internal API ====================
+
+ def _ensure_device_id_is_available(self, device_id_list: list[str]) ->
TSStatus:
+ """
+ Ensure that the device IDs in the provided list are available.
+ """
+ available_devices = self._backend.device_ids()
+ for device_id in device_id_list:
+ try:
+ if device_id == "cpu":
+ continue
+ if int(device_id) not in available_devices:
+ raise ValueError(f"Invalid device ID [{device_id}]")
+ except (TypeError, ValueError):
+ return TSStatus(
+ code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value,
+ message=f"AIDevice ID [{device_id}] is not available. You
can use 'SHOW AI_DEVICES' to retrieve the available devices.",
+ )
+ return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
+
# ==================== Tuning ====================
def createTuningTask(self, req: TTuningReq) -> TSStatus:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/gpu_mapping.py
b/iotdb-core/ainode/iotdb/ainode/core/util/gpu_mapping.py
deleted file mode 100644
index 72b056adb87..00000000000
--- a/iotdb-core/ainode/iotdb/ainode/core/util/gpu_mapping.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-import torch
-
-
-def convert_device_id_to_torch_device(device_id: str) -> torch.device:
- """
- Converts a device ID string to a torch.device object.
-
- Args:
- device_id (str): The device ID string. It can be "cpu" or a GPU index
like "0", "1", etc.
-
- Returns:
- torch.device: The corresponding torch.device object.
-
- Raises:
- ValueError: If the device_id is not "cpu" or a valid integer string.
- """
- if device_id.lower() == "cpu":
- return torch.device("cpu")
- try:
- gpu_index = int(device_id)
- if gpu_index < 0:
- raise ValueError
- return torch.device(f"cuda:{gpu_index}")
- except ValueError:
- raise ValueError(
- f"Invalid device_id '{device_id}'. It should be 'cpu' or a
non-negative integer string."
- )
-
-
-def get_available_gpus() -> list[int]:
- """
- Returns a list of available GPU indices if CUDA is available, otherwise
returns an empty list.
- """
-
- if not torch.cuda.is_available():
- return []
- return list(range(torch.cuda.device_count()))
-
-
-def get_available_devices() -> list[str]:
- """
- Returns: a list of available device IDs as strings, including "cpu".
- """
- device_id_list = get_available_gpus()
- device_id_list = [str(device_id) for device_id in device_id_list]
- device_id_list.append("cpu")
- return device_id_list
-
-
-def parse_devices(devices):
- """
- Parses the input string of GPU devices and returns a comma-separated
string of valid GPU indices.
-
- Args:
- devices (str): A comma-separated string of GPU indices (e.g., "0,1,2").
- Returns:
- str: A comma-separated string of valid GPU indices corresponding to
the input. All available GPUs if no input is provided.
- Exceptions:
- RuntimeError: If no GPUs are available.
- ValueError: If any of the provided GPU indices are not available.
- """
- if devices is None or devices == "":
- gpu_ids = get_available_gpus()
- if not gpu_ids:
- raise RuntimeError("No available GPU")
- return ",".join(map(str, gpu_ids))
- else:
- gpu_ids = [int(gpu) for gpu in devices.split(",")]
- available_gpus = get_available_gpus()
- for gpu_id in gpu_ids:
- if gpu_id not in available_gpus:
- raise ValueError(
- f"GPU {gpu_id} is not available, the available choices
are: {available_gpus}"
- )
- return devices
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index fc2068e66e4..c3965d9d099 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -76,7 +76,7 @@ exclude = [
]
[tool.poetry.dependencies]
-python = ">=3.11.0,<3.14.0"
+python = ">=3.11.0,<3.12.0"
# ---- DL / HF stack ----
torch = "^2.8.0,<2.9.0"
@@ -88,9 +88,9 @@ safetensors = "^0.6.2"
einops = "^0.8.1"
# ---- Core scientific stack ----
-numpy = "^2.3.2"
+numpy = ">=2.0,<2.4.0"
+pandas = ">=2.0,<2.4.0"
scipy = "^1.12.0"
-pandas = "^2.3.2"
scikit-learn = "^1.7.1"
statsmodels = "^0.14.5"
sktime = "0.40.1"
diff --git a/iotdb-core/ainode/resources/conf/iotdb-ainode.properties
b/iotdb-core/ainode/resources/conf/iotdb-ainode.properties
index 88948138340..fc569b27807 100644
--- a/iotdb-core/ainode/resources/conf/iotdb-ainode.properties
+++ b/iotdb-core/ainode/resources/conf/iotdb-ainode.properties
@@ -58,7 +58,7 @@ ain_cluster_ingress_time_zone=UTC+8
# The device space allocated for inference
# Datatype: Float
-ain_inference_memory_usage_ratio=0.4
+ain_inference_memory_usage_ratio=0.2
# The overhead ratio for inference, used to estimate the pool size
# Datatype: Float
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
index 690f6f9485f..3ccad1e24d5 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java
@@ -36,6 +36,7 @@ import org.apache.tsfile.read.common.block.TsBlockBuilder;
import org.apache.tsfile.utils.BytesUtils;
import java.util.List;
+import java.util.Map;
import java.util.stream.Collectors;
public class ShowAIDevicesTask implements IConfigTask {
@@ -53,11 +54,15 @@ public class ShowAIDevicesTask implements IConfigTask {
.map(ColumnHeader::getColumnType)
.collect(Collectors.toList());
TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
- for (String deviceId : resp.getDeviceIdList()) {
- builder.getTimeColumnBuilder().writeLong(0L);
- builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(deviceId));
- builder.declarePosition();
- }
+ resp.getDeviceIdMap().entrySet().stream()
+ .sorted(Map.Entry.comparingByKey())
+ .forEach(
+ deviceEntry -> {
+ builder.getTimeColumnBuilder().writeLong(0L);
+
builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(deviceEntry.getKey()));
+
builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(deviceEntry.getValue()));
+ builder.declarePosition();
+ });
DatasetHeader datasetHeader =
DatasetHeaderFactory.getShowAIDevicesHeader();
future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS,
builder.build(), datasetHeader));
}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
index 0459d4d2c86..dba2c2e368d 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
@@ -36,6 +36,7 @@ public class ColumnHeaderConstant {
public static final String VALUE = "Value";
public static final String DEVICE = "Device";
public static final String DEVICE_ID = "DeviceId";
+ public static final String DEVICE_TYPE = "DeviceType";
public static final String EXPLAIN_ANALYZE = "Explain Analyze";
// column names for schema statement
@@ -660,7 +661,9 @@ public class ColumnHeaderConstant {
new ColumnHeader(COUNT_INSTANCES, TSDataType.INT32));
public static final List<ColumnHeader> showAIDevicesColumnHeaders =
- ImmutableList.of(new ColumnHeader(DEVICE_ID, TSDataType.TEXT));
+ ImmutableList.of(
+ new ColumnHeader(DEVICE_ID, TSDataType.TEXT),
+ new ColumnHeader(DEVICE_TYPE, TSDataType.TEXT));
public static final List<ColumnHeader> showLogicalViewColumnHeaders =
ImmutableList.of(
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 8a5971823ec..1cb585f0323 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -118,7 +118,7 @@ struct TShowLoadedModelsResp {
struct TShowAIDevicesResp {
1: required common.TSStatus status
- 2: required list<string> deviceIdList
+ 2: required map<string, string> deviceIdMap
}
struct TLoadModelReq {