This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new cc66a38b223 [AINode] Reconstruct inference structure (#16237)
cc66a38b223 is described below

commit cc66a38b223a00008186360e70034b7e14e8d10a
Author: Leo <[email protected]>
AuthorDate: Sun Aug 24 08:44:41 2025 +0800

    [AINode] Reconstruct inference structure (#16237)
---
 .../{scheduler => dispatcher}/__init__.py          |   0
 .../abstract_dispatcher.py}                        |  26 ++++
 .../core/inference/dispatcher/basic_dispatcher.py  |  60 ++++++++
 .../core/inference/inference_request_pool.py       |  12 +-
 .../ainode/core/inference/pool_controller.py       | 155 ++++++++++++++-------
 ...ference_request_pool_group.py => pool_group.py} |  32 ++++-
 .../ainode/ainode/core/inference/pool_scheduler.py | 125 -----------------
 .../{scheduler => pool_scheduler}/__init__.py      |   0
 .../pool_scheduler/abstract_pool_scheduler.py      |  56 ++++++++
 .../pool_scheduler/basic_pool_scheduler.py         |  59 ++++++++
 .../ainode/core/inference/request_controller.py    |  89 ------------
 .../{scheduler => request_scheduler}/__init__.py   |   0
 .../abstract_request_scheduler.py}                 |   2 +-
 .../basic_request_scheduler.py}                    |   9 +-
 .../ainode/core/manager/inference_manager.py       |  56 +++++++-
 iotdb-core/ainode/ainode/core/manager/utils.py     |   8 +-
 16 files changed, 406 insertions(+), 283 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py 
b/iotdb-core/ainode/ainode/core/inference/dispatcher/__init__.py
similarity index 100%
copy from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
copy to iotdb-core/ainode/ainode/core/inference/dispatcher/__init__.py
diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py 
b/iotdb-core/ainode/ainode/core/inference/dispatcher/abstract_dispatcher.py
similarity index 50%
copy from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
copy to 
iotdb-core/ainode/ainode/core/inference/dispatcher/abstract_dispatcher.py
index 2a1e720805f..18cdee14cf0 100644
--- a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
+++ b/iotdb-core/ainode/ainode/core/inference/dispatcher/abstract_dispatcher.py
@@ -15,3 +15,29 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
+from abc import ABC, abstractmethod
+from typing import Dict
+
+from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.inference.inference_request_pool import PoolState
+
+
+class AbstractDispatcher(ABC):
+    """
+    Abstract base class for dispatchers that handle inference requests.
+    """
+
+    def __init__(self, pool_states: Dict[int, PoolState]):
+        """
+        Args:
+            pool_states: Dictionary containing the states of inference request 
pools in the same pool group.
+        """
+        self.pool_states = pool_states
+
+    @abstractmethod
+    def dispatch_request(self, req: InferenceRequest, pool_ids: list[int]) -> 
int:
+        """
+        Dispatch an inference request to the appropriate pool.
+        """
+        pass
diff --git 
a/iotdb-core/ainode/ainode/core/inference/dispatcher/basic_dispatcher.py 
b/iotdb-core/ainode/ainode/core/inference/dispatcher/basic_dispatcher.py
new file mode 100644
index 00000000000..590b37fe349
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/dispatcher/basic_dispatcher.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from ainode.core.exception import (
+    InferenceModelInternalError,
+)
+from ainode.core.inference.dispatcher.abstract_dispatcher import 
AbstractDispatcher
+from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.inference.inference_request_pool import PoolState
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class BasicDispatcher(AbstractDispatcher):
+    """
+    Basic dispatcher for inference requests.
+    """
+
+    def __init__(self, pool_states):
+        super().__init__(pool_states)
+
+    def _select_pool_by_hash(self, req, pool_ids) -> int:
+        """
+        Select a pool for the given request using a hash-based approach.
+        """
+        model_id = req.model_id
+        if not pool_ids:
+            raise InferenceModelInternalError(
+                f"No available pools for model {model_id}"
+            )
+        start_idx = hash(req.req_id) % len(pool_ids)
+        n = len(pool_ids)
+        for i in range(n):
+            pool_id = pool_ids[(start_idx + i) % n]
+            state = self.pool_states[pool_id]
+            if state == PoolState.RUNNING:
+                return pool_id
+        raise InferenceModelInternalError(
+            f"No RUNNING pools available for model {model_id}"
+        )
+
+    def dispatch_request(self, req: InferenceRequest, pool_ids: list[int]) -> 
int:
+        pool_idx = self._select_pool_by_hash(req, pool_ids)
+        return pool_idx
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
index 956d0773785..9005dbe642c 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -29,7 +29,9 @@ from transformers import PretrainedConfig
 
 from ainode.core.config import AINodeDescriptor
 from ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE
-from ainode.core.inference.scheduler.basic_scheduler import BasicScheduler
+from ainode.core.inference.request_scheduler.basic_request_scheduler import (
+    BasicRequestScheduler,
+)
 from ainode.core.log import Logger
 from ainode.core.manager.model_manager import ModelManager
 
@@ -74,7 +76,7 @@ class InferenceRequestPool(mp.Process):
         self._waiting_queue = request_queue  # Requests that are waiting to be 
processed
         self._running_queue = mp.Queue()  # Requests that are currently being 
processed
         self._finished_queue = result_queue  # Requests that are finished
-        self._scheduler = BasicScheduler(
+        self._request_scheduler = BasicRequestScheduler(
             self._waiting_queue, self._running_queue, self._finished_queue, 
self.pool_id
         )
         self._stop_event = mp.Event()
@@ -113,7 +115,7 @@ class InferenceRequestPool(mp.Process):
         )
 
     def _activate_requests(self):
-        requests = self._scheduler.schedule_activate()
+        requests = self._request_scheduler.schedule_activate()
         for request in requests:
             request.inputs = request.inference_pipeline.preprocess_inputs(
                 request.inputs
@@ -130,7 +132,7 @@ class InferenceRequestPool(mp.Process):
             self._activate_requests()
 
     def _step(self):
-        requests = self._scheduler.schedule_step()
+        requests = self._request_scheduler.schedule_step()
         # TODO: We need a batcher to accelerate the concurrent inference
         for request in requests:
             if self.model_id == "sundial":
@@ -178,7 +180,7 @@ class InferenceRequestPool(mp.Process):
         self.logger = Logger(
             INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
         )
-        self._scheduler.device = self.device
+        self._request_scheduler.device = self.device
         self.model = self._model_manager.load_model(self.model_id, 
{}).to(self.device)
         self.ready_event.set()
 
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_controller.py 
b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
index 121d3023df9..8f4efc9ebf3 100644
--- a/iotdb-core/ainode/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
@@ -15,19 +15,24 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from collections import defaultdict
+
+import threading
 from typing import Dict, Optional
 
 import torch
 import torch.multiprocessing as mp
 
-from ainode.core.exception import (
-    InferenceModelInternalError,
-)
 from ainode.core.inference.inference_request import InferenceRequest
 from ainode.core.inference.inference_request_pool import InferenceRequestPool, 
PoolState
-from ainode.core.inference.inference_request_pool_group import PoolGroup
+from ainode.core.inference.pool_group import PoolGroup
+from ainode.core.inference.pool_scheduler.basic_pool_scheduler import (
+    BasicPoolScheduler,
+    ScaleActionType,
+)
 from ainode.core.log import Logger
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
+from ainode.core.model.timerxl.configuration_timer import TimerConfig
+from ainode.core.util.decorator import synchronized
 
 logger = Logger()
 
@@ -35,53 +40,116 @@ logger = Logger()
 class PoolController:
     """
     A controller for handling inference request pools.
-    It handles the registration of pools, adding and removing requests,
-    and managing the state of each pool.
     """
 
     DEFAULT_DEVICE = torch.device("cpu")
     # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
 
-    def __init__(self):
-        # structure: {model_id: {pool_id: PoolState}}
-        self.pool_states: Dict[str, Dict[int, PoolState]] = defaultdict(dict)
+    def __init__(self, result_queue: mp.Queue):
         # structure: {model_id: PoolGroup}
         self._request_pool_map: Dict[str, PoolGroup] = {}
-
-    def dispatch_request(self, model_id, req: InferenceRequest):
-        pool_idx = self._select_pool_by_hash(model_id, req.req_id)
-        self.add_request(pool_idx, req)
-        logger.debug(
-            
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{req.req_id}] 
Request is queued for inference"
+        self._result_queue = result_queue
+        self._pool_scheduler = BasicPoolScheduler(self._request_pool_map)
+
+    @synchronized(threading.Lock())
+    def first_req_init(self, model_id: str):
+        if not self.has_request_pools(model_id):
+            actions = self._pool_scheduler.schedule(model_id)
+            for action in actions:
+                if action.action == ScaleActionType.SCALE_UP:
+                    # initialize the first pool
+                    self._first_pool_init(action.model_id)
+                    # start a background thread to expand pools
+                    expand_thread = threading.Thread(
+                        target=self._expand_pools,
+                        args=(action.model_id, 1, action.amount - 1),
+                        daemon=True,
+                    )
+                    expand_thread.start()
+                elif action.action == ScaleActionType.SCALE_DOWN:
+                    # TODO: implement scale down logic
+                    pass
+
+    def _first_pool_init(self, model_id: str):
+        if model_id == "sundial":
+            config = SundialConfig()
+        elif model_id == "timer_xl":
+            config = TimerConfig()
+        first_queue = mp.Queue()
+        ready_event = mp.Event()
+        first_pool = InferenceRequestPool(
+            pool_id=0,
+            model_id=model_id,
+            config=config,
+            request_queue=first_queue,
+            result_queue=self._result_queue,
+            ready_event=ready_event,
         )
+        first_pool.start()
+        self.register_pool(model_id, 0, first_pool, first_queue)
+        if not ready_event.wait(timeout=30):
+            self.unregister_pool(model_id, 0)
+            logger.error(
+                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool 
failed to be ready in time"
+            )
+        else:
+            self.set_state(model_id, 0, PoolState.RUNNING)
+            logger.info(
+                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] 
Initialized inference request pool for model {model_id}"
+            )
 
-    def _select_pool_by_hash(self, model_id, req_id) -> int:
-        pool_ids = self.get_pool_ids(model_id)
-        if not pool_ids:
-            raise InferenceModelInternalError(
-                f"No available pools for model {model_id}"
+    def _expand_pools(self, model_id, start_idx, count):
+        for idx in range(count):
+            queue = mp.Queue()
+            pool_id = start_idx + idx
+            if model_id == "sundial":
+                config = SundialConfig()
+            elif model_id == "timer_xl":
+                config = TimerConfig()
+            pool = InferenceRequestPool(
+                pool_id=pool_id,
+                model_id=model_id,
+                config=config,
+                request_queue=queue,
+                result_queue=self._result_queue,
+                ready_event=mp.Event(),
             )
-        start_idx = hash(req_id) % len(pool_ids)
-        n = len(pool_ids)
-        for i in range(n):
-            pool_id = pool_ids[(start_idx + i) % n]
-            state = self.get_state(model_id, pool_id)
-            if state == PoolState.RUNNING:
-                return pool_id
-        raise InferenceModelInternalError(
-            f"No RUNNING pools available for model {model_id}"
-        )
+            pool.start()
+            self.register_pool(model_id, pool_id, pool, queue)
+            if not pool.ready_event.wait(timeout=30):
+                logger.error(
+                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_id}] Pool failed to be 
ready in time"
+                )
+                self.unregister_pool(model_id, pool_id)
+                continue
+            else:
+                self.set_state(model_id, pool_id, PoolState.RUNNING)
+                logger.info(
+                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference 
request pool started for model {model_id}"
+                )
+
+    def add_request(self, model_id: str, req: InferenceRequest):
+        # lazy initialization for first request
+        model_id = req.model_id
+        if not self.has_request_pools(model_id):
+            self.first_req_init(model_id)
+        self._request_pool_map[model_id].dispatch_request(req)
+
+    def get_state(self, model_id, pool_id) -> PoolState:
+        return self._request_pool_map[model_id].get_state(pool_id)
+
+    def set_state(self, model_id, pool_id, state):
+        self._request_pool_map[model_id].set_state(pool_id, state)
 
     def register_pool(self, model_id, pool_id, request_pool, request_queue):
-        self.set_state(model_id, pool_id, PoolState.RUNNING)
         self.set_request_pool_map(model_id, pool_id, request_pool, 
request_queue)
+        pool_group = self._request_pool_map.get(model_id)
+        pool_group.set_state(pool_id, PoolState.INITIALIZING)
 
-    def add_request(self, pool_id, req):
-        req_q = self.get_request_queue(req.model_id, pool_id)
-        req_q.put(req)
-
-    def remove_request(self, model_id, req_id):
-        pass
+    def unregister_pool(self, model_id, pool_id):
+        self._request_pool_map[model_id].remove_pool(pool_id)
+        if not self._request_pool_map[model_id].get_pool_ids():
+            self._request_pool_map.pop(model_id, None)
 
     def get_pool_ids(self, model_id) -> list[int]:
         return self._request_pool_map[model_id].get_pool_ids()
@@ -106,17 +174,8 @@ class PoolController:
             self._request_pool_map[model_id] = PoolGroup(model_id)
         self._request_pool_map[model_id].add_pool(pool_id, request_pool, 
request_queue)
 
-    def get_state(self, model_id, pool_id) -> PoolState:
-        return self.pool_states[model_id][pool_id]
-
-    def set_state(self, model_id, pool_id, state):
-        self.pool_states[model_id][pool_id] = state
-
-    def get_load(self, model_id, pool_id) -> int:
-        pass
-
     def shutdown(self):
-        for model_id, pool_group in self._request_pool_map.items():
+        for pool_group in self._request_pool_map.values():
             for pool_id in pool_group.get_pool_ids():
                 request_pool = pool_group.get_request_pool(pool_id)
                 request_queue = pool_group.get_request_queue(pool_id)
diff --git 
a/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py 
b/iotdb-core/ainode/ainode/core/inference/pool_group.py
similarity index 65%
rename from 
iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
rename to iotdb-core/ainode/ainode/core/inference/pool_group.py
index 3dc2929e388..99704008c22 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
+++ b/iotdb-core/ainode/ainode/core/inference/pool_group.py
@@ -17,12 +17,14 @@
 #
 from typing import Dict, Tuple
 
+import torch
 import torch.multiprocessing as mp
 
 from ainode.core.exception import (
     InferenceModelInternalError,
 )
-from ainode.core.inference.inference_request_pool import InferenceRequestPool
+from ainode.core.inference.dispatcher.basic_dispatcher import BasicDispatcher
+from ainode.core.inference.inference_request_pool import InferenceRequestPool, 
PoolState
 from ainode.core.log import Logger
 
 logger = Logger()
@@ -33,9 +35,16 @@ class PoolGroup:
     A group of inference request pools for a specific model.
     """
 
+    DEFAULT_DEVICE = torch.device("cpu")
+    # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
+
     def __init__(self, model_id):
+        # structure: {pool_id: (InferenceRequestPool, mp.Queue)}
         self.pool_group: Dict[int, Tuple[InferenceRequestPool, mp.Queue]] = {}
+        # structure: {pool_id: PoolState}
+        self.pool_states: Dict[int, PoolState] = {}
         self.model_id = model_id
+        self.request_dispatcher = BasicDispatcher(self.pool_states)
 
     def get_pool_group(self) -> Dict[int, Tuple[InferenceRequestPool, 
mp.Queue]]:
         return self.pool_group
@@ -45,9 +54,21 @@ class PoolGroup:
     ):
         self.pool_group[pool_id] = (request_pool, request_queue)
 
+    def remove_pool(self, pool_id: int):
+        self.pool_group.pop(pool_id, None)
+        self.pool_states.pop(pool_id, None)
+
     def get_pool_ids(self) -> list[int]:
         return list(self.pool_group.keys())
 
+    def dispatch_request(self, req):
+        pool_idx = self.request_dispatcher.dispatch_request(req, 
self.get_pool_ids())
+        req_q = self.pool_group[pool_idx][1]
+        req_q.put(req)
+        logger.debug(
+            
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{req.req_id}] 
Request is queued for inference"
+        )
+
     def get_request_pool(self, pool_id) -> InferenceRequestPool:
         if pool_id not in self.pool_group:
             raise InferenceModelInternalError(
@@ -61,3 +82,12 @@ class PoolGroup:
                 f"Pool ID {pool_id} not found for model {self.model_id}"
             )
         return self.pool_group[pool_id][1]
+
+    def get_state(self, pool_id) -> PoolState:
+        return self.pool_states[pool_id]
+
+    def set_state(self, pool_id, state):
+        self.pool_states[pool_id] = state
+
+    def get_load(self, pool_id) -> int:
+        pass
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py 
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
deleted file mode 100644
index 53aa329c494..00000000000
--- a/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import threading
-
-import torch
-import torch.multiprocessing as mp
-
-from ainode.core.exception import (
-    InferenceModelInternalError,
-)
-from ainode.core.inference.inference_request_pool import InferenceRequestPool, 
PoolState
-from ainode.core.inference.pool_controller import PoolController
-from ainode.core.log import Logger
-from ainode.core.manager.utils import (
-    _estimate_pool_size,
-)
-from ainode.core.model.sundial.configuration_sundial import SundialConfig
-from ainode.core.model.timerxl.configuration_timer import TimerConfig
-from ainode.core.util.decorator import synchronized
-
-logger = Logger()
-
-
-class PoolScheduler:
-    """
-    A Scheduler to init the request pools.
-    It initializes the first pool and starts a background thread to expand 
pools
-    as needed based on the model_id.
-    """
-
-    DEFAULT_DEVICE = torch.device("cpu")
-    # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
-
-    def __init__(self, pool_controller: PoolController, result_queue: 
mp.Queue):
-        self._pool_controller = pool_controller
-        self._result_queue = result_queue
-
-    @synchronized(threading.Lock())
-    def first_req_init(self, model_id: str):
-        if not self._pool_controller.has_request_pools(model_id):
-            pool_num = _estimate_pool_size(self.DEFAULT_DEVICE, model_id)
-            if pool_num <= 0:
-                raise InferenceModelInternalError(
-                    f"Not enough memory to run model {model_id}."
-                )
-            # initialize the first pool
-            self._first_pool_init(model_id)
-            # start a background thread to expand pools
-            expand_thread = threading.Thread(
-                target=self._expand_pools,
-                args=(model_id, 1, pool_num - 1),
-                daemon=True,
-            )
-            expand_thread.start()
-
-    def _first_pool_init(self, model_id: str):
-        if model_id == "sundial":
-            config = SundialConfig()
-        elif model_id == "timer_xl":
-            config = TimerConfig()
-        first_queue = mp.Queue()
-        ready_event = mp.Event()
-        first_pool = InferenceRequestPool(
-            pool_id=0,
-            model_id=model_id,
-            config=config,
-            request_queue=first_queue,
-            result_queue=self._result_queue,
-            ready_event=ready_event,
-        )
-        first_pool.start()
-        self._pool_controller.set_state(model_id, 0, PoolState.INITIALIZING)
-        if not ready_event.wait(timeout=30):
-            logger.error(
-                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool 
failed to be ready in time"
-            )
-        else:
-            self._pool_controller.register_pool(model_id, 0, first_pool, 
first_queue)
-            logger.info(
-                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] 
Initialized inference request pool for model {model_id}"
-            )
-
-    def _expand_pools(self, model_id, start_idx, count):
-        for idx in range(count):
-            queue = mp.Queue()
-            pool_id = start_idx + idx
-            if model_id == "sundial":
-                config = SundialConfig()
-            elif model_id == "timer_xl":
-                config = TimerConfig()
-            pool = InferenceRequestPool(
-                pool_id=pool_id,
-                model_id=model_id,
-                config=config,
-                request_queue=queue,
-                result_queue=self._result_queue,
-                ready_event=mp.Event(),
-            )
-            pool.start()
-            self._pool_controller.set_state(model_id, pool_id, 
PoolState.INITIALIZING)
-            if not pool.ready_event.wait(timeout=30):
-                logger.error(
-                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_id}] Pool failed to be 
ready in time"
-                )
-                continue
-            else:
-                self._pool_controller.register_pool(model_id, pool_id, pool, 
queue)
-                logger.info(
-                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference 
request pool started for model {model_id}"
-                )
diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py 
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/__init__.py
similarity index 100%
copy from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
copy to iotdb-core/ainode/ainode/core/inference/pool_scheduler/__init__.py
diff --git 
a/iotdb-core/ainode/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
 
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
new file mode 100644
index 00000000000..a2140b0a8a6
--- /dev/null
+++ 
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict, List
+
+from ainode.core.inference.pool_group import PoolGroup
+
+
+class ScaleActionType(Enum):
+    SCALE_UP = "Scale Up"
+    SCALE_DOWN = "Scale Down"
+
+
+@dataclass(frozen=True)
+class ScaleAction:
+    action: ScaleActionType
+    amount: int
+    model_id: str
+
+
+class AbstractPoolScheduler(ABC):
+    """
+    Abstract base class for pool scheduling strategies.
+    """
+
+    def __init__(self, request_pool_map: Dict[str, PoolGroup]):
+        """
+        Args:
+            request_pool_map: A mapping from model IDs to their corresponding 
pool groups.
+        """
+        self._request_pool_map = request_pool_map
+
+    @abstractmethod
+    def schedule(self, model_id: str) -> List[ScaleAction]:
+        """
+        Schedule a scaling action for the given model_id.
+        """
+        pass
diff --git 
a/iotdb-core/ainode/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
 
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
new file mode 100644
index 00000000000..8994a9e00b2
--- /dev/null
+++ 
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Dict, List
+
+import torch
+
+from ainode.core.exception import (
+    InferenceModelInternalError,
+)
+from ainode.core.inference.pool_group import PoolGroup
+from ainode.core.inference.pool_scheduler.abstract_pool_scheduler import (
+    AbstractPoolScheduler,
+    ScaleAction,
+    ScaleActionType,
+)
+from ainode.core.log import Logger
+from ainode.core.manager.utils import estimate_pool_size
+
+logger = Logger()
+
+
+class BasicPoolScheduler(AbstractPoolScheduler):
+    """
+    A basic scheduler to init the request pools.
+    """
+
+    DEFAULT_DEVICE = torch.device("cpu")
+    # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
+
+    def __init__(self, request_pool_map: Dict[str, PoolGroup]):
+        super().__init__(request_pool_map)
+
+    def schedule(self, model_id: str) -> List[ScaleAction]:
+        """
+        Schedule a scaling action for the given model_id.
+        """
+        if model_id not in self._request_pool_map:
+            pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id)
+            if pool_num <= 0:
+                raise InferenceModelInternalError(
+                    f"Not enough memory to run model {model_id}."
+                )
+            return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
diff --git a/iotdb-core/ainode/ainode/core/inference/request_controller.py 
b/iotdb-core/ainode/ainode/core/inference/request_controller.py
deleted file mode 100644
index bde513a2813..00000000000
--- a/iotdb-core/ainode/ainode/core/inference/request_controller.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import threading
-import time
-
-import torch.multiprocessing as mp
-
-from ainode.core.config import AINodeDescriptor
-from ainode.core.inference.inference_request import (
-    InferenceRequest,
-    InferenceRequestProxy,
-)
-from ainode.core.inference.pool_controller import PoolController
-from ainode.core.inference.pool_scheduler import PoolScheduler
-from ainode.core.log import Logger
-
-logger = Logger()
-
-
-class RequestController:
-    """
-    Controls the lifecycle and scheduling of inference requests.
-    """
-
-    WAITING_INTERVAL_IN_MS = (
-        
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
-    )  # How often to check for requests in the result queue
-
-    def __init__(self):
-        self._result_queue = mp.Queue()
-        self._result_wrapper_map = {}
-        self._result_wrapper_lock = threading.RLock()
-
-        self._stop_event = mp.Event()
-        self._result_handler_thread = threading.Thread(
-            target=self._handle_results, daemon=True
-        )
-        self._result_handler_thread.start()
-        self._pool_controller = PoolController()
-        self._pool_scheduler = PoolScheduler(self._pool_controller, 
self._result_queue)
-
-    def _handle_results(self):
-        while not self._stop_event.is_set():
-            if self._result_queue.empty():
-                time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
-                continue
-            infer_req: InferenceRequest = self._result_queue.get()
-            self._pool_controller.remove_request(infer_req.model_id, 
infer_req.req_id)
-            with self._result_wrapper_lock:
-                self._result_wrapper_map[infer_req.req_id].set_result(
-                    infer_req.get_final_output()
-                )
-
-    def process_request(self, req):
-        infer_proxy = InferenceRequestProxy(req.req_id)
-        with self._result_wrapper_lock:
-            self._result_wrapper_map[req.req_id] = infer_proxy
-        # lazy initialization for first request
-        model_id = req.model_id
-        if not self._pool_controller.has_request_pools(model_id):
-            self._pool_scheduler.first_req_init(model_id)
-        # dispatch request to the pool
-        self._pool_controller.dispatch_request(model_id, req)
-        outputs = infer_proxy.wait_for_completion()
-        with self._result_wrapper_lock:
-            del self._result_wrapper_map[req.req_id]
-        return outputs
-
-    def shutdown(self):
-        self._stop_event.set()
-        self._pool_controller.shutdown()
-        while not self._result_queue.empty():
-            self._result_queue.get_nowait()
-        self._result_queue.close()
diff --git a/iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py 
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/__init__.py
similarity index 100%
rename from iotdb-core/ainode/ainode/core/inference/scheduler/__init__.py
rename to iotdb-core/ainode/ainode/core/inference/request_scheduler/__init__.py
diff --git 
a/iotdb-core/ainode/ainode/core/inference/scheduler/abstract_scheduler.py 
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/abstract_request_scheduler.py
similarity index 98%
rename from 
iotdb-core/ainode/ainode/core/inference/scheduler/abstract_scheduler.py
rename to 
iotdb-core/ainode/ainode/core/inference/request_scheduler/abstract_request_scheduler.py
index 8bc34e529c7..a6c2fe53cae 100644
--- a/iotdb-core/ainode/ainode/core/inference/scheduler/abstract_scheduler.py
+++ 
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/abstract_request_scheduler.py
@@ -19,7 +19,7 @@
 from abc import ABC, abstractmethod
 
 
-class AbstractScheduler(ABC):
+class AbstractRequestScheduler(ABC):
     """
     Abstract base class for inference scheduling strategies.
 
diff --git 
a/iotdb-core/ainode/ainode/core/inference/scheduler/basic_scheduler.py 
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/basic_request_scheduler.py
similarity index 90%
rename from iotdb-core/ainode/ainode/core/inference/scheduler/basic_scheduler.py
rename to 
iotdb-core/ainode/ainode/core/inference/request_scheduler/basic_request_scheduler.py
index 65dee81ca1c..248a7a88a4c 100644
--- a/iotdb-core/ainode/ainode/core/inference/scheduler/basic_scheduler.py
+++ 
b/iotdb-core/ainode/ainode/core/inference/request_scheduler/basic_request_scheduler.py
@@ -21,16 +21,17 @@ import os
 import psutil
 import torch
 
-from ainode.core.inference.inference_request import InferenceRequest
-from ainode.core.inference.scheduler.abstract_scheduler import 
AbstractScheduler
+from ainode.core.inference.request_scheduler.abstract_request_scheduler import 
(
+    AbstractRequestScheduler,
+)
 from ainode.core.log import Logger
 
 logger = Logger()
 
 
-class BasicScheduler(AbstractScheduler):
+class BasicRequestScheduler(AbstractRequestScheduler):
     """
-    A simple FIFO scheduler that selects requests based on memory availability 
and activation/step size.
+    A simple FIFO request scheduler that selects requests based on memory 
availability and activation/step size.
     """
 
     def __init__(
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index a7f10b8b389..2d4e2088ac7 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -16,11 +16,14 @@
 # under the License.
 #
 
+import threading
+import time
 from abc import ABC, abstractmethod
 from typing import Dict
 
 import pandas as pd
 import torch
+import torch.multiprocessing as mp
 from iotdb.tsfile.utils.tsblock_serde import deserialize
 
 from ainode.core.config import AINodeDescriptor
@@ -33,8 +36,9 @@ from ainode.core.exception import (
 )
 from ainode.core.inference.inference_request import (
     InferenceRequest,
+    InferenceRequestProxy,
 )
-from ainode.core.inference.request_controller import RequestController
+from ainode.core.inference.pool_controller import PoolController
 from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
     TimerSundialInferencePipeline,
 )
@@ -45,7 +49,7 @@ from ainode.core.inference.utils import generate_req_id
 from ainode.core.log import Logger
 from ainode.core.manager.model_manager import ModelManager
 from ainode.core.manager.utils import (
-    _measure_model_memory,
+    measure_model_memory,
 )
 from ainode.core.model.sundial.configuration_sundial import SundialConfig
 from ainode.core.model.sundial.modeling_sundial import SundialForPrediction
@@ -143,18 +147,31 @@ class InferenceManager:
     DEFAULT_DEVICE = torch.device("cpu")
     # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
 
+    WAITING_INTERVAL_IN_MS = (
+        
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+    )  # How often to check for requests in the result queue
+
     def __init__(self):
         self._model_manager = ModelManager()
         self._model_mem_usage_map: Dict[str, int] = (
             {}
         )  # store model memory usage for each model
-        self._request_controller = RequestController()
+        self._result_queue = mp.Queue()
+        self._result_wrapper_map = {}
+        self._result_wrapper_lock = threading.RLock()
+
+        self._stop_event = mp.Event()
+        self._result_handler_thread = threading.Thread(
+            target=self._handle_results, daemon=True
+        )
+        self._result_handler_thread.start()
+        self._pool_controller = PoolController(self._result_queue)
         # self._preload_model_benchmarks()
 
     def _preload_model_benchmarks(self):
         if "cuda" in str(self.DEFAULT_DEVICE):
             for model_id in self.ACCELERATE_MODEL_ID:
-                mem_usage = _measure_model_memory(self.DEFAULT_DEVICE, 
model_id)
+                mem_usage = measure_model_memory(self.DEFAULT_DEVICE, model_id)
                 self._model_mem_usage_map[model_id] = mem_usage
                 logger.info(
                     f"[Inference] Preloaded benchmark for {model_id}, 
mem_usage={mem_usage/1024**2:.2f} MB"
@@ -164,6 +181,29 @@ class InferenceManager:
                 f"[Inference] Skipped preloading benchmarks for 
{self.DEFAULT_DEVICE}, only supports CUDA currently"
             )
 
+    def _handle_results(self):
+        while not self._stop_event.is_set():
+            if self._result_queue.empty():
+                time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+                continue
+            infer_req: InferenceRequest = self._result_queue.get()
+            with self._result_wrapper_lock:
+                self._result_wrapper_map[infer_req.req_id].set_result(
+                    infer_req.get_final_output()
+                )
+
+    def process_request(self, req):
+        req_id = req.req_id
+        infer_proxy = InferenceRequestProxy(req_id)
+        with self._result_wrapper_lock:
+            self._result_wrapper_map[req_id] = infer_proxy
+        # dispatch request to the pool
+        self._pool_controller.add_request(req.model_id, req)
+        outputs = infer_proxy.wait_for_completion()
+        with self._result_wrapper_lock:
+            del self._result_wrapper_map[req_id]
+        return outputs
+
     def _get_strategy(self, model_id, model):
         if isinstance(model, TimerForPrediction):
             return TimerXLStrategy(model)
@@ -227,7 +267,7 @@ class InferenceManager:
                     inference_pipeline=inference_pipeline,
                     max_new_tokens=predict_length,
                 )
-                outputs = self._request_controller.process_request(infer_req)
+                outputs = self.process_request(infer_req)
                 outputs = convert_to_binary(pd.DataFrame(outputs[0]))
             else:
                 # load model
@@ -278,4 +318,8 @@ class InferenceManager:
         )
 
     def shutdown(self):
-        self._request_controller.shutdown()
+        self._stop_event.set()
+        self._pool_controller.shutdown()
+        while not self._result_queue.empty():
+            self._result_queue.get_nowait()
+        self._result_queue.close()
diff --git a/iotdb-core/ainode/ainode/core/manager/utils.py 
b/iotdb-core/ainode/ainode/core/manager/utils.py
index 5fbd444c38c..af60ff5ebe4 100644
--- a/iotdb-core/ainode/ainode/core/manager/utils.py
+++ b/iotdb-core/ainode/ainode/core/manager/utils.py
@@ -39,7 +39,7 @@ INFERENCE_EXTRA_MEMORY_RATIO = (
 )  # the overhead ratio for inference, used to estimate the pool size
 
 
-def _measure_model_memory(device: torch.device, model_id: str) -> int:
+def measure_model_memory(device: torch.device, model_id: str) -> int:
     # TODO: support CPU in the future
     # TODO: we can estimate the memory usage by running a dummy inference
     torch.cuda.empty_cache()
@@ -62,7 +62,7 @@ def _measure_model_memory(device: torch.device, model_id: 
str) -> int:
     return final
 
 
-def _evaluate_system_resources(device: torch.device) -> dict:
+def evaluate_system_resources(device: torch.device) -> dict:
     if torch.cuda.is_available():
         free_mem, total_mem = torch.cuda.mem_get_info()
         logger.info(
@@ -79,7 +79,7 @@ def _evaluate_system_resources(device: torch.device) -> dict:
         return {"device": "cpu", "free_mem": free_mem, "total_mem": total_mem}
 
 
-def _estimate_pool_size(device: torch.device, model_id: str) -> int:
+def estimate_pool_size(device: torch.device, model_id: str) -> int:
     model_info = BUILT_IN_LTSM_MAP.get(model_id, None)
     if model_info is None:
         logger.error(f"[Inference][Device-{device}] Model {model_id} not 
found")
@@ -90,7 +90,7 @@ def _estimate_pool_size(device: torch.device, model_id: str) 
-> int:
         logger.error(f"[Inference][Device-{device}] Model {model_id} not 
supported now")
         return 0
 
-    system_res = _evaluate_system_resources(device)
+    system_res = evaluate_system_resources(device)
     free_mem = system_res["free_mem"]
 
     mem_usage = MODEL_MEM_USAGE_MAP[model_type] * INFERENCE_EXTRA_MEMORY_RATIO


Reply via email to