This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new cc66a38b223 [AINode] Reconstruct inference structure (#16237)
cc66a38b223 is described below
commit cc66a38b223a00008186360e70034b7e14e8d10a
Author: Leo <[email protected]>
AuthorDate: Sun Aug 24 08:44:41 2025 +0800
[AINode] Reconstruct inference structure (#16237)
---
.../{scheduler => dispatcher}/__init__.py | 0
.../abstract_dispatcher.py} | 26 ++++
.../core/inference/dispatcher/basic_dispatcher.py | 60 ++++++++
.../core/inference/inference_request_pool.py | 12 +-
.../ainode/core/inference/pool_controller.py | 155 ++++++++++++++-------
...ference_request_pool_group.py => pool_group.py} | 32 ++++-
.../ainode/ainode/core/inference/pool_scheduler.py | 125 -----------------
.../{scheduler => pool_scheduler}/__init__.py | 0
.../pool_scheduler/abstract_pool_scheduler.py | 56 ++++++++
.../pool_scheduler/basic_pool_scheduler.py | 59 ++++++++
.../ainode/core/inference/request_controller.py | 89 ------------
.../{scheduler => request_scheduler}/__init__.py | 0
.../abstract_request_scheduler.py} | 2 +-
.../basic_request_scheduler.py} | 9 +-
.../ainode/core/manager/inference_manager.py | 56 +++++++-
iotdb-core/ainode/ainode/core/manager/utils.py | 8 +-
16 files changed, 406 insertions(+), 283 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
b/iotdb-core/ainode/ainode/core/inference/dispatcher/__init__.py
similarity index 100%
copy from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
copy to iotdb-core/ainode/ainode/core/inference/dispatcher/__init__.py
diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
b/iotdb-core/ainode/ainode/core/inference/dispatcher/abstract_dispatcher.py
similarity index 50%
copy from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
copy to
iotdb-core/ainode/ainode/core/inference/dispatcher/abstract_dispatcher.py
index 2a1e720805f..18cdee14cf0 100644
--- a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
+++ b/iotdb-core/ainode/ainode/core/inference/dispatcher/abstract_dispatcher.py
@@ -15,3 +15,29 @@
# specific language governing permissions and limitations
# under the License.
#
+
+from abc import ABC, abstractmethod
+from typing import Dict
+
+from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.inference.inference_request_pool import PoolState
+
+
+class AbstractDispatcher(ABC):
+ """
+ Abstract base class for dispatchers that handle inference requests.
+ """
+
+ def __init__(self, pool_states: Dict[int, PoolState]):
+ """
+ Args:
+ pool_states: Dictionary containing the states of inference request
pools in the same pool group.
+ """
+ self.pool_states = pool_states
+
+ @abstractmethod
+ def dispatch_request(self, req: InferenceRequest, pool_ids: list[int]) ->
int:
+ """
+ Dispatch an inference request to the appropriate pool.
+ """
+ pass
diff --git
a/iotdb-core/ainode/ainode/core/inference/dispatcher/basic_dispatcher.py
b/iotdb-core/ainode/ainode/core/inference/dispatcher/basic_dispatcher.py
new file mode 100644
index 00000000000..590b37fe349
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/dispatcher/basic_dispatcher.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from ainode.core.exception import (
+ InferenceModelInternalError,
+)
+from ainode.core.inference.dispatcher.abstract_dispatcher import
AbstractDispatcher
+from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.inference.inference_request_pool import PoolState
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class BasicDispatcher(AbstractDispatcher):
+ """
+ Basic dispatcher for inference requests.
+ """
+
+ def __init__(self, pool_states):
+ super().__init__(pool_states)
+
+ def _select_pool_by_hash(self, req, pool_ids) -> int:
+ """
+ Select a pool for the given request using a hash-based approach.
+ """
+ model_id = req.model_id
+ if not pool_ids:
+ raise InferenceModelInternalError(
+ f"No available pools for model {model_id}"
+ )
+ start_idx = hash(req.req_id) % len(pool_ids)
+ n = len(pool_ids)
+ for i in range(n):
+ pool_id = pool_ids[(start_idx + i) % n]
+ state = self.pool_states[pool_id]
+ if state == PoolState.RUNNING:
+ return pool_id
+ raise InferenceModelInternalError(
+ f"No RUNNING pools available for model {model_id}"
+ )
+
+ def dispatch_request(self, req: InferenceRequest, pool_ids: list[int]) ->
int:
+ pool_idx = self._select_pool_by_hash(req, pool_ids)
+ return pool_idx
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
index 956d0773785..9005dbe642c 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -29,7 +29,9 @@ from transformers import PretrainedConfig
from ainode.core.config import AINodeDescriptor
from ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE
-from ainode.core.inference.scheduler.basic_scheduler import BasicScheduler
+from ainode.core.inference.request_scheduler.basic_request_scheduler import (
+ BasicRequestScheduler,
+)
from ainode.core.log import Logger
from ainode.core.manager.model_manager import ModelManager
@@ -74,7 +76,7 @@ class InferenceRequestPool(mp.Process):
self._waiting_queue = request_queue # Requests that are waiting to be
processed
self._running_queue = mp.Queue() # Requests that are currently being
processed
self._finished_queue = result_queue # Requests that are finished
- self._scheduler = BasicScheduler(
+ self._request_scheduler = BasicRequestScheduler(
self._waiting_queue, self._running_queue, self._finished_queue,
self.pool_id
)
self._stop_event = mp.Event()
@@ -113,7 +115,7 @@ class InferenceRequestPool(mp.Process):
)
def _activate_requests(self):
- requests = self._scheduler.schedule_activate()
+ requests = self._request_scheduler.schedule_activate()
for request in requests:
request.inputs = request.inference_pipeline.preprocess_inputs(
request.inputs
@@ -130,7 +132,7 @@ class InferenceRequestPool(mp.Process):
self._activate_requests()
def _step(self):
- requests = self._scheduler.schedule_step()
+ requests = self._request_scheduler.schedule_step()
# TODO: We need a batcher to accelerate the concurrent inference
for request in requests:
if self.model_id == "sundial":
@@ -178,7 +180,7 @@ class InferenceRequestPool(mp.Process):
self.logger = Logger(
INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
)
- self._scheduler.device = self.device
+ self._request_scheduler.device = self.device
self.model = self._model_manager.load_model(self.model_id,
{}).to(self.device)
self.ready_event.set()
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_controller.py
b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
index 121d3023df9..8f4efc9ebf3 100644
--- a/iotdb-core/ainode/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
@@ -15,19 +15,24 @@
# specific language governing permissions and limitations
# under the License.
#
-from collections import defaultdict
+
+import threading
from typing import Dict, Optional
import torch
import torch.multiprocessing as mp
-from ainode.core.exception import (
- InferenceModelInternalError,
-)
from ainode.core.inference.inference_request import InferenceRequest
from ainode.core.inference.inference_request_pool import InferenceRequestPool,
PoolState
-from ainode.core.inference.inference_request_pool_group import PoolGroup
+from ainode.core.inference.pool_group import PoolGroup
+from ainode.core.inference.pool_scheduler.basic_pool_scheduler import (
+ BasicPoolScheduler,
+ ScaleActionType,
+)
from ainode.core.log import Logger
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
+from ainode.core.model.timerxl.configuration_timer import TimerConfig
+from ainode.core.util.decorator import synchronized
logger = Logger()
@@ -35,53 +40,116 @@ logger = Logger()
class PoolController:
"""
A controller for handling inference request pools.
- It handles the registration of pools, adding and removing requests,
- and managing the state of each pool.
"""
DEFAULT_DEVICE = torch.device("cpu")
# DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
- def __init__(self):
- # structure: {model_id: {pool_id: PoolState}}
- self.pool_states: Dict[str, Dict[int, PoolState]] = defaultdict(dict)
+ def __init__(self, result_queue: mp.Queue):
# structure: {model_id: PoolGroup}
self._request_pool_map: Dict[str, PoolGroup] = {}
-
- def dispatch_request(self, model_id, req: InferenceRequest):
- pool_idx = self._select_pool_by_hash(model_id, req.req_id)
- self.add_request(pool_idx, req)
- logger.debug(
-
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{req.req_id}]
Request is queued for inference"
+ self._result_queue = result_queue
+ self._pool_scheduler = BasicPoolScheduler(self._request_pool_map)
+
+ @synchronized(threading.Lock())
+ def first_req_init(self, model_id: str):
+ if not self.has_request_pools(model_id):
+ actions = self._pool_scheduler.schedule(model_id)
+ for action in actions:
+ if action.action == ScaleActionType.SCALE_UP:
+ # initialize the first pool
+ self._first_pool_init(action.model_id)
+ # start a background thread to expand pools
+ expand_thread = threading.Thread(
+ target=self._expand_pools,
+ args=(action.model_id, 1, action.amount - 1),
+ daemon=True,
+ )
+ expand_thread.start()
+ elif action.action == ScaleActionType.SCALE_DOWN:
+ # TODO: implement scale down logic
+ pass
+
+ def _first_pool_init(self, model_id: str):
+ if model_id == "sundial":
+ config = SundialConfig()
+ elif model_id == "timer_xl":
+ config = TimerConfig()
+ first_queue = mp.Queue()
+ ready_event = mp.Event()
+ first_pool = InferenceRequestPool(
+ pool_id=0,
+ model_id=model_id,
+ config=config,
+ request_queue=first_queue,
+ result_queue=self._result_queue,
+ ready_event=ready_event,
)
+ first_pool.start()
+ self.register_pool(model_id, 0, first_pool, first_queue)
+ if not ready_event.wait(timeout=30):
+ self.unregister_pool(model_id, 0)
+ logger.error(
+ f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool
failed to be ready in time"
+ )
+ else:
+ self.set_state(model_id, 0, PoolState.RUNNING)
+ logger.info(
+ f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0]
Initialized inference request pool for model {model_id}"
+ )
- def _select_pool_by_hash(self, model_id, req_id) -> int:
- pool_ids = self.get_pool_ids(model_id)
- if not pool_ids:
- raise InferenceModelInternalError(
- f"No available pools for model {model_id}"
+ def _expand_pools(self, model_id, start_idx, count):
+ for idx in range(count):
+ queue = mp.Queue()
+ pool_id = start_idx + idx
+ if model_id == "sundial":
+ config = SundialConfig()
+ elif model_id == "timer_xl":
+ config = TimerConfig()
+ pool = InferenceRequestPool(
+ pool_id=pool_id,
+ model_id=model_id,
+ config=config,
+ request_queue=queue,
+ result_queue=self._result_queue,
+ ready_event=mp.Event(),
)
- start_idx = hash(req_id) % len(pool_ids)
- n = len(pool_ids)
- for i in range(n):
- pool_id = pool_ids[(start_idx + i) % n]
- state = self.get_state(model_id, pool_id)
- if state == PoolState.RUNNING:
- return pool_id
- raise InferenceModelInternalError(
- f"No RUNNING pools available for model {model_id}"
- )
+ pool.start()
+ self.register_pool(model_id, pool_id, pool, queue)
+ if not pool.ready_event.wait(timeout=30):
+ logger.error(
+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_id}] Pool failed to be
ready in time"
+ )
+ self.unregister_pool(model_id, pool_id)
+ continue
+ else:
+ self.set_state(model_id, pool_id, PoolState.RUNNING)
+ logger.info(
+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference
request pool started for model {model_id}"
+ )
+
+ def add_request(self, model_id: str, req: InferenceRequest):
+ # lazy initialization for first request
+ model_id = req.model_id
+ if not self.has_request_pools(model_id):
+ self.first_req_init(model_id)
+ self._request_pool_map[model_id].dispatch_request(req)
+
+ def get_state(self, model_id, pool_id) -> PoolState:
+ return self._request_pool_map[model_id].get_state(pool_id)
+
+ def set_state(self, model_id, pool_id, state):
+ self._request_pool_map[model_id].set_state(pool_id, state)
def register_pool(self, model_id, pool_id, request_pool, request_queue):
- self.set_state(model_id, pool_id, PoolState.RUNNING)
self.set_request_pool_map(model_id, pool_id, request_pool,
request_queue)
+ pool_group = self._request_pool_map.get(model_id)
+ pool_group.set_state(pool_id, PoolState.INITIALIZING)
- def add_request(self, pool_id, req):
- req_q = self.get_request_queue(req.model_id, pool_id)
- req_q.put(req)
-
- def remove_request(self, model_id, req_id):
- pass
+ def unregister_pool(self, model_id, pool_id):
+ self._request_pool_map[model_id].remove_pool(pool_id)
+ if not self._request_pool_map[model_id].get_pool_ids():
+ self._request_pool_map.pop(model_id, None)
def get_pool_ids(self, model_id) -> list[int]:
return self._request_pool_map[model_id].get_pool_ids()
@@ -106,17 +174,8 @@ class PoolController:
self._request_pool_map[model_id] = PoolGroup(model_id)
self._request_pool_map[model_id].add_pool(pool_id, request_pool,
request_queue)
- def get_state(self, model_id, pool_id) -> PoolState:
- return self.pool_states[model_id][pool_id]
-
- def set_state(self, model_id, pool_id, state):
- self.pool_states[model_id][pool_id] = state
-
- def get_load(self, model_id, pool_id) -> int:
- pass
-
def shutdown(self):
- for model_id, pool_group in self._request_pool_map.items():
+ for pool_group in self._request_pool_map.values():
for pool_id in pool_group.get_pool_ids():
request_pool = pool_group.get_request_pool(pool_id)
request_queue = pool_group.get_request_queue(pool_id)
diff --git
a/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
b/iotdb-core/ainode/ainode/core/inference/pool_group.py
similarity index 65%
rename from
iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
rename to iotdb-core/ainode/ainode/core/inference/pool_group.py
index 3dc2929e388..99704008c22 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
+++ b/iotdb-core/ainode/ainode/core/inference/pool_group.py
@@ -17,12 +17,14 @@
#
from typing import Dict, Tuple
+import torch
import torch.multiprocessing as mp
from ainode.core.exception import (
InferenceModelInternalError,
)
-from ainode.core.inference.inference_request_pool import InferenceRequestPool
+from ainode.core.inference.dispatcher.basic_dispatcher import BasicDispatcher
+from ainode.core.inference.inference_request_pool import InferenceRequestPool,
PoolState
from ainode.core.log import Logger
logger = Logger()
@@ -33,9 +35,16 @@ class PoolGroup:
A group of inference request pools for a specific model.
"""
+ DEFAULT_DEVICE = torch.device("cpu")
+ # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
+
def __init__(self, model_id):
+ # structure: {pool_id: (InferenceRequestPool, mp.Queue)}
self.pool_group: Dict[int, Tuple[InferenceRequestPool, mp.Queue]] = {}
+ # structure: {pool_id: PoolState}
+ self.pool_states: Dict[int, PoolState] = {}
self.model_id = model_id
+ self.request_dispatcher = BasicDispatcher(self.pool_states)
def get_pool_group(self) -> Dict[int, Tuple[InferenceRequestPool,
mp.Queue]]:
return self.pool_group
@@ -45,9 +54,21 @@ class PoolGroup:
):
self.pool_group[pool_id] = (request_pool, request_queue)
+ def remove_pool(self, pool_id: int):
+ self.pool_group.pop(pool_id, None)
+ self.pool_states.pop(pool_id, None)
+
def get_pool_ids(self) -> list[int]:
return list(self.pool_group.keys())
+ def dispatch_request(self, req):
+ pool_idx = self.request_dispatcher.dispatch_request(req,
self.get_pool_ids())
+ req_q = self.pool_group[pool_idx][1]
+ req_q.put(req)
+ logger.debug(
+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{req.req_id}]
Request is queued for inference"
+ )
+
def get_request_pool(self, pool_id) -> InferenceRequestPool:
if pool_id not in self.pool_group:
raise InferenceModelInternalError(
@@ -61,3 +82,12 @@ class PoolGroup:
f"Pool ID {pool_id} not found for model {self.model_id}"
)
return self.pool_group[pool_id][1]
+
+ def get_state(self, pool_id) -> PoolState:
+ return self.pool_states[pool_id]
+
+ def set_state(self, pool_id, state):
+ self.pool_states[pool_id] = state
+
+ def get_load(self, pool_id) -> int:
+ pass
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
deleted file mode 100644
index 53aa329c494..00000000000
--- a/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import threading
-
-import torch
-import torch.multiprocessing as mp
-
-from ainode.core.exception import (
- InferenceModelInternalError,
-)
-from ainode.core.inference.inference_request_pool import InferenceRequestPool,
PoolState
-from ainode.core.inference.pool_controller import PoolController
-from ainode.core.log import Logger
-from ainode.core.manager.utils import (
- _estimate_pool_size,
-)
-from ainode.core.model.sundial.configuration_sundial import SundialConfig
-from ainode.core.model.timerxl.configuration_timer import TimerConfig
-from ainode.core.util.decorator import synchronized
-
-logger = Logger()
-
-
-class PoolScheduler:
- """
- A Scheduler to init the request pools.
- It initializes the first pool and starts a background thread to expand
pools
- as needed based on the model_id.
- """
-
- DEFAULT_DEVICE = torch.device("cpu")
- # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
-
- def __init__(self, pool_controller: PoolController, result_queue:
mp.Queue):
- self._pool_controller = pool_controller
- self._result_queue = result_queue
-
- @synchronized(threading.Lock())
- def first_req_init(self, model_id: str):
- if not self._pool_controller.has_request_pools(model_id):
- pool_num = _estimate_pool_size(self.DEFAULT_DEVICE, model_id)
- if pool_num <= 0:
- raise InferenceModelInternalError(
- f"Not enough memory to run model {model_id}."
- )
- # initialize the first pool
- self._first_pool_init(model_id)
- # start a background thread to expand pools
- expand_thread = threading.Thread(
- target=self._expand_pools,
- args=(model_id, 1, pool_num - 1),
- daemon=True,
- )
- expand_thread.start()
-
- def _first_pool_init(self, model_id: str):
- if model_id == "sundial":
- config = SundialConfig()
- elif model_id == "timer_xl":
- config = TimerConfig()
- first_queue = mp.Queue()
- ready_event = mp.Event()
- first_pool = InferenceRequestPool(
- pool_id=0,
- model_id=model_id,
- config=config,
- request_queue=first_queue,
- result_queue=self._result_queue,
- ready_event=ready_event,
- )
- first_pool.start()
- self._pool_controller.set_state(model_id, 0, PoolState.INITIALIZING)
- if not ready_event.wait(timeout=30):
- logger.error(
- f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool
failed to be ready in time"
- )
- else:
- self._pool_controller.register_pool(model_id, 0, first_pool,
first_queue)
- logger.info(
- f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0]
Initialized inference request pool for model {model_id}"
- )
-
- def _expand_pools(self, model_id, start_idx, count):
- for idx in range(count):
- queue = mp.Queue()
- pool_id = start_idx + idx
- if model_id == "sundial":
- config = SundialConfig()
- elif model_id == "timer_xl":
- config = TimerConfig()
- pool = InferenceRequestPool(
- pool_id=pool_id,
- model_id=model_id,
- config=config,
- request_queue=queue,
- result_queue=self._result_queue,
- ready_event=mp.Event(),
- )
- pool.start()
- self._pool_controller.set_state(model_id, pool_id,
PoolState.INITIALIZING)
- if not pool.ready_event.wait(timeout=30):
- logger.error(
-
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_id}] Pool failed to be
ready in time"
- )
- continue
- else:
- self._pool_controller.register_pool(model_id, pool_id, pool,
queue)
- logger.info(
-
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference
request pool started for model {model_id}"
- )
diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/__init__.py
similarity index 100%
copy from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
copy to iotdb-core/ainode/ainode/core/inference/pool_scheduler/__init__.py
diff --git
a/iotdb-core/ainode/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
new file mode 100644
index 00000000000..a2140b0a8a6
--- /dev/null
+++
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict, List
+
+from ainode.core.inference.pool_group import PoolGroup
+
+
+class ScaleActionType(Enum):
+ SCALE_UP = "Scale Up"
+ SCALE_DOWN = "Scale Down"
+
+
+@dataclass(frozen=True)
+class ScaleAction:
+ action: ScaleActionType
+ amount: int
+ model_id: str
+
+
+class AbstractPoolScheduler(ABC):
+ """
+ Abstract base class for pool scheduling strategies.
+ """
+
+ def __init__(self, request_pool_map: Dict[str, PoolGroup]):
+ """
+ Args:
+ request_pool_map: A mapping from model IDs to their corresponding
pool groups.
+ """
+ self._request_pool_map = request_pool_map
+
+ @abstractmethod
+ def schedule(self, model_id: str) -> List[ScaleAction]:
+ """
+ Schedule a scaling action for the given model_id.
+ """
+ pass
diff --git
a/iotdb-core/ainode/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
new file mode 100644
index 00000000000..8994a9e00b2
--- /dev/null
+++
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Dict, List
+
+import torch
+
+from ainode.core.exception import (
+ InferenceModelInternalError,
+)
+from ainode.core.inference.pool_group import PoolGroup
+from ainode.core.inference.pool_scheduler.abstract_pool_scheduler import (
+ AbstractPoolScheduler,
+ ScaleAction,
+ ScaleActionType,
+)
+from ainode.core.log import Logger
+from ainode.core.manager.utils import estimate_pool_size
+
+logger = Logger()
+
+
+class BasicPoolScheduler(AbstractPoolScheduler):
+ """
+ A basic scheduler to init the request pools.
+ """
+
+ DEFAULT_DEVICE = torch.device("cpu")
+ # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
+
+ def __init__(self, request_pool_map: Dict[str, PoolGroup]):
+ super().__init__(request_pool_map)
+
+ def schedule(self, model_id: str) -> List[ScaleAction]:
+ """
+ Schedule a scaling action for the given model_id.
+ """
+ if model_id not in self._request_pool_map:
+ pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id)
+ if pool_num <= 0:
+ raise InferenceModelInternalError(
+ f"Not enough memory to run model {model_id}."
+ )
+ return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
diff --git a/iotdb-core/ainode/ainode/core/inference/request_controller.py
b/iotdb-core/ainode/ainode/core/inference/request_controller.py
deleted file mode 100644
index bde513a2813..00000000000
--- a/iotdb-core/ainode/ainode/core/inference/request_controller.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import threading
-import time
-
-import torch.multiprocessing as mp
-
-from ainode.core.config import AINodeDescriptor
-from ainode.core.inference.inference_request import (
- InferenceRequest,
- InferenceRequestProxy,
-)
-from ainode.core.inference.pool_controller import PoolController
-from ainode.core.inference.pool_scheduler import PoolScheduler
-from ainode.core.log import Logger
-
-logger = Logger()
-
-
-class RequestController:
- """
- Controls the lifecycle and scheduling of inference requests.
- """
-
- WAITING_INTERVAL_IN_MS = (
-
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
- ) # How often to check for requests in the result queue
-
- def __init__(self):
- self._result_queue = mp.Queue()
- self._result_wrapper_map = {}
- self._result_wrapper_lock = threading.RLock()
-
- self._stop_event = mp.Event()
- self._result_handler_thread = threading.Thread(
- target=self._handle_results, daemon=True
- )
- self._result_handler_thread.start()
- self._pool_controller = PoolController()
- self._pool_scheduler = PoolScheduler(self._pool_controller,
self._result_queue)
-
- def _handle_results(self):
- while not self._stop_event.is_set():
- if self._result_queue.empty():
- time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
- continue
- infer_req: InferenceRequest = self._result_queue.get()
- self._pool_controller.remove_request(infer_req.model_id,
infer_req.req_id)
- with self._result_wrapper_lock:
- self._result_wrapper_map[infer_req.req_id].set_result(
- infer_req.get_final_output()
- )
-
- def process_request(self, req):
- infer_proxy = InferenceRequestProxy(req.req_id)
- with self._result_wrapper_lock:
- self._result_wrapper_map[req.req_id] = infer_proxy
- # lazy initialization for first request
- model_id = req.model_id
- if not self._pool_controller.has_request_pools(model_id):
- self._pool_scheduler.first_req_init(model_id)
- # dispatch request to the pool
- self._pool_controller.dispatch_request(model_id, req)
- outputs = infer_proxy.wait_for_completion()
- with self._result_wrapper_lock:
- del self._result_wrapper_map[req.req_id]
- return outputs
-
- def shutdown(self):
- self._stop_event.set()
- self._pool_controller.shutdown()
- while not self._result_queue.empty():
- self._result_queue.get_nowait()
- self._result_queue.close()
diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/__init__.py
similarity index 100%
rename from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
rename to iotdb-core/ainode/ainode/core/inference/request_scheduler/__init__.py
diff --git
a/iotdb-core/ainode/ainode/core/inference/scheduler/abstract_scheduler.py
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/abstract_request_scheduler.py
similarity index 98%
rename from
iotdb-core/ainode/ainode/core/inference/scheduler/abstract_scheduler.py
rename to
iotdb-core/ainode/ainode/core/inference/request_scheduler/abstract_request_scheduler.py
index 8bc34e529c7..a6c2fe53cae 100644
--- a/iotdb-core/ainode/ainode/core/inference/scheduler/abstract_scheduler.py
+++
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/abstract_request_scheduler.py
@@ -19,7 +19,7 @@
from abc import ABC, abstractmethod
-class AbstractScheduler(ABC):
+class AbstractRequestScheduler(ABC):
"""
Abstract base class for inference scheduling strategies.
diff --git
a/iotdb-core/ainode/ainode/core/inference/scheduler/basic_scheduler.py
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/basic_request_scheduler.py
similarity index 90%
rename from iotdb-core/ainode/ainode/core/inference/scheduler/basic_scheduler.py
rename to
iotdb-core/ainode/ainode/core/inference/request_scheduler/basic_request_scheduler.py
index 65dee81ca1c..248a7a88a4c 100644
--- a/iotdb-core/ainode/ainode/core/inference/scheduler/basic_scheduler.py
+++
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/basic_request_scheduler.py
@@ -21,16 +21,17 @@ import os
import psutil
import torch
-from ainode.core.inference.inference_request import InferenceRequest
-from ainode.core.inference.scheduler.abstract_scheduler import
AbstractScheduler
+from ainode.core.inference.request_scheduler.abstract_request_scheduler import
(
+ AbstractRequestScheduler,
+)
from ainode.core.log import Logger
logger = Logger()
-class BasicScheduler(AbstractScheduler):
+class BasicRequestScheduler(AbstractRequestScheduler):
"""
- A simple FIFO scheduler that selects requests based on memory availability
and activation/step size.
+ A simple FIFO request scheduler that selects requests based on memory
availability and activation/step size.
"""
def __init__(
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index a7f10b8b389..2d4e2088ac7 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -16,11 +16,14 @@
# under the License.
#
+import threading
+import time
from abc import ABC, abstractmethod
from typing import Dict
import pandas as pd
import torch
+import torch.multiprocessing as mp
from iotdb.tsfile.utils.tsblock_serde import deserialize
from ainode.core.config import AINodeDescriptor
@@ -33,8 +36,9 @@ from ainode.core.exception import (
)
from ainode.core.inference.inference_request import (
InferenceRequest,
+ InferenceRequestProxy,
)
-from ainode.core.inference.request_controller import RequestController
+from ainode.core.inference.pool_controller import PoolController
from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
TimerSundialInferencePipeline,
)
@@ -45,7 +49,7 @@ from ainode.core.inference.utils import generate_req_id
from ainode.core.log import Logger
from ainode.core.manager.model_manager import ModelManager
from ainode.core.manager.utils import (
- _measure_model_memory,
+ measure_model_memory,
)
from ainode.core.model.sundial.configuration_sundial import SundialConfig
from ainode.core.model.sundial.modeling_sundial import SundialForPrediction
@@ -143,18 +147,31 @@ class InferenceManager:
DEFAULT_DEVICE = torch.device("cpu")
# DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
+ WAITING_INTERVAL_IN_MS = (
+
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+ ) # How often to check for requests in the result queue
+
def __init__(self):
self._model_manager = ModelManager()
self._model_mem_usage_map: Dict[str, int] = (
{}
) # store model memory usage for each model
- self._request_controller = RequestController()
+ self._result_queue = mp.Queue()
+ self._result_wrapper_map = {}
+ self._result_wrapper_lock = threading.RLock()
+
+ self._stop_event = mp.Event()
+ self._result_handler_thread = threading.Thread(
+ target=self._handle_results, daemon=True
+ )
+ self._result_handler_thread.start()
+ self._pool_controller = PoolController(self._result_queue)
# self._preload_model_benchmarks()
def _preload_model_benchmarks(self):
if "cuda" in str(self.DEFAULT_DEVICE):
for model_id in self.ACCELERATE_MODEL_ID:
- mem_usage = _measure_model_memory(self.DEFAULT_DEVICE,
model_id)
+ mem_usage = measure_model_memory(self.DEFAULT_DEVICE, model_id)
self._model_mem_usage_map[model_id] = mem_usage
logger.info(
f"[Inference] Preloaded benchmark for {model_id},
mem_usage={mem_usage/1024**2:.2f} MB"
@@ -164,6 +181,29 @@ class InferenceManager:
f"[Inference] Skipped preloading benchmarks for
{self.DEFAULT_DEVICE}, only supports CUDA currently"
)
+ def _handle_results(self):
+ while not self._stop_event.is_set():
+ if self._result_queue.empty():
+ time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+ continue
+ infer_req: InferenceRequest = self._result_queue.get()
+ with self._result_wrapper_lock:
+ self._result_wrapper_map[infer_req.req_id].set_result(
+ infer_req.get_final_output()
+ )
+
+ def process_request(self, req):
+ req_id = req.req_id
+ infer_proxy = InferenceRequestProxy(req_id)
+ with self._result_wrapper_lock:
+ self._result_wrapper_map[req_id] = infer_proxy
+ # dispatch request to the pool
+ self._pool_controller.add_request(req.model_id, req)
+ outputs = infer_proxy.wait_for_completion()
+ with self._result_wrapper_lock:
+ del self._result_wrapper_map[req_id]
+ return outputs
+
def _get_strategy(self, model_id, model):
if isinstance(model, TimerForPrediction):
return TimerXLStrategy(model)
@@ -227,7 +267,7 @@ class InferenceManager:
inference_pipeline=inference_pipeline,
max_new_tokens=predict_length,
)
- outputs = self._request_controller.process_request(infer_req)
+ outputs = self.process_request(infer_req)
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
else:
# load model
@@ -278,4 +318,8 @@ class InferenceManager:
)
def shutdown(self):
- self._request_controller.shutdown()
+ self._stop_event.set()
+ self._pool_controller.shutdown()
+ while not self._result_queue.empty():
+ self._result_queue.get_nowait()
+ self._result_queue.close()
diff --git a/iotdb-core/ainode/ainode/core/manager/utils.py
b/iotdb-core/ainode/ainode/core/manager/utils.py
index 5fbd444c38c..af60ff5ebe4 100644
--- a/iotdb-core/ainode/ainode/core/manager/utils.py
+++ b/iotdb-core/ainode/ainode/core/manager/utils.py
@@ -39,7 +39,7 @@ INFERENCE_EXTRA_MEMORY_RATIO = (
) # the overhead ratio for inference, used to estimate the pool size
-def _measure_model_memory(device: torch.device, model_id: str) -> int:
+def measure_model_memory(device: torch.device, model_id: str) -> int:
# TODO: support CPU in the future
# TODO: we can estimate the memory usage by running a dummy inference
torch.cuda.empty_cache()
@@ -62,7 +62,7 @@ def _measure_model_memory(device: torch.device, model_id:
str) -> int:
return final
-def _evaluate_system_resources(device: torch.device) -> dict:
+def evaluate_system_resources(device: torch.device) -> dict:
if torch.cuda.is_available():
free_mem, total_mem = torch.cuda.mem_get_info()
logger.info(
@@ -79,7 +79,7 @@ def _evaluate_system_resources(device: torch.device) -> dict:
return {"device": "cpu", "free_mem": free_mem, "total_mem": total_mem}
-def _estimate_pool_size(device: torch.device, model_id: str) -> int:
+def estimate_pool_size(device: torch.device, model_id: str) -> int:
model_info = BUILT_IN_LTSM_MAP.get(model_id, None)
if model_info is None:
logger.error(f"[Inference][Device-{device}] Model {model_id} not
found")
@@ -90,7 +90,7 @@ def _estimate_pool_size(device: torch.device, model_id: str)
-> int:
logger.error(f"[Inference][Device-{device}] Model {model_id} not
supported now")
return 0
- system_res = _evaluate_system_resources(device)
+ system_res = evaluate_system_resources(device)
free_mem = system_res["free_mem"]
mem_usage = MODEL_MEM_USAGE_MAP[model_type] * INFERENCE_EXTRA_MEMORY_RATIO