This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch concurrent-sundial
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit d91ed56e9d2df63964f1e0f4048cea7b2e9b2057
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 9 20:15:44 2025 +0800

    seems finished and accelerated
---
 iotdb-core/ainode/ainode/core/config.py            |  17 ++
 iotdb-core/ainode/ainode/core/constant.py          |  11 +-
 .../ainode/core/inference/inference_request.py     |  61 +++---
 .../core/inference/inference_request_pool.py       | 221 +++++----------------
 iotdb-core/ainode/ainode/core/inference/utils.py   |   2 +-
 .../ainode/core/manager/inference_manager.py       |  86 +++++---
 6 files changed, 171 insertions(+), 227 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/config.py 
b/iotdb-core/ainode/ainode/core/config.py
index b4694cb9c3b..6f0336ad6ae 100644
--- a/iotdb-core/ainode/ainode/core/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -30,6 +30,7 @@ from ainode.core.constant import (
     AINODE_CONF_FILE_NAME,
     AINODE_CONF_GIT_FILE_NAME,
     AINODE_CONF_POM_FILE_NAME,
+    AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
     AINODE_INFERENCE_RPC_ADDRESS,
     AINODE_INFERENCE_RPC_PORT,
     AINODE_LOG_DIR,
@@ -55,6 +56,9 @@ class AINodeConfig(object):
         # Used for connection of DataNode/ConfigNode clients
         self._ain_inference_rpc_address: str = AINODE_INFERENCE_RPC_ADDRESS
         self._ain_inference_rpc_port: int = AINODE_INFERENCE_RPC_PORT
+        self._ain_inference_batch_interval_in_ms: int = (
+            AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
+        )
 
         # log directory
         self._ain_logs_dir: str = AINODE_LOG_DIR
@@ -132,6 +136,14 @@ class AINodeConfig(object):
     def set_ain_inference_rpc_port(self, ain_inference_rpc_port: int) -> None:
         self._ain_inference_rpc_port = ain_inference_rpc_port
 
+    def get_ain_inference_batch_interval_in_ms(self) -> int:
+        return self._ain_inference_batch_interval_in_ms
+
+    def set_ain_inference_batch_interval_in_ms(
+        self, ain_inference_batch_interval_in_ms: int
+    ) -> None:
+        self._ain_inference_batch_interval_in_ms = 
ain_inference_batch_interval_in_ms
+
     def get_ain_logs_dir(self) -> str:
         return self._ain_logs_dir
 
@@ -273,6 +285,11 @@ class AINodeDescriptor(object):
                     int(file_configs["ain_inference_rpc_port"])
                 )
 
+            if "ain_inference_batch_interval_in_ms" in config_keys:
+                self._config.set_ain_inference_batch_interval_in_ms(
+                    int(file_configs["ain_inference_batch_interval_in_ms"])
+                )
+
             if "ain_models_dir" in config_keys:
                 self._config.set_ain_models_dir(file_configs["ain_models_dir"])
 
diff --git a/iotdb-core/ainode/ainode/core/constant.py 
b/iotdb-core/ainode/ainode/core/constant.py
index c7b75103d03..1858d56ad8c 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -29,25 +29,30 @@ AINODE_CONF_FILE_NAME = "iotdb-ainode.properties"
 AINODE_CONF_GIT_FILE_NAME = "git.properties"
 AINODE_CONF_POM_FILE_NAME = "pom.properties"
 AINODE_SYSTEM_FILE_NAME = "system.properties"
+
 # inference_rpc_address
 AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1"
-AINODE_INFERENCE_RPC_PORT = 11810
+AINODE_INFERENCE_RPC_PORT = 10810
+AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
+
 # AINode folder structure
 AINODE_MODELS_DIR = "data/ainode/models"
 AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights"  # For built-in 
models, we only need to store their weights and config.
 AINODE_SYSTEM_DIR = "data/ainode/system"
 AINODE_LOG_DIR = "logs/ainode"
 AINODE_THRIFT_COMPRESSION_ENABLED = False
+
 # use for node management
-AINODE_CLUSTER_NAME = "yongzaoCluster"
+AINODE_CLUSTER_NAME = "defaultCluster"
 AINODE_VERSION_INFO = "UNKNOWN"
 AINODE_BUILD_INFO = "UNKNOWN"
 AINODE_ROOT_DIR = os.path.dirname(
     os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
 )
+
 # connect IoTDB cluster
 AINODE_CLUSTER_INGRESS_ADDRESS = "127.0.0.1"
-AINODE_CLUSTER_INGRESS_PORT = 7667
+AINODE_CLUSTER_INGRESS_PORT = 6667
 AINODE_CLUSTER_INGRESS_USERNAME = "root"
 AINODE_CLUSTER_INGRESS_PASSWORD = "root"
 AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request.py 
b/iotdb-core/ainode/ainode/core/inference/inference_request.py
index ccd3fb3e542..4dab2aee6b6 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request.py
@@ -21,6 +21,9 @@ from typing import Any
 import torch
 
 from ainode.core.inference.strategy.abstract_strategy import AbstractStrategy
+from ainode.core.log import Logger
+
+logger = Logger()
 
 
 class InferenceRequestState:
@@ -32,7 +35,7 @@ class InferenceRequestState:
 class InferenceRequest:
     def __init__(
         self,
-        req_id: int,
+        req_id: str,
         inputs: torch.Tensor,
         strategy: AbstractStrategy,
         max_new_tokens: int = 96,
@@ -41,7 +44,7 @@ class InferenceRequest:
         if inputs.ndim == 1:
             inputs = inputs.unsqueeze(0)
 
-        self.id = req_id
+        self.req_id = req_id
         self.inputs = inputs
         self.infer_kwargs = infer_kwargs
         self.strategy = strategy
@@ -59,9 +62,6 @@ class InferenceRequest:
             self.batch_size, max_new_tokens, device=device
         )  # shape: [self.batch_size, max_new_steps]
 
-        self._lock = threading.Lock()
-        self._condition = threading.Condition(self._lock)
-
     def mark_running(self):
         self.state = InferenceRequestState.RUNNING
 
@@ -75,34 +75,45 @@ class InferenceRequest:
         )
 
     def write_step_output(self, step_output: torch.Tensor):
-        with self._lock:
-            if step_output.ndim == 1:
-                step_output = step_output.unsqueeze(0)
+        if step_output.ndim == 1:
+            step_output = step_output.unsqueeze(0)
 
-            batch_size, step_size = step_output.shape
-            end_idx = self.cur_step_idx + step_size
+        batch_size, step_size = step_output.shape
+        end_idx = self.cur_step_idx + step_size
 
-            if end_idx > self.max_new_tokens:
-                self.output_tensor[:, self.cur_step_idx :] = step_output[
-                    :, : self.max_new_tokens - self.cur_step_idx
-                ]
-                self.cur_step_idx = self.max_new_tokens
-            else:
-                self.output_tensor[:, self.cur_step_idx : end_idx] = 
step_output
-                self.cur_step_idx = end_idx
+        if end_idx > self.max_new_tokens:
+            self.output_tensor[:, self.cur_step_idx :] = step_output[
+                :, : self.max_new_tokens - self.cur_step_idx
+            ]
+            self.cur_step_idx = self.max_new_tokens
+        else:
+            self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
+            self.cur_step_idx = end_idx
 
-            if self.is_finished():
-                self.mark_finished()
+        if self.is_finished():
+            self.mark_finished()
 
     def get_final_output(self) -> torch.Tensor:
-        with self._lock:
-            return self.output_tensor[:, : self.cur_step_idx]
+        return self.output_tensor[:, : self.cur_step_idx]
+
+
+class InferenceRequestProxy:
+    """
+    Wrap the raw request for handling multiprocess processing.
+    """
+
+    def __init__(self, req_id: str):
+        self.req_id = req_id
+        self.result = None
+        self._lock = threading.Lock()
+        self._condition = threading.Condition(self._lock)
 
-    def notify_completion(self):
+    def set_result(self, result: Any):
         with self._lock:
+            self.result = result
             self._condition.notify_all()
 
     def wait_for_completion(self) -> Any:
         with self._lock:
-            while self.state != InferenceRequestState.FINISHED:
-                self._condition.wait()
+            self._condition.wait()
+            return self.result
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
index b541c877f65..69757e0a7c1 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -15,53 +15,55 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import multiprocessing
-import queue
+
 import random
 import threading
 import time
-from multiprocessing import Process
 
 import numpy as np
 import torch
+import torch.multiprocessing as mp
 from transformers import PretrainedConfig, PreTrainedModel
 
+from ainode.core.config import AINodeDescriptor
 from ainode.core.inference.inference_request import InferenceRequest
 from ainode.core.log import Logger
 
 logger = Logger()
 
 
-class InferenceRequestPool(Process):
+class InferenceRequestPool(mp.Process):
     """
     The request pool to handle inference for a specific model.
     """
 
     FIX_SEED = 2021
     WAITING_INTERVAL_IN_MS = (
-        15  # How often to check for requests in the waiting/running queue
-    )
+        
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+    )  # How often to check for requests in the waiting/running queue
 
     def __init__(
         self,
+        pool_id: int,
         model: PreTrainedModel,
         config: PretrainedConfig,
-        request_queue: multiprocessing.Queue,
+        request_queue: mp.Queue,
+        result_queue: mp.Queue,
         **pool_kwargs,
     ):
         super().__init__()
+        self.pool_id = pool_id
         self.model = model
         self.device = self.model.device
         self.config = config
         self.pool_kwargs = pool_kwargs
 
         # TODO: A scheduler is necessary for better handling following queues
-        self.waiting_queue = request_queue  # Requests that are waiting to be 
processed
-        self.running_queue = (
-            queue.Queue()
-        )  # Requests that are currently being processed, TODO: we might need 
coroutine to accelerate different stages
-        self.finished_queue = queue.Queue()  # Requests that are finished
-        self._stop_event = multiprocessing.Event()
+        self._threads = []
+        self._waiting_queue = request_queue  # Requests that are waiting to be 
processed
+        self._running_queue = mp.Queue()  # Requests that are currently being 
processed
+        self._finished_queue = result_queue  # Requests that are finished
+        self._stop_event = mp.Event()
 
         # Fix inference seed
         random.seed(self.FIX_SEED)
@@ -73,12 +75,16 @@ class InferenceRequestPool(Process):
         pass
 
     def _activate_requests(self):
-        while not self.waiting_queue.empty():
-            request: InferenceRequest = self.waiting_queue.get()
-            # TODO: Check memory size before activating requests
-            request.inputs = request.strategy.preprocess_inputs(request.inputs)
-            request.mark_running()
-            self.running_queue.put(request)
+        if self._waiting_queue.empty():
+            return
+        request: InferenceRequest = self._waiting_queue.get()
+        # TODO: Check memory size before activating requests
+        request.inputs = request.strategy.preprocess_inputs(request.inputs)
+        request.mark_running()
+        logger.debug(
+            
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] 
Request is activated with inputs shape {request.inputs.shape}"
+        )
+        self._running_queue.put(request)
 
     def _requests_activate_loop(self):
         while not self._stop_event.is_set():
@@ -86,11 +92,11 @@ class InferenceRequestPool(Process):
             self._activate_requests()
 
     def _step(self):
-        if self.running_queue.empty():
+        if self._running_queue.empty():
             return
         # TODO: We need a batcher to accelerate the concurrent inference
         # TODO: Check memory size before executing requests
-        request: InferenceRequest = self.running_queue.get()
+        request: InferenceRequest = self._running_queue.get()
         output = self.model.generate(
             request.inputs,
             max_new_tokens=request.max_new_tokens,
@@ -100,170 +106,35 @@ class InferenceRequestPool(Process):
         request.write_step_output(output[0].mean(dim=0))
         request.strategy.post_decode()
         if request.is_finished():
-            self.finished_queue.put(request)
+            request.strategy.post_inference()
+            logger.debug(
+                
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] 
Request is finished"
+            )
+            self._finished_queue.put(request)
         else:
-            self.waiting_queue.put(request)
+            logger.debug(
+                
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] 
Request is not finished, re-queueing"
+            )
+            self._waiting_queue.put(request)
 
     def _requests_execute_loop(self):
         while not self._stop_event.is_set():
             time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
             self._step()
 
-    def _finish(self):
-        while not self.finished_queue.empty():
-            request: InferenceRequest = self.finished_queue.get()
-            request.strategy.post_inference()
-            request.notify_completion()
-
-    def _requests_finish_loop(self):
-        while not self._stop_event.is_set():
-            time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
-            self._finish()
-
     def run(self):
-        activate_daemon = threading.Thread(target=self._activate_requests)
-        activate_daemon.daemon = True
+        activate_daemon = threading.Thread(
+            target=self._requests_activate_loop, daemon=True
+        )
+        self._threads.append(activate_daemon)
         activate_daemon.start()
-        execute_daemon = threading.Thread(target=self._requests_execute_loop)
-        execute_daemon.daemon = True
+        execute_daemon = threading.Thread(
+            target=self._requests_execute_loop, daemon=True
+        )
+        self._threads.append(execute_daemon)
         execute_daemon.start()
-        finish_daemon = threading.Thread(target=self._requests_finish_loop)
-        finish_daemon.daemon = True
-        finish_daemon.start()
-        activate_daemon.join()
-        execute_daemon.join()
-        finish_daemon.join()
+        for thread in self._threads:
+            thread.join()
 
     def stop(self):
         self._stop_event.set()
-
-
-def pool_worker(p, done_event):
-    while not done_event.is_set():
-        p._step()
-        time.sleep(0.001)
-
-
-"""
-The following code is used to test the difference in inference speed and the 
difference in result values when using and not using requestPool
-"""
-if __name__ == "__main__":
-    config = TimerConfig()
-    config.ckpt_path = "/data/mahaoke/AINode/ainode/TimerXL/model.safetensors"
-    model = TimerForPrediction(config).eval()
-    if config.ckpt_path is not None and config.ckpt_path != "":
-        state_dict = load_file(config.ckpt_path)
-        model.load_state_dict(state_dict, strict=True)
-
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    model = model.to(device)
-
-    BATCH = 1
-    INPUT_LEN = config.input_token_len * 7  # 例如 4 × 96
-    x1 = torch.randn(BATCH, INPUT_LEN, device=device)
-    x2 = torch.randn(BATCH, INPUT_LEN, device=device)
-    x3 = torch.randn(BATCH, INPUT_LEN, device=device)
-
-    pool = InferenceRequestPool(model, config, total_memory_availble=24 * 1024)
-
-    def _always_true(self, req):
-        return True
-
-    InferenceRequestPool.memory_is_availble = _always_true
-
-    def prepare_inputs(model, x, max_new_steps: int = 96, **model_kwargs):
-        model_inputs = model.prepare_inputs_for_generation(x, **model_kwargs)
-        return model_inputs
-
-    def baseline_generate(model, inp: torch.Tensor, max_steps: int, 
**model_kwargs):
-        cur_ids = inp
-        preds = []
-        remain = max_steps
-
-        model_kwargs["attention_mask"] = 
pool.prepare_attention_mask_for_generation(inp)
-
-        batch_size, cur_len = inp.shape
-
-        model_kwargs["unfinished_sequences"] = torch.ones(
-            batch_size, dtype=torch.long, device=inp.device
-        )
-        model_kwargs["cache_position"] = torch.arange(cur_len, 
device=inp.device)
-        true_seq_len = cur_len // config.input_token_len
-        model_kwargs["attention_mask"] = model_kwargs["attention_mask"][
-            :, -true_seq_len:
-        ]
-        model_kwargs["past_key_values"] = None
-        model_kwargs["position_ids"] = None
-        model_kwargs["is_encoder_decoder"] = getattr(
-            config, "is_encoder_decoder", False
-        )
-        model_kwargs["max_output_length"] = max_steps
-
-        while remain > 0:
-            chunk = 96
-            model_inputs = prepare_inputs(model, cur_ids, max_steps, 
**model_kwargs)
-            out = model(**model_inputs)
-            # [B, chunk]
-            tok = out.logits.detach()
-            preds.append(tok.cpu())
-            cur_ids = torch.cat([cur_ids, tok.to(device)], dim=-1)
-
-            horizon_len = 96 // config.input_token_len
-            model_kwargs = pool._update_model_kwargs_for_generation(
-                out, model_kwargs, horizon_len, False
-            )
-
-            remain -= chunk
-        return torch.cat(preds, dim=-1)  # [B, max_steps]
-
-    # warm up
-    for i in range(3):
-        base_res1 = baseline_generate(model, x1, 192)
-
-    torch.cuda.synchronize()
-    t_base_start = time.perf_counter()
-    base_res1 = baseline_generate(model, x1, 192)
-    base_res2 = baseline_generate(model, x2, 192)
-    base_res3 = baseline_generate(model, x3, 192)
-    base_reses = [base_res1, base_res2, base_res3]
-    # print(f'base_reses:{base_reses}')
-    torch.cuda.synchronize()
-    t_base_end = time.perf_counter()
-    base_time = t_base_end - t_base_start
-    print(f"[Baseline]    total time: {base_time*1000:.1f} ms")
-
-    done_event = threading.Event()
-    threading.Thread(target=pool_worker, args=(pool, done_event), 
daemon=True).start()
-
-    torch.cuda.synchronize()
-    t_pool_start = time.perf_counter()
-    pool.add_request(1, x1, max_new_steps=192)
-    # time.sleep(0.010)
-    pool.add_request(2, x2, max_new_steps=192)
-    # time.sleep(0.010)
-    pool.add_request(3, x3, max_new_steps=192)
-    pool_results = []
-    while len(pool_results) < 3:
-        pool_results.append(pool.results_queue.get())
-    torch.cuda.synchronize()
-    t_pool_end = time.perf_counter()
-    pool_time = t_pool_end - t_pool_start
-    print(f"[RequestPool] total time: {pool_time*1000:.1f} ms")
-
-    done_event.set()  # stop pool
-
-    def mae(a, b):
-        return (a - b).abs().mean().item()
-
-    diff1 = mae(
-        pool_results[0][1].to("cpu"), base_reses[pool_results[0][0] - 
1].to("cpu")
-    )
-    diff2 = mae(
-        pool_results[1][1].to("cpu"), base_reses[pool_results[1][0] - 
1].to("cpu")
-    )
-    diff3 = mae(
-        pool_results[2][1].to("cpu"), base_reses[pool_results[2][0] - 
1].to("cpu")
-    )
-
-    print(f"MAE diff (req1/2/3): {diff1:.6f}, {diff2:.6f}, {diff3:.6f}")
-    print(f"Speed-up: {base_time/pool_time:.2f}× faster with RequestPool")
diff --git a/iotdb-core/ainode/ainode/core/inference/utils.py 
b/iotdb-core/ainode/ainode/core/inference/utils.py
index 04199389205..c2a618d716c 100644
--- a/iotdb-core/ainode/ainode/core/inference/utils.py
+++ b/iotdb-core/ainode/ainode/core/inference/utils.py
@@ -22,7 +22,7 @@ import torch
 from transformers.modeling_outputs import MoeCausalLMOutputWithPast
 
 
-def _generate_req_id(length=10, charset=string.ascii_letters + string.digits):
+def _generate_req_id(length=10, charset=string.ascii_letters + string.digits) 
-> str:
     """
     Generate a random req_id string of specified length.
     The length is 10 by default, with 10^{17} possible combinations.
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index a1983adef3f..deab44b4c81 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -15,21 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import multiprocessing
+import threading
+import time
 from abc import ABC, abstractmethod
 from typing import Dict, List
 
 import pandas as pd
 import torch
+import torch.multiprocessing as mp
 from iotdb.tsfile.utils.tsblock_serde import deserialize
 
+from ainode.core.config import AINodeDescriptor
 from ainode.core.constant import TSStatusCode
 from ainode.core.exception import (
     InferenceModelInternalError,
     InvalidWindowArgumentError,
     runtime_error_extractor,
 )
-from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.inference.inference_request import (
+    InferenceRequest,
+    InferenceRequestProxy,
+)
 from ainode.core.inference.inference_request_pool import InferenceRequestPool
 from ainode.core.inference.strategy.timer_sundial_strategy import 
TimerSundialStrategy
 from ainode.core.inference.utils import _generate_req_id
@@ -130,44 +136,67 @@ class InferenceManager:
     DEFAULT_DEVICE = "cpu"
     # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
     DEFAULT_POOL_SIZE = (
-        1  # TODO: Remove these parameter by sampling model inference 
consumption
+        10  # TODO: Remove these parameter by sampling model inference 
consumption
     )
+    WAITING_INTERVAL_IN_MS = (
+        
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+    )  # How often to check for requests in the result queue
 
     def __init__(self, model_manager: ModelManager):
-        self.model_manager = model_manager
+        self._model_manager = model_manager
+        self._result_queue = mp.Queue()
+        self._result_wrapper_map = {}
+        self._result_wrapper_lock = threading.RLock()
         # structure: {model_id: [(InferenceRequestPool, request_queue), ...]}
-        self.request_pool_map: Dict[
-            str, List[(InferenceRequestPool, multiprocessing.Queue)]
-        ] = {}
-        self.result_queue = multiprocessing.Queue()
+        self._request_pool_map: Dict[str, List[(InferenceRequestPool, 
mp.Queue)]] = {}
+        self._stop_event = mp.Event()
         self._init_inference_request_pool()
+        self._result_handler_thread = threading.Thread(
+            target=self._handle_results, daemon=True
+        )
+        self._result_handler_thread.start()
 
     def _init_inference_request_pool(self):
         """
         Initialize the inference request pool for each model.
         TODO: This is a temporary solution, we need a automatic algorithm to 
adjust the pool size for different models
         """
-        self.request_pool_map[self.ACCELERATE_MODEL_ID] = []
-        for _ in range(self.DEFAULT_POOL_SIZE):
-            sundial_model = self.model_manager.load_model(
+        self._request_pool_map[self.ACCELERATE_MODEL_ID] = []
+        for idx in range(self.DEFAULT_POOL_SIZE):
+            sundial_model = self._model_manager.load_model(
                 self.ACCELERATE_MODEL_ID, {}
             ).to(self.DEFAULT_DEVICE)
             sundial_config = SundialConfig()
-            request_queue = multiprocessing.Queue()
+            request_queue = mp.Queue()
             request_pool = InferenceRequestPool(
-                sundial_model, sundial_config, request_queue
+                pool_id=idx,
+                model=sundial_model,
+                config=sundial_config,
+                request_queue=request_queue,
+                result_queue=self._result_queue,
             )
             request_pool.start()
-            self.request_pool_map[self.ACCELERATE_MODEL_ID].append(
+            self._request_pool_map[self.ACCELERATE_MODEL_ID].append(
                 (request_pool, request_queue)
             )
 
+    def _handle_results(self):
+        while not self._stop_event.is_set():
+            if self._result_queue.empty():
+                time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+                continue
+            infer_req: InferenceRequest = self._result_queue.get()
+            with self._result_wrapper_lock:
+                self._result_wrapper_map[infer_req.req_id].set_result(
+                    infer_req.get_final_output()
+                )
+
     def _get_strategy(self, model_id, model):
         if isinstance(model, TimerForPrediction):
             return TimerXLStrategy(model)
         if isinstance(model, SundialForPrediction):
             return SundialStrategy(model)
-        if 
self.model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
+        if 
self._model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
             return BuiltInStrategy(model)
         return RegisteredStrategy(model)
 
@@ -181,13 +210,14 @@ class InferenceManager:
         single_output: bool,
     ):
         model_id = req.modelId
-        logger.info(f"Start processing for {model_id}")
         try:
             raw = data_getter(req)
             full_data = deserializer(raw)
             inference_attrs = extract_attrs(req)
 
             if model_id == self.ACCELERATE_MODEL_ID and self.DEFAULT_POOL_SIZE 
> 0:
+                # TODO: Logic in this branch shall handle all LTSM inferences
+                # TODO: TSBlock -> Tensor codes should be unified
                 data = full_data[1][0]
                 if data.dtype.byteorder not in ("=", "|"):
                     data = data.byteswap().newbyteorder()
@@ -198,14 +228,21 @@ class InferenceManager:
                     strategy=TimerSundialStrategy(SundialConfig()),
                     max_new_tokens=96,
                 )
-                pool_idx = hash(infer_req.id) % 
len(self.request_pool_map[model_id])
-                self.request_pool_map[model_id][pool_idx][1].put(infer_req)
-                infer_req.wait_for_completion()
-                outputs = 
convert_to_binary(pd.DataFrame(infer_req.get_final_output()))
+                infer_proxy = InferenceRequestProxy(infer_req.req_id)
+                with self._result_wrapper_lock:
+                    self._result_wrapper_map[infer_req.req_id] = infer_proxy
+                pool_idx = hash(infer_req.req_id) % len(
+                    self._request_pool_map[model_id]
+                )
+                self._request_pool_map[model_id][pool_idx][1].put(infer_req)
+                outputs = infer_proxy.wait_for_completion()
+                outputs = convert_to_binary(pd.DataFrame(outputs[0]))
+                with self._result_wrapper_lock:
+                    del self._result_wrapper_map[infer_req.req_id]
             else:
                 # load model
                 accel = str(inference_attrs.get("acceleration", "")).lower() 
== "true"
-                model = self.model_manager.load_model(model_id, 
inference_attrs, accel)
+                model = self._model_manager.load_model(model_id, 
inference_attrs, accel)
 
                 # inference by strategy
                 strategy = self._get_strategy(model_id, model)
@@ -252,10 +289,13 @@ class InferenceManager:
         )
 
     def shutdown(self):
-        for model_id, pools in self.request_pool_map.items():
+        self._stop_event.set()
+        for model_id, pools in self._request_pool_map.items():
             for requestPool, requestQueue in pools:
                 requestPool.stop()
                 while not requestQueue.empty():
                     requestQueue.get_nowait()
                 requestQueue.close()
-                requestQueue.join_thread()
+        while not self._result_queue.empty():
+            self._result_queue.get_nowait()
+        self._result_queue.close()

Reply via email to