This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new e271df7c4db [AINode] Decoupling inference manager into
request_manager, pool_manager (#16131)
e271df7c4db is described below
commit e271df7c4db2e5a6b60f2d7d394e22a6b825f0a3
Author: Leo <[email protected]>
AuthorDate: Wed Aug 20 07:58:00 2025 +0800
[AINode] Decoupling inference manager into request_manager, pool_manager
(#16131)
---
.../ainode/core/inference/inference_request.py | 2 +
.../core/inference/inference_request_pool.py | 7 ++
.../core/inference/inference_request_pool_group.py | 63 ++++++++++
.../ainode/core/inference/pool_controller.py | 129 +++++++++++++++++++
.../ainode/ainode/core/inference/pool_scheduler.py | 125 ++++++++++++++++++
.../ainode/core/inference/request_controller.py | 89 +++++++++++++
.../ainode/core/manager/inference_manager.py | 140 ++-------------------
iotdb-core/ainode/ainode/core/util/decorator.py | 15 +++
8 files changed, 442 insertions(+), 128 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request.py
b/iotdb-core/ainode/ainode/core/inference/inference_request.py
index 2c45826fd26..40cede3b435 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request.py
@@ -38,6 +38,7 @@ class InferenceRequest:
def __init__(
self,
req_id: str,
+ model_id: str,
inputs: torch.Tensor,
inference_pipeline: AbstractInferencePipeline,
max_new_tokens: int = 96,
@@ -47,6 +48,7 @@ class InferenceRequest:
inputs = inputs.unsqueeze(0)
self.req_id = req_id
+ self.model_id = model_id
self.inputs = inputs
self.infer_kwargs = infer_kwargs
self.inference_pipeline = inference_pipeline
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
index 4bf594860f7..956d0773785 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -20,6 +20,7 @@ import gc
import random
import threading
import time
+from enum import Enum
import numpy as np
import torch
@@ -33,6 +34,12 @@ from ainode.core.log import Logger
from ainode.core.manager.model_manager import ModelManager
+class PoolState(Enum):
+ INITIALIZING = "INITIALIZING"
+ RUNNING = "RUNNING"
+ STOPPING = "STOPPING"
+
+
class InferenceRequestPool(mp.Process):
"""
The request pool to handle inference for a specific model.
diff --git
a/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
new file mode 100644
index 00000000000..3dc2929e388
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool_group.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from typing import Dict, Tuple
+
+import torch.multiprocessing as mp
+
+from ainode.core.exception import (
+ InferenceModelInternalError,
+)
+from ainode.core.inference.inference_request_pool import InferenceRequestPool
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class PoolGroup:
+ """
+ A group of inference request pools for a specific model.
+ """
+
+ def __init__(self, model_id):
+ self.pool_group: Dict[int, Tuple[InferenceRequestPool, mp.Queue]] = {}
+ self.model_id = model_id
+
+ def get_pool_group(self) -> Dict[int, Tuple[InferenceRequestPool,
mp.Queue]]:
+ return self.pool_group
+
+ def add_pool(
+ self, pool_id: int, request_pool: InferenceRequestPool, request_queue:
mp.Queue
+ ):
+ self.pool_group[pool_id] = (request_pool, request_queue)
+
+ def get_pool_ids(self) -> list[int]:
+ return list(self.pool_group.keys())
+
+ def get_request_pool(self, pool_id) -> InferenceRequestPool:
+ if pool_id not in self.pool_group:
+ raise InferenceModelInternalError(
+ f"Pool ID {pool_id} not found for model {self.model_id}"
+ )
+ return self.pool_group[pool_id][0]
+
+ def get_request_queue(self, pool_id) -> mp.Queue:
+ if pool_id not in self.pool_group:
+ raise InferenceModelInternalError(
+ f"Pool ID {pool_id} not found for model {self.model_id}"
+ )
+ return self.pool_group[pool_id][1]
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_controller.py
b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
new file mode 100644
index 00000000000..121d3023df9
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/pool_controller.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from collections import defaultdict
+from typing import Dict, Optional
+
+import torch
+import torch.multiprocessing as mp
+
+from ainode.core.exception import (
+ InferenceModelInternalError,
+)
+from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.inference.inference_request_pool import InferenceRequestPool,
PoolState
+from ainode.core.inference.inference_request_pool_group import PoolGroup
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class PoolController:
+ """
+ A controller for handling inference request pools.
+ It handles the registration of pools, adding and removing requests,
+ and managing the state of each pool.
+ """
+
+ DEFAULT_DEVICE = torch.device("cpu")
+ # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
+
+ def __init__(self):
+ # structure: {model_id: {pool_id: PoolState}}
+ self.pool_states: Dict[str, Dict[int, PoolState]] = defaultdict(dict)
+ # structure: {model_id: PoolGroup}
+ self._request_pool_map: Dict[str, PoolGroup] = {}
+
+ def dispatch_request(self, model_id, req: InferenceRequest):
+ pool_idx = self._select_pool_by_hash(model_id, req.req_id)
+ self.add_request(pool_idx, req)
+ logger.debug(
+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{req.req_id}]
Request is queued for inference"
+ )
+
+ def _select_pool_by_hash(self, model_id, req_id) -> int:
+ pool_ids = self.get_pool_ids(model_id)
+ if not pool_ids:
+ raise InferenceModelInternalError(
+ f"No available pools for model {model_id}"
+ )
+ start_idx = hash(req_id) % len(pool_ids)
+ n = len(pool_ids)
+ for i in range(n):
+ pool_id = pool_ids[(start_idx + i) % n]
+ state = self.get_state(model_id, pool_id)
+ if state == PoolState.RUNNING:
+ return pool_id
+ raise InferenceModelInternalError(
+ f"No RUNNING pools available for model {model_id}"
+ )
+
+ def register_pool(self, model_id, pool_id, request_pool, request_queue):
+ self.set_state(model_id, pool_id, PoolState.RUNNING)
+ self.set_request_pool_map(model_id, pool_id, request_pool,
request_queue)
+
+ def add_request(self, pool_id, req):
+ req_q = self.get_request_queue(req.model_id, pool_id)
+ req_q.put(req)
+
+ def remove_request(self, model_id, req_id):
+ pass
+
+ def get_pool_ids(self, model_id) -> list[int]:
+ return self._request_pool_map[model_id].get_pool_ids()
+
+ def has_request_pools(self, model_id) -> bool:
+ return model_id in self._request_pool_map
+
+ def get_request_pool_map(self) -> Dict[str, PoolGroup]:
+ return self._request_pool_map
+
+ def get_request_pools_group(self, model_id) -> Optional[PoolGroup]:
+ return self._request_pool_map.get(model_id, None)
+
+ def get_request_pool(self, model_id, pool_id) -> InferenceRequestPool:
+ return self._request_pool_map[model_id].get_request_pool(pool_id)
+
+ def get_request_queue(self, model_id, pool_id) -> mp.Queue:
+ return self._request_pool_map[model_id].get_request_queue(pool_id)
+
+ def set_request_pool_map(self, model_id, pool_id, request_pool,
request_queue):
+ if model_id not in self._request_pool_map:
+ self._request_pool_map[model_id] = PoolGroup(model_id)
+ self._request_pool_map[model_id].add_pool(pool_id, request_pool,
request_queue)
+
+ def get_state(self, model_id, pool_id) -> PoolState:
+ return self.pool_states[model_id][pool_id]
+
+ def set_state(self, model_id, pool_id, state):
+ self.pool_states[model_id][pool_id] = state
+
+ def get_load(self, model_id, pool_id) -> int:
+ pass
+
+ def shutdown(self):
+ for model_id, pool_group in self._request_pool_map.items():
+ for pool_id in pool_group.get_pool_ids():
+ request_pool = pool_group.get_request_pool(pool_id)
+ request_queue = pool_group.get_request_queue(pool_id)
+ request_pool.stop()
+ while not request_queue.empty():
+ request_queue.get_nowait()
+ request_queue.close()
+ for pool_id in pool_group.get_pool_ids():
+ request_pool = pool_group.get_request_pool(pool_id)
+ request_pool.join(timeout=10)
diff --git a/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
b/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
new file mode 100644
index 00000000000..53aa329c494
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/pool_scheduler.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import threading
+
+import torch
+import torch.multiprocessing as mp
+
+from ainode.core.exception import (
+ InferenceModelInternalError,
+)
+from ainode.core.inference.inference_request_pool import InferenceRequestPool,
PoolState
+from ainode.core.inference.pool_controller import PoolController
+from ainode.core.log import Logger
+from ainode.core.manager.utils import (
+ _estimate_pool_size,
+)
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
+from ainode.core.model.timerxl.configuration_timer import TimerConfig
+from ainode.core.util.decorator import synchronized
+
+logger = Logger()
+
+
+class PoolScheduler:
+ """
+ A Scheduler to init the request pools.
+ It initializes the first pool and starts a background thread to expand
pools
+ as needed based on the model_id.
+ """
+
+ DEFAULT_DEVICE = torch.device("cpu")
+ # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
+
+ def __init__(self, pool_controller: PoolController, result_queue:
mp.Queue):
+ self._pool_controller = pool_controller
+ self._result_queue = result_queue
+
+ @synchronized(threading.Lock())
+ def first_req_init(self, model_id: str):
+ if not self._pool_controller.has_request_pools(model_id):
+ pool_num = _estimate_pool_size(self.DEFAULT_DEVICE, model_id)
+ if pool_num <= 0:
+ raise InferenceModelInternalError(
+ f"Not enough memory to run model {model_id}."
+ )
+ # initialize the first pool
+ self._first_pool_init(model_id)
+ # start a background thread to expand pools
+ expand_thread = threading.Thread(
+ target=self._expand_pools,
+ args=(model_id, 1, pool_num - 1),
+ daemon=True,
+ )
+ expand_thread.start()
+
+ def _first_pool_init(self, model_id: str):
+ if model_id == "sundial":
+ config = SundialConfig()
+ elif model_id == "timer_xl":
+ config = TimerConfig()
+ first_queue = mp.Queue()
+ ready_event = mp.Event()
+ first_pool = InferenceRequestPool(
+ pool_id=0,
+ model_id=model_id,
+ config=config,
+ request_queue=first_queue,
+ result_queue=self._result_queue,
+ ready_event=ready_event,
+ )
+ first_pool.start()
+ self._pool_controller.set_state(model_id, 0, PoolState.INITIALIZING)
+ if not ready_event.wait(timeout=30):
+ logger.error(
+ f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool
failed to be ready in time"
+ )
+ else:
+ self._pool_controller.register_pool(model_id, 0, first_pool,
first_queue)
+ logger.info(
+ f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0]
Initialized inference request pool for model {model_id}"
+ )
+
+ def _expand_pools(self, model_id, start_idx, count):
+ for idx in range(count):
+ queue = mp.Queue()
+ pool_id = start_idx + idx
+ if model_id == "sundial":
+ config = SundialConfig()
+ elif model_id == "timer_xl":
+ config = TimerConfig()
+ pool = InferenceRequestPool(
+ pool_id=pool_id,
+ model_id=model_id,
+ config=config,
+ request_queue=queue,
+ result_queue=self._result_queue,
+ ready_event=mp.Event(),
+ )
+ pool.start()
+ self._pool_controller.set_state(model_id, pool_id,
PoolState.INITIALIZING)
+ if not pool.ready_event.wait(timeout=30):
+ logger.error(
+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_id}] Pool failed to be
ready in time"
+ )
+ continue
+ else:
+ self._pool_controller.register_pool(model_id, pool_id, pool,
queue)
+ logger.info(
+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference
request pool started for model {model_id}"
+ )
diff --git a/iotdb-core/ainode/ainode/core/inference/request_controller.py
b/iotdb-core/ainode/ainode/core/inference/request_controller.py
new file mode 100644
index 00000000000..bde513a2813
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/request_controller.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import threading
+import time
+
+import torch.multiprocessing as mp
+
+from ainode.core.config import AINodeDescriptor
+from ainode.core.inference.inference_request import (
+ InferenceRequest,
+ InferenceRequestProxy,
+)
+from ainode.core.inference.pool_controller import PoolController
+from ainode.core.inference.pool_scheduler import PoolScheduler
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class RequestController:
+ """
+ Controls the lifecycle and scheduling of inference requests.
+ """
+
+ WAITING_INTERVAL_IN_MS = (
+
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+ ) # How often to check for requests in the result queue
+
+ def __init__(self):
+ self._result_queue = mp.Queue()
+ self._result_wrapper_map = {}
+ self._result_wrapper_lock = threading.RLock()
+
+ self._stop_event = mp.Event()
+ self._result_handler_thread = threading.Thread(
+ target=self._handle_results, daemon=True
+ )
+ self._result_handler_thread.start()
+ self._pool_controller = PoolController()
+ self._pool_scheduler = PoolScheduler(self._pool_controller,
self._result_queue)
+
+ def _handle_results(self):
+ while not self._stop_event.is_set():
+ if self._result_queue.empty():
+ time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+ continue
+ infer_req: InferenceRequest = self._result_queue.get()
+ self._pool_controller.remove_request(infer_req.model_id,
infer_req.req_id)
+ with self._result_wrapper_lock:
+ self._result_wrapper_map[infer_req.req_id].set_result(
+ infer_req.get_final_output()
+ )
+
+ def process_request(self, req):
+ infer_proxy = InferenceRequestProxy(req.req_id)
+ with self._result_wrapper_lock:
+ self._result_wrapper_map[req.req_id] = infer_proxy
+ # lazy initialization for first request
+ model_id = req.model_id
+ if not self._pool_controller.has_request_pools(model_id):
+ self._pool_scheduler.first_req_init(model_id)
+ # dispatch request to the pool
+ self._pool_controller.dispatch_request(model_id, req)
+ outputs = infer_proxy.wait_for_completion()
+ with self._result_wrapper_lock:
+ del self._result_wrapper_map[req.req_id]
+ return outputs
+
+ def shutdown(self):
+ self._stop_event.set()
+ self._pool_controller.shutdown()
+ while not self._result_queue.empty():
+ self._result_queue.get_nowait()
+ self._result_queue.close()
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 19d0cc340b6..a7f10b8b389 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -15,16 +15,12 @@
# specific language governing permissions and limitations
# under the License.
#
-import gc
-import threading
-import time
+
from abc import ABC, abstractmethod
-from typing import Dict, List
+from typing import Dict
import pandas as pd
-import psutil
import torch
-import torch.multiprocessing as mp
from iotdb.tsfile.utils.tsblock_serde import deserialize
from ainode.core.config import AINodeDescriptor
@@ -37,9 +33,8 @@ from ainode.core.exception import (
)
from ainode.core.inference.inference_request import (
InferenceRequest,
- InferenceRequestProxy,
)
-from ainode.core.inference.inference_request_pool import InferenceRequestPool
+from ainode.core.inference.request_controller import RequestController
from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
TimerSundialInferencePipeline,
)
@@ -50,7 +45,6 @@ from ainode.core.inference.utils import generate_req_id
from ainode.core.log import Logger
from ainode.core.manager.model_manager import ModelManager
from ainode.core.manager.utils import (
- _estimate_pool_size,
_measure_model_memory,
)
from ainode.core.model.sundial.configuration_sundial import SundialConfig
@@ -146,27 +140,15 @@ class RegisteredStrategy(InferenceStrategy):
class InferenceManager:
ACCELERATE_MODEL_ID = ["sundial", "timer_xl"]
- DEFAULT_DEVICE = "cpu"
+ DEFAULT_DEVICE = torch.device("cpu")
# DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
- WAITING_INTERVAL_IN_MS = (
-
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
- ) # How often to check for requests in the result queue
def __init__(self):
self._model_manager = ModelManager()
- self._result_queue = mp.Queue()
- self._result_wrapper_map = {}
- self._result_wrapper_lock = threading.RLock()
- # structure: {model_id: [(InferenceRequestPool, request_queue), ...]}
- self._request_pool_map: Dict[str, List[(InferenceRequestPool,
mp.Queue)]] = {}
- self._stop_event = mp.Event()
- self._result_handler_thread = threading.Thread(
- target=self._handle_results, daemon=True
- )
- self._result_handler_thread.start()
self._model_mem_usage_map: Dict[str, int] = (
{}
) # store model memory usage for each model
+ self._request_controller = RequestController()
# self._preload_model_benchmarks()
def _preload_model_benchmarks(self):
@@ -182,70 +164,6 @@ class InferenceManager:
f"[Inference] Skipped preloading benchmarks for
{self.DEFAULT_DEVICE}, only supports CUDA currently"
)
- def _first_pool_init(self, model_id: str):
- if model_id == "sundial":
- config = SundialConfig()
- elif model_id == "timer_xl":
- config = TimerConfig()
- first_queue = mp.Queue()
- ready_event = mp.Event()
- first_pool = InferenceRequestPool(
- pool_id=0,
- model_id=model_id,
- config=config,
- request_queue=first_queue,
- result_queue=self._result_queue,
- ready_event=ready_event,
- )
- first_pool.start()
- if not ready_event.wait(timeout=30):
- logger.error(
- f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool
failed to be ready in time"
- )
- else:
- self._request_pool_map[model_id] = [(first_pool, first_queue)]
- logger.info(
- f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0]
Initialized inference request pool for model {model_id}"
- )
-
- def _expand_pools(self, model_id, start_idx, count):
- for idx in range(count):
- queue = mp.Queue()
- if model_id == "sundial":
- config = SundialConfig()
- elif model_id == "timer_xl":
- config = TimerConfig()
- pool = InferenceRequestPool(
- pool_id=start_idx + idx,
- model_id=model_id,
- config=config,
- request_queue=queue,
- result_queue=self._result_queue,
- ready_event=mp.Event(),
- )
- pool.start()
- if not pool.ready_event.wait(timeout=30):
- logger.error(
-
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{start_idx + idx}] Pool failed
to be ready in time"
- )
- continue
- else:
- self._request_pool_map[model_id].append((pool, queue))
- logger.info(
-
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference
request pool started for model {model_id}"
- )
-
- def _handle_results(self):
- while not self._stop_event.is_set():
- if self._result_queue.empty():
- time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
- continue
- infer_req: InferenceRequest = self._result_queue.get()
- with self._result_wrapper_lock:
- self._result_wrapper_map[infer_req.req_id].set_result(
- infer_req.get_final_output()
- )
-
def _get_strategy(self, model_id, model):
if isinstance(model, TimerForPrediction):
return TimerXLStrategy(model)
@@ -287,22 +205,6 @@ class InferenceManager:
if model_id in self.ACCELERATE_MODEL_ID and "cuda" in str(
self.DEFAULT_DEVICE
):
- # lazy initialization for first request
- if model_id not in self._request_pool_map:
- pool_num = _estimate_pool_size(self.DEFAULT_DEVICE,
model_id)
- if pool_num <= 0:
- raise InferenceModelInternalError(
- f"Not enough memory to run model {model_id}."
- )
- # initialize the first pool
- self._first_pool_init(model_id)
- # start a background thread to expand pools
- expand_thread = threading.Thread(
- target=self._expand_pools,
- args=(model_id, 1, pool_num - 1),
- daemon=True,
- )
- expand_thread.start()
# TODO: Logic in this branch shall handle all LTSM inferences
# TODO: TSBlock -> Tensor codes should be unified
data = full_data[1][0]
@@ -314,26 +216,19 @@ class InferenceManager:
inference_pipeline =
TimerSundialInferencePipeline(SundialConfig())
elif model_id == "timer_xl":
inference_pipeline =
TimerXLInferencePipeline(TimerConfig())
+ else:
+ raise InferenceModelInternalError(
+ f"Unsupported model_id: {model_id}"
+ )
infer_req = InferenceRequest(
req_id=generate_req_id(),
+ model_id=model_id,
inputs=inputs,
inference_pipeline=inference_pipeline,
max_new_tokens=predict_length,
)
- infer_proxy = InferenceRequestProxy(infer_req.req_id)
- with self._result_wrapper_lock:
- self._result_wrapper_map[infer_req.req_id] = infer_proxy
- pool_idx = hash(infer_req.req_id) % len(
- self._request_pool_map[model_id]
- )
- self._request_pool_map[model_id][pool_idx][1].put(infer_req)
- logger.debug(
-
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{infer_req.req_id}]
Request is queued for inference"
- )
- outputs = infer_proxy.wait_for_completion()
+ outputs = self._request_controller.process_request(infer_req)
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
- with self._result_wrapper_lock:
- del self._result_wrapper_map[infer_req.req_id]
else:
# load model
accel = str(inference_attrs.get("acceleration", "")).lower()
== "true"
@@ -383,15 +278,4 @@ class InferenceManager:
)
def shutdown(self):
- self._stop_event.set()
- for model_id, pools in self._request_pool_map.items():
- for requestPool, requestQueue in pools:
- requestPool.stop()
- while not requestQueue.empty():
- requestQueue.get_nowait()
- requestQueue.close()
- for requestPool, _ in pools:
- requestPool.join(timeout=10)
- while not self._result_queue.empty():
- self._result_queue.get_nowait()
- self._result_queue.close()
+ self._request_controller.shutdown()
diff --git a/iotdb-core/ainode/ainode/core/util/decorator.py
b/iotdb-core/ainode/ainode/core/util/decorator.py
index 33b9f4835ac..5a84c3d6bb2 100644
--- a/iotdb-core/ainode/ainode/core/util/decorator.py
+++ b/iotdb-core/ainode/ainode/core/util/decorator.py
@@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
#
+from functools import wraps
+
+
def singleton(cls):
instances = {}
@@ -24,3 +27,15 @@ def singleton(cls):
return instances[cls]
return get_instance
+
+
+def synchronized(lock):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ with lock:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator