This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new e271df7c4db [AINode] Decoupling inference manager into 
request_manager, pool_manager (#16131)
e271df7c4db is described below

commit e271df7c4db2e5a6b60f2d7d394e22a6b825f0a3
Author: Leo <[email protected]>
AuthorDate: Wed Aug 20 07:58:00 2025 +0800

    [AINode] Decoupling inference manager into request_manager, pool_manager 
(#16131)
---
 .../ainode/core/inference/inference_request.py     |   2 +
 .../core/inference/inference_request_pool.py       |   7 ++
 .../core/inference/inference_request_pool_group.py |  63 ++++++++++
 .../ainode/core/inference/pool_controller.py       | 129 +++++++++++++++++++
 .../ainode/ainode/core/inference/pool_scheduler.py | 125 ++++++++++++++++++
 .../ainode/core/inference/request_controller.py    |  89 +++++++++++++
 .../ainode/core/manager/inference_manager.py       | 140 ++-------------------
 iotdb-core/ainode/ainode/core/util/decorator.py    |  15 +++
 8 files changed, 442 insertions(+), 128 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request.py 
b/iotdb-core/ainode/ainode/core/inference/inference_request.py
index 2c45826fd26..40cede3b435 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request.py
@@ -38,6 +38,7 @@ class InferenceRequest:
     def __init__(
         self,
         req_id: str,
+        model_id: str,
         inputs: torch.Tensor,
         inference_pipeline: AbstractInferencePipeline,
         max_new_tokens: int = 96,
@@ -47,6 +48,7 @@ class InferenceRequest:
             inputs = inputs.unsqueeze(0)
 
         self.req_id = req_id
+        self.model_id = model_id
         self.inputs = inputs
         self.infer_kwargs = infer_kwargs
         self.inference_pipeline = inference_pipeline
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
index 4bf594860f7..956d0773785 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -20,6 +20,7 @@ import gc
 import random
 import threading
 import time
+from enum import Enum
 
 import numpy as np
 import torch
@@ -33,6 +34,12 @@ from ainode.core.log import Logger
 from ainode.core.manager.model_manager import ModelManager
 
 
+class PoolState(Enum):
+    INITIALIZING = "INITIALIZING"
+    RUNNING = "RUNNING"
+    STOPPING = "STOPPING"
+
+
 class InferenceRequestPool(mp.Process):
     """
     The request pool to handle inference for a specific model.
diff --git 
a/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py 
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
new file mode 100644
index 00000000000..3dc2929e388
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import Dict, Tuple
+
+import torch.multiprocessing as mp
+
+from ainode.core.exception import (
+    InferenceModelInternalError,
+)
+from ainode.core.inference.inference_request_pool import InferenceRequestPool
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class PoolGroup:
+    """
+    A group of inference request pools for a specific model.
+    """
+
+    def __init__(self, model_id):
+        self.pool_group: Dict[int, Tuple[InferenceRequestPool, mp.Queue]] = {}
+        self.model_id = model_id
+
+    def get_pool_group(self) -> Dict[int, Tuple[InferenceRequestPool, 
mp.Queue]]:
+        return self.pool_group
+
+    def add_pool(
+        self, pool_id: int, request_pool: InferenceRequestPool, request_queue: 
mp.Queue
+    ):
+        self.pool_group[pool_id] = (request_pool, request_queue)
+
+    def get_pool_ids(self) -> list[int]:
+        return list(self.pool_group.keys())
+
+    def get_request_pool(self, pool_id) -> InferenceRequestPool:
+        if pool_id not in self.pool_group:
+            raise InferenceModelInternalError(
+                f"Pool ID {pool_id} not found for model {self.model_id}"
+            )
+        return self.pool_group[pool_id][0]
+
+    def get_request_queue(self, pool_id) -> mp.Queue:
+        if pool_id not in self.pool_group:
+            raise InferenceModelInternalError(
+                f"Pool ID {pool_id} not found for model {self.model_id}"
+            )
+        return self.pool_group[pool_id][1]
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_controller.py 
b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
new file mode 100644
index 00000000000..121d3023df9
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from collections import defaultdict
+from typing import Dict, Optional
+
+import torch
+import torch.multiprocessing as mp
+
+from ainode.core.exception import (
+    InferenceModelInternalError,
+)
+from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.inference.inference_request_pool import InferenceRequestPool, 
PoolState
+from ainode.core.inference.inference_request_pool_group import PoolGroup
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class PoolController:
+    """
+    A controller for handling inference request pools.
+    It handles the registration of pools, adding and removing requests,
+    and managing the state of each pool.
+    """
+
+    DEFAULT_DEVICE = torch.device("cpu")
+    # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
+
+    def __init__(self):
+        # structure: {model_id: {pool_id: PoolState}}
+        self.pool_states: Dict[str, Dict[int, PoolState]] = defaultdict(dict)
+        # structure: {model_id: PoolGroup}
+        self._request_pool_map: Dict[str, PoolGroup] = {}
+
+    def dispatch_request(self, model_id, req: InferenceRequest):
+        pool_idx = self._select_pool_by_hash(model_id, req.req_id)
+        self.add_request(pool_idx, req)
+        logger.debug(
+            
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{req.req_id}] 
Request is queued for inference"
+        )
+
+    def _select_pool_by_hash(self, model_id, req_id) -> int:
+        pool_ids = self.get_pool_ids(model_id)
+        if not pool_ids:
+            raise InferenceModelInternalError(
+                f"No available pools for model {model_id}"
+            )
+        start_idx = hash(req_id) % len(pool_ids)
+        n = len(pool_ids)
+        for i in range(n):
+            pool_id = pool_ids[(start_idx + i) % n]
+            state = self.get_state(model_id, pool_id)
+            if state == PoolState.RUNNING:
+                return pool_id
+        raise InferenceModelInternalError(
+            f"No RUNNING pools available for model {model_id}"
+        )
+
+    def register_pool(self, model_id, pool_id, request_pool, request_queue):
+        self.set_state(model_id, pool_id, PoolState.RUNNING)
+        self.set_request_pool_map(model_id, pool_id, request_pool, 
request_queue)
+
+    def add_request(self, pool_id, req):
+        req_q = self.get_request_queue(req.model_id, pool_id)
+        req_q.put(req)
+
+    def remove_request(self, model_id, req_id):
+        pass
+
+    def get_pool_ids(self, model_id) -> list[int]:
+        return self._request_pool_map[model_id].get_pool_ids()
+
+    def has_request_pools(self, model_id) -> bool:
+        return model_id in self._request_pool_map
+
+    def get_request_pool_map(self) -> Dict[str, PoolGroup]:
+        return self._request_pool_map
+
+    def get_request_pools_group(self, model_id) -> Optional[PoolGroup]:
+        return self._request_pool_map.get(model_id, None)
+
+    def get_request_pool(self, model_id, pool_id) -> InferenceRequestPool:
+        return self._request_pool_map[model_id].get_request_pool(pool_id)
+
+    def get_request_queue(self, model_id, pool_id) -> mp.Queue:
+        return self._request_pool_map[model_id].get_request_queue(pool_id)
+
+    def set_request_pool_map(self, model_id, pool_id, request_pool, 
request_queue):
+        if model_id not in self._request_pool_map:
+            self._request_pool_map[model_id] = PoolGroup(model_id)
+        self._request_pool_map[model_id].add_pool(pool_id, request_pool, 
request_queue)
+
+    def get_state(self, model_id, pool_id) -> PoolState:
+        return self.pool_states[model_id][pool_id]
+
+    def set_state(self, model_id, pool_id, state):
+        self.pool_states[model_id][pool_id] = state
+
+    def get_load(self, model_id, pool_id) -> int:
+        pass
+
+    def shutdown(self):
+        for model_id, pool_group in self._request_pool_map.items():
+            for pool_id in pool_group.get_pool_ids():
+                request_pool = pool_group.get_request_pool(pool_id)
+                request_queue = pool_group.get_request_queue(pool_id)
+                request_pool.stop()
+                while not request_queue.empty():
+                    request_queue.get_nowait()
+                request_queue.close()
+            for pool_id in pool_group.get_pool_ids():
+                request_pool = pool_group.get_request_pool(pool_id)
+                request_pool.join(timeout=10)
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py 
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
new file mode 100644
index 00000000000..53aa329c494
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import threading
+
+import torch
+import torch.multiprocessing as mp
+
+from ainode.core.exception import (
+    InferenceModelInternalError,
+)
+from ainode.core.inference.inference_request_pool import InferenceRequestPool, 
PoolState
+from ainode.core.inference.pool_controller import PoolController
+from ainode.core.log import Logger
+from ainode.core.manager.utils import (
+    _estimate_pool_size,
+)
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
+from ainode.core.model.timerxl.configuration_timer import TimerConfig
+from ainode.core.util.decorator import synchronized
+
+logger = Logger()
+
+
+class PoolScheduler:
+    """
+    A Scheduler to init the request pools.
+    It initializes the first pool and starts a background thread to expand 
pools
+    as needed based on the model_id.
+    """
+
+    DEFAULT_DEVICE = torch.device("cpu")
+    # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
+
+    def __init__(self, pool_controller: PoolController, result_queue: 
mp.Queue):
+        self._pool_controller = pool_controller
+        self._result_queue = result_queue
+
+    @synchronized(threading.Lock())
+    def first_req_init(self, model_id: str):
+        if not self._pool_controller.has_request_pools(model_id):
+            pool_num = _estimate_pool_size(self.DEFAULT_DEVICE, model_id)
+            if pool_num <= 0:
+                raise InferenceModelInternalError(
+                    f"Not enough memory to run model {model_id}."
+                )
+            # initialize the first pool
+            self._first_pool_init(model_id)
+            # start a background thread to expand pools
+            expand_thread = threading.Thread(
+                target=self._expand_pools,
+                args=(model_id, 1, pool_num - 1),
+                daemon=True,
+            )
+            expand_thread.start()
+
+    def _first_pool_init(self, model_id: str):
+        if model_id == "sundial":
+            config = SundialConfig()
+        elif model_id == "timer_xl":
+            config = TimerConfig()
+        first_queue = mp.Queue()
+        ready_event = mp.Event()
+        first_pool = InferenceRequestPool(
+            pool_id=0,
+            model_id=model_id,
+            config=config,
+            request_queue=first_queue,
+            result_queue=self._result_queue,
+            ready_event=ready_event,
+        )
+        first_pool.start()
+        self._pool_controller.set_state(model_id, 0, PoolState.INITIALIZING)
+        if not ready_event.wait(timeout=30):
+            logger.error(
+                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool 
failed to be ready in time"
+            )
+        else:
+            self._pool_controller.register_pool(model_id, 0, first_pool, 
first_queue)
+            logger.info(
+                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] 
Initialized inference request pool for model {model_id}"
+            )
+
+    def _expand_pools(self, model_id, start_idx, count):
+        for idx in range(count):
+            queue = mp.Queue()
+            pool_id = start_idx + idx
+            if model_id == "sundial":
+                config = SundialConfig()
+            elif model_id == "timer_xl":
+                config = TimerConfig()
+            pool = InferenceRequestPool(
+                pool_id=pool_id,
+                model_id=model_id,
+                config=config,
+                request_queue=queue,
+                result_queue=self._result_queue,
+                ready_event=mp.Event(),
+            )
+            pool.start()
+            self._pool_controller.set_state(model_id, pool_id, 
PoolState.INITIALIZING)
+            if not pool.ready_event.wait(timeout=30):
+                logger.error(
+                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_id}] Pool failed to be 
ready in time"
+                )
+                continue
+            else:
+                self._pool_controller.register_pool(model_id, pool_id, pool, 
queue)
+                logger.info(
+                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference 
request pool started for model {model_id}"
+                )
diff --git a/iotdb-core/ainode/ainode/core/inference/request_controller.py 
b/iotdb-core/ainode/ainode/core/inference/request_controller.py
new file mode 100644
index 00000000000..bde513a2813
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/request_controller.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import threading
+import time
+
+import torch.multiprocessing as mp
+
+from ainode.core.config import AINodeDescriptor
+from ainode.core.inference.inference_request import (
+    InferenceRequest,
+    InferenceRequestProxy,
+)
+from ainode.core.inference.pool_controller import PoolController
+from ainode.core.inference.pool_scheduler import PoolScheduler
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class RequestController:
+    """
+    Controls the lifecycle and scheduling of inference requests.
+    """
+
+    WAITING_INTERVAL_IN_MS = (
+        
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+    )  # How often to check for requests in the result queue
+
+    def __init__(self):
+        self._result_queue = mp.Queue()
+        self._result_wrapper_map = {}
+        self._result_wrapper_lock = threading.RLock()
+
+        self._stop_event = mp.Event()
+        self._result_handler_thread = threading.Thread(
+            target=self._handle_results, daemon=True
+        )
+        self._result_handler_thread.start()
+        self._pool_controller = PoolController()
+        self._pool_scheduler = PoolScheduler(self._pool_controller, 
self._result_queue)
+
+    def _handle_results(self):
+        while not self._stop_event.is_set():
+            if self._result_queue.empty():
+                time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+                continue
+            infer_req: InferenceRequest = self._result_queue.get()
+            self._pool_controller.remove_request(infer_req.model_id, 
infer_req.req_id)
+            with self._result_wrapper_lock:
+                self._result_wrapper_map[infer_req.req_id].set_result(
+                    infer_req.get_final_output()
+                )
+
+    def process_request(self, req):
+        infer_proxy = InferenceRequestProxy(req.req_id)
+        with self._result_wrapper_lock:
+            self._result_wrapper_map[req.req_id] = infer_proxy
+        # lazy initialization for first request
+        model_id = req.model_id
+        if not self._pool_controller.has_request_pools(model_id):
+            self._pool_scheduler.first_req_init(model_id)
+        # dispatch request to the pool
+        self._pool_controller.dispatch_request(model_id, req)
+        outputs = infer_proxy.wait_for_completion()
+        with self._result_wrapper_lock:
+            del self._result_wrapper_map[req.req_id]
+        return outputs
+
+    def shutdown(self):
+        self._stop_event.set()
+        self._pool_controller.shutdown()
+        while not self._result_queue.empty():
+            self._result_queue.get_nowait()
+        self._result_queue.close()
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 19d0cc340b6..a7f10b8b389 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -15,16 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import gc
-import threading
-import time
+
 from abc import ABC, abstractmethod
-from typing import Dict, List
+from typing import Dict
 
 import pandas as pd
-import psutil
 import torch
-import torch.multiprocessing as mp
 from iotdb.tsfile.utils.tsblock_serde import deserialize
 
 from ainode.core.config import AINodeDescriptor
@@ -37,9 +33,8 @@ from ainode.core.exception import (
 )
 from ainode.core.inference.inference_request import (
     InferenceRequest,
-    InferenceRequestProxy,
 )
-from ainode.core.inference.inference_request_pool import InferenceRequestPool
+from ainode.core.inference.request_controller import RequestController
 from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
     TimerSundialInferencePipeline,
 )
@@ -50,7 +45,6 @@ from ainode.core.inference.utils import generate_req_id
 from ainode.core.log import Logger
 from ainode.core.manager.model_manager import ModelManager
 from ainode.core.manager.utils import (
-    _estimate_pool_size,
     _measure_model_memory,
 )
 from ainode.core.model.sundial.configuration_sundial import SundialConfig
@@ -146,27 +140,15 @@ class RegisteredStrategy(InferenceStrategy):
 
 class InferenceManager:
     ACCELERATE_MODEL_ID = ["sundial", "timer_xl"]
-    DEFAULT_DEVICE = "cpu"
+    DEFAULT_DEVICE = torch.device("cpu")
     # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
-    WAITING_INTERVAL_IN_MS = (
-        
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
-    )  # How often to check for requests in the result queue
 
     def __init__(self):
         self._model_manager = ModelManager()
-        self._result_queue = mp.Queue()
-        self._result_wrapper_map = {}
-        self._result_wrapper_lock = threading.RLock()
-        # structure: {model_id: [(InferenceRequestPool, request_queue), ...]}
-        self._request_pool_map: Dict[str, List[(InferenceRequestPool, 
mp.Queue)]] = {}
-        self._stop_event = mp.Event()
-        self._result_handler_thread = threading.Thread(
-            target=self._handle_results, daemon=True
-        )
-        self._result_handler_thread.start()
         self._model_mem_usage_map: Dict[str, int] = (
             {}
         )  # store model memory usage for each model
+        self._request_controller = RequestController()
         # self._preload_model_benchmarks()
 
     def _preload_model_benchmarks(self):
@@ -182,70 +164,6 @@ class InferenceManager:
                 f"[Inference] Skipped preloading benchmarks for 
{self.DEFAULT_DEVICE}, only supports CUDA currently"
             )
 
-    def _first_pool_init(self, model_id: str):
-        if model_id == "sundial":
-            config = SundialConfig()
-        elif model_id == "timer_xl":
-            config = TimerConfig()
-        first_queue = mp.Queue()
-        ready_event = mp.Event()
-        first_pool = InferenceRequestPool(
-            pool_id=0,
-            model_id=model_id,
-            config=config,
-            request_queue=first_queue,
-            result_queue=self._result_queue,
-            ready_event=ready_event,
-        )
-        first_pool.start()
-        if not ready_event.wait(timeout=30):
-            logger.error(
-                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool 
failed to be ready in time"
-            )
-        else:
-            self._request_pool_map[model_id] = [(first_pool, first_queue)]
-            logger.info(
-                f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] 
Initialized inference request pool for model {model_id}"
-            )
-
-    def _expand_pools(self, model_id, start_idx, count):
-        for idx in range(count):
-            queue = mp.Queue()
-            if model_id == "sundial":
-                config = SundialConfig()
-            elif model_id == "timer_xl":
-                config = TimerConfig()
-            pool = InferenceRequestPool(
-                pool_id=start_idx + idx,
-                model_id=model_id,
-                config=config,
-                request_queue=queue,
-                result_queue=self._result_queue,
-                ready_event=mp.Event(),
-            )
-            pool.start()
-            if not pool.ready_event.wait(timeout=30):
-                logger.error(
-                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{start_idx + idx}] Pool failed 
to be ready in time"
-                )
-                continue
-            else:
-                self._request_pool_map[model_id].append((pool, queue))
-                logger.info(
-                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference 
request pool started for model {model_id}"
-                )
-
-    def _handle_results(self):
-        while not self._stop_event.is_set():
-            if self._result_queue.empty():
-                time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
-                continue
-            infer_req: InferenceRequest = self._result_queue.get()
-            with self._result_wrapper_lock:
-                self._result_wrapper_map[infer_req.req_id].set_result(
-                    infer_req.get_final_output()
-                )
-
     def _get_strategy(self, model_id, model):
         if isinstance(model, TimerForPrediction):
             return TimerXLStrategy(model)
@@ -287,22 +205,6 @@ class InferenceManager:
             if model_id in self.ACCELERATE_MODEL_ID and "cuda" in str(
                 self.DEFAULT_DEVICE
             ):
-                # lazy initialization for first request
-                if model_id not in self._request_pool_map:
-                    pool_num = _estimate_pool_size(self.DEFAULT_DEVICE, 
model_id)
-                    if pool_num <= 0:
-                        raise InferenceModelInternalError(
-                            f"Not enough memory to run model {model_id}."
-                        )
-                    # initialize the first pool
-                    self._first_pool_init(model_id)
-                    # start a background thread to expand pools
-                    expand_thread = threading.Thread(
-                        target=self._expand_pools,
-                        args=(model_id, 1, pool_num - 1),
-                        daemon=True,
-                    )
-                    expand_thread.start()
                 # TODO: Logic in this branch shall handle all LTSM inferences
                 # TODO: TSBlock -> Tensor codes should be unified
                 data = full_data[1][0]
@@ -314,26 +216,19 @@ class InferenceManager:
                     inference_pipeline = 
TimerSundialInferencePipeline(SundialConfig())
                 elif model_id == "timer_xl":
                     inference_pipeline = 
TimerXLInferencePipeline(TimerConfig())
+                else:
+                    raise InferenceModelInternalError(
+                        f"Unsupported model_id: {model_id}"
+                    )
                 infer_req = InferenceRequest(
                     req_id=generate_req_id(),
+                    model_id=model_id,
                     inputs=inputs,
                     inference_pipeline=inference_pipeline,
                     max_new_tokens=predict_length,
                 )
-                infer_proxy = InferenceRequestProxy(infer_req.req_id)
-                with self._result_wrapper_lock:
-                    self._result_wrapper_map[infer_req.req_id] = infer_proxy
-                pool_idx = hash(infer_req.req_id) % len(
-                    self._request_pool_map[model_id]
-                )
-                self._request_pool_map[model_id][pool_idx][1].put(infer_req)
-                logger.debug(
-                    
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{infer_req.req_id}]
 Request is queued for inference"
-                )
-                outputs = infer_proxy.wait_for_completion()
+                outputs = self._request_controller.process_request(infer_req)
                 outputs = convert_to_binary(pd.DataFrame(outputs[0]))
-                with self._result_wrapper_lock:
-                    del self._result_wrapper_map[infer_req.req_id]
             else:
                 # load model
                 accel = str(inference_attrs.get("acceleration", "")).lower() 
== "true"
@@ -383,15 +278,4 @@ class InferenceManager:
         )
 
     def shutdown(self):
-        self._stop_event.set()
-        for model_id, pools in self._request_pool_map.items():
-            for requestPool, requestQueue in pools:
-                requestPool.stop()
-                while not requestQueue.empty():
-                    requestQueue.get_nowait()
-                requestQueue.close()
-            for requestPool, _ in pools:
-                requestPool.join(timeout=10)
-        while not self._result_queue.empty():
-            self._result_queue.get_nowait()
-        self._result_queue.close()
+        self._request_controller.shutdown()
diff --git a/iotdb-core/ainode/ainode/core/util/decorator.py 
b/iotdb-core/ainode/ainode/core/util/decorator.py
index 33b9f4835ac..5a84c3d6bb2 100644
--- a/iotdb-core/ainode/ainode/core/util/decorator.py
+++ b/iotdb-core/ainode/ainode/core/util/decorator.py
@@ -15,6 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from functools import wraps
+
+
 def singleton(cls):
     instances = {}
 
@@ -24,3 +27,15 @@ def singleton(cls):
         return instances[cls]
 
     return get_instance
+
+
+def synchronized(lock):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            with lock:
+                return func(*args, **kwargs)
+
+        return wrapper
+
+    return decorator

Reply via email to