This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 393aba20740 [AINode] Support concurrent inference for Timer-Sundial
(#15897)
393aba20740 is described below
commit 393aba207407c5c4849faf67bb7c0781c7584a87
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 16 10:17:33 2025 +0800
[AINode] Support concurrent inference for Timer-Sundial (#15897)
---
iotdb-core/ainode/ainode/core/config.py | 17 +++
iotdb-core/ainode/ainode/core/constant.py | 5 +
.../ainode/ainode/core/inference/__init__.py | 17 +++
.../ainode/core/inference/inference_request.py | 121 ++++++++++++++++++
.../core/inference/inference_request_pool.py | 140 +++++++++++++++++++++
.../ainode/core/inference/strategy/__init__.py | 17 +++
.../strategy/abstract_inference_pipeline.py | 60 +++++++++
.../strategy/timer_sundial_inference_pipeline.py | 51 ++++++++
iotdb-core/ainode/ainode/core/inference/utils.py | 80 ++++++++++++
.../ainode/core/manager/inference_manager.py | 132 ++++++++++++++++---
.../core/model/sundial/configuration_sundial.py | 2 -
.../ainode/core/model/timerxl/modeling_timer.py | 3 -
12 files changed, 624 insertions(+), 21 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/config.py
b/iotdb-core/ainode/ainode/core/config.py
index b4694cb9c3b..6f0336ad6ae 100644
--- a/iotdb-core/ainode/ainode/core/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -30,6 +30,7 @@ from ainode.core.constant import (
AINODE_CONF_FILE_NAME,
AINODE_CONF_GIT_FILE_NAME,
AINODE_CONF_POM_FILE_NAME,
+ AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
AINODE_INFERENCE_RPC_ADDRESS,
AINODE_INFERENCE_RPC_PORT,
AINODE_LOG_DIR,
@@ -55,6 +56,9 @@ class AINodeConfig(object):
# Used for connection of DataNode/ConfigNode clients
self._ain_inference_rpc_address: str = AINODE_INFERENCE_RPC_ADDRESS
self._ain_inference_rpc_port: int = AINODE_INFERENCE_RPC_PORT
+ self._ain_inference_batch_interval_in_ms: int = (
+ AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
+ )
# log directory
self._ain_logs_dir: str = AINODE_LOG_DIR
@@ -132,6 +136,14 @@ class AINodeConfig(object):
def set_ain_inference_rpc_port(self, ain_inference_rpc_port: int) -> None:
self._ain_inference_rpc_port = ain_inference_rpc_port
+ def get_ain_inference_batch_interval_in_ms(self) -> int:
+ return self._ain_inference_batch_interval_in_ms
+
+ def set_ain_inference_batch_interval_in_ms(
+ self, ain_inference_batch_interval_in_ms: int
+ ) -> None:
+ self._ain_inference_batch_interval_in_ms =
ain_inference_batch_interval_in_ms
+
def get_ain_logs_dir(self) -> str:
return self._ain_logs_dir
@@ -273,6 +285,11 @@ class AINodeDescriptor(object):
int(file_configs["ain_inference_rpc_port"])
)
+ if "ain_inference_batch_interval_in_ms" in config_keys:
+ self._config.set_ain_inference_batch_interval_in_ms(
+ int(file_configs["ain_inference_batch_interval_in_ms"])
+ )
+
if "ain_models_dir" in config_keys:
self._config.set_ain_models_dir(file_configs["ain_models_dir"])
diff --git a/iotdb-core/ainode/ainode/core/constant.py
b/iotdb-core/ainode/ainode/core/constant.py
index c307dbafe63..1858d56ad8c 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -29,15 +29,19 @@ AINODE_CONF_FILE_NAME = "iotdb-ainode.properties"
AINODE_CONF_GIT_FILE_NAME = "git.properties"
AINODE_CONF_POM_FILE_NAME = "pom.properties"
AINODE_SYSTEM_FILE_NAME = "system.properties"
+
# inference_rpc_address
AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1"
AINODE_INFERENCE_RPC_PORT = 10810
+AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
+
# AINode folder structure
AINODE_MODELS_DIR = "data/ainode/models"
AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in
models, we only need to store their weights and config.
AINODE_SYSTEM_DIR = "data/ainode/system"
AINODE_LOG_DIR = "logs/ainode"
AINODE_THRIFT_COMPRESSION_ENABLED = False
+
# use for node management
AINODE_CLUSTER_NAME = "defaultCluster"
AINODE_VERSION_INFO = "UNKNOWN"
@@ -45,6 +49,7 @@ AINODE_BUILD_INFO = "UNKNOWN"
AINODE_ROOT_DIR = os.path.dirname(
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
)
+
# connect IoTDB cluster
AINODE_CLUSTER_INGRESS_ADDRESS = "127.0.0.1"
AINODE_CLUSTER_INGRESS_PORT = 6667
diff --git a/iotdb-core/ainode/ainode/core/inference/__init__.py
b/iotdb-core/ainode/ainode/core/inference/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request.py
b/iotdb-core/ainode/ainode/core/inference/inference_request.py
new file mode 100644
index 00000000000..4cf8e4992dc
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import threading
+from typing import Any
+
+import torch
+
+from ainode.core.inference.strategy.abstract_inference_pipeline import (
+ AbstractInferencePipeline,
+)
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class InferenceRequestState:
+ WAITING = "waiting"
+ RUNNING = "running"
+ FINISHED = "finished"
+
+
+class InferenceRequest:
+ def __init__(
+ self,
+ req_id: str,
+ inputs: torch.Tensor,
+ inference_pipeline: AbstractInferencePipeline,
+ max_new_tokens: int = 96,
+ **infer_kwargs,
+ ):
+ if inputs.ndim == 1:
+ inputs = inputs.unsqueeze(0)
+
+ self.req_id = req_id
+ self.inputs = inputs
+ self.infer_kwargs = infer_kwargs
+ self.inference_pipeline = inference_pipeline
+ self.max_new_tokens = (
+ max_new_tokens # Number of time series data points to generate
+ )
+
+ self.batch_size = inputs.size(0)
+ self.state = InferenceRequestState.WAITING
+ self.cur_step_idx = 0 # Current write position in the output step
index
+
+ # Preallocate output buffer [batch_size, max_new_tokens]
+ device = inputs.device
+ self.output_tensor = torch.zeros(
+ self.batch_size, max_new_tokens, device=device
+ ) # shape: [self.batch_size, max_new_steps]
+
+ def mark_running(self):
+ self.state = InferenceRequestState.RUNNING
+
+ def mark_finished(self):
+ self.state = InferenceRequestState.FINISHED
+
+ def is_finished(self) -> bool:
+ return (
+ self.state == InferenceRequestState.FINISHED
+ or self.cur_step_idx >= self.max_new_tokens
+ )
+
+ def write_step_output(self, step_output: torch.Tensor):
+ if step_output.ndim == 1:
+ step_output = step_output.unsqueeze(0)
+
+ batch_size, step_size = step_output.shape
+ end_idx = self.cur_step_idx + step_size
+
+ if end_idx > self.max_new_tokens:
+ self.output_tensor[:, self.cur_step_idx :] = step_output[
+ :, : self.max_new_tokens - self.cur_step_idx
+ ]
+ self.cur_step_idx = self.max_new_tokens
+ else:
+ self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
+ self.cur_step_idx = end_idx
+
+ if self.is_finished():
+ self.mark_finished()
+
+ def get_final_output(self) -> torch.Tensor:
+ return self.output_tensor[:, : self.cur_step_idx]
+
+
+class InferenceRequestProxy:
+ """
+ Wrap the raw request for handling multiprocess processing.
+ """
+
+ def __init__(self, req_id: str):
+ self.req_id = req_id
+ self.result = None
+ self._lock = threading.Lock()
+ self._condition = threading.Condition(self._lock)
+
+ def set_result(self, result: Any):
+ with self._lock:
+ self.result = result
+ self._condition.notify_all()
+
+ def wait_for_completion(self) -> Any:
+ with self._lock:
+ self._condition.wait()
+ return self.result
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
new file mode 100644
index 00000000000..8d6d4060ab2
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import random
+import threading
+import time
+
+import numpy as np
+import torch
+import torch.multiprocessing as mp
+from transformers import PretrainedConfig, PreTrainedModel
+
+from ainode.core.config import AINodeDescriptor
+from ainode.core.inference.inference_request import InferenceRequest
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+class InferenceRequestPool(mp.Process):
+ """
+ The request pool to handle inference for a specific model.
+ """
+
+ FIX_SEED = 2021
+ WAITING_INTERVAL_IN_MS = (
+
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+ ) # How often to check for requests in the waiting/running queue
+
+ def __init__(
+ self,
+ pool_id: int,
+ model: PreTrainedModel,
+ config: PretrainedConfig,
+ request_queue: mp.Queue,
+ result_queue: mp.Queue,
+ **pool_kwargs,
+ ):
+ super().__init__()
+ self.pool_id = pool_id
+ self.model = model
+ self.device = self.model.device
+ self.config = config
+ self.pool_kwargs = pool_kwargs
+
+ # TODO: A scheduler is necessary for better handling following queues
+ self._threads = []
+ self._waiting_queue = request_queue # Requests that are waiting to be
processed
+ self._running_queue = mp.Queue() # Requests that are currently being
processed
+ self._finished_queue = result_queue # Requests that are finished
+ self._stop_event = mp.Event()
+
+ # Fix inference seed
+ random.seed(self.FIX_SEED)
+ torch.manual_seed(self.FIX_SEED)
+ np.random.seed(self.FIX_SEED)
+
+ def memory_is_available(self, request):
+ # need test with several rounds of dummy data
+ pass
+
+ def _activate_requests(self):
+ if self._waiting_queue.empty():
+ return
+ request: InferenceRequest = self._waiting_queue.get()
+ # TODO: Check memory size before activating requests
+ request.inputs =
request.inference_pipeline.preprocess_inputs(request.inputs)
+ request.mark_running()
+ logger.debug(
+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}]
Request is activated with inputs shape {request.inputs.shape}"
+ )
+ self._running_queue.put(request)
+
+ def _requests_activate_loop(self):
+ while not self._stop_event.is_set():
+ time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+ self._activate_requests()
+
+ def _step(self):
+ if self._running_queue.empty():
+ return
+ # TODO: We need a batcher to accelerate the concurrent inference
+ # TODO: Check memory size before executing requests
+ request: InferenceRequest = self._running_queue.get()
+ output = self.model.generate(
+ request.inputs,
+ max_new_tokens=request.max_new_tokens,
+ num_samples=10,
+ revin=True,
+ )
+ request.write_step_output(output[0].mean(dim=0))
+ request.inference_pipeline.post_decode()
+ if request.is_finished():
+ request.inference_pipeline.post_inference()
+ logger.debug(
+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}]
Request is finished"
+ )
+ self._finished_queue.put(request)
+ else:
+ logger.debug(
+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}]
Request is not finished, re-queueing"
+ )
+ self._waiting_queue.put(request)
+
+ def _requests_execute_loop(self):
+ while not self._stop_event.is_set():
+ time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+ self._step()
+
+ def run(self):
+ activate_daemon = threading.Thread(
+ target=self._requests_activate_loop, daemon=True
+ )
+ self._threads.append(activate_daemon)
+ activate_daemon.start()
+ execute_daemon = threading.Thread(
+ target=self._requests_execute_loop, daemon=True
+ )
+ self._threads.append(execute_daemon)
+ execute_daemon.start()
+ for thread in self._threads:
+ thread.join()
+
+ def stop(self):
+ self._stop_event.set()
diff --git a/iotdb-core/ainode/ainode/core/inference/strategy/__init__.py
b/iotdb-core/ainode/ainode/core/inference/strategy/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/strategy/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git
a/iotdb-core/ainode/ainode/core/inference/strategy/abstract_inference_pipeline.py
b/iotdb-core/ainode/ainode/core/inference/strategy/abstract_inference_pipeline.py
new file mode 100644
index 00000000000..2300169a6ee
--- /dev/null
+++
b/iotdb-core/ainode/ainode/core/inference/strategy/abstract_inference_pipeline.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from abc import ABC, abstractmethod
+
+import torch
+
+
+class AbstractInferencePipeline(ABC):
+ """
+ Abstract assistance strategy class for model inference.
+ This class shall define the interface process for specific model.
+ """
+
+ def __init__(self, model_config, **infer_kwargs):
+ self.model_config = model_config
+ self.infer_kwargs = infer_kwargs
+
+ @abstractmethod
+ def preprocess_inputs(self, inputs: torch.Tensor):
+ """
+ Preprocess the inputs before inference, including shape validation and
value transformation.
+
+ Args:
+ inputs (torch.Tensor): The input tensor to be preprocessed.
+
+ Returns:
+ torch.Tensor: The preprocessed input tensor.
+ """
+ # TODO: Integrate with the data processing pipeline operators
+ pass
+
+ @abstractmethod
+ def post_decode(self):
+ """
+ Post-process the outputs after each decode step.
+ """
+ pass
+
+ @abstractmethod
+ def post_inference(self):
+ """
+ Post-process the outputs after the entire inference task.
+ """
+ pass
diff --git
a/iotdb-core/ainode/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py
b/iotdb-core/ainode/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py
new file mode 100644
index 00000000000..ffa76751713
--- /dev/null
+++
b/iotdb-core/ainode/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+
+from ainode.core.exception import InferenceModelInternalError
+from ainode.core.inference.strategy.abstract_inference_pipeline import (
+ AbstractInferencePipeline,
+)
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
+
+
+class TimerSundialInferencePipeline(AbstractInferencePipeline):
+ """
+ Strategy for Timer-Sundial model inference.
+ """
+
+ def __init__(self, model_config: SundialConfig, **infer_kwargs):
+ super().__init__(model_config, infer_kwargs=infer_kwargs)
+
+ def preprocess_inputs(self, inputs: torch.Tensor):
+ super().preprocess_inputs(inputs)
+ if len(inputs.shape) != 2:
+ raise InferenceModelInternalError(
+ f"[Inference] Input shape must be: [batch_size, seq_len], but
receives {inputs.shape}"
+ )
+ # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py
+ return inputs
+
+ def post_decode(self):
+ # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py
+ pass
+
+ def post_inference(self):
+ # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py
+ pass
diff --git a/iotdb-core/ainode/ainode/core/inference/utils.py
b/iotdb-core/ainode/ainode/core/inference/utils.py
new file mode 100644
index 00000000000..c2a618d716c
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/inference/utils.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import secrets
+import string
+
+import torch
+from transformers.modeling_outputs import MoeCausalLMOutputWithPast
+
+
+def _generate_req_id(length=10, charset=string.ascii_letters + string.digits)
-> str:
+ """
+ Generate a random req_id string of specified length.
+ The length is 10 by default, with 10^{17} possible combinations.
+ """
+ return "".join(secrets.choice(charset) for _ in range(length))
+
+
+def _slice_tensor(t, s, e):
+ return None if t is None else t[s:e]
+
+
+def _slice_tuple_of_tensors(tup, s, e):
+ """
+ hidden_states / attentions: Tuple[layer0, layer1, ...]
+ every layer maybe Tensor or None。
+ """
+ if tup is None:
+ return None
+ sliced = []
+ for x in tup:
+ sliced.append(_slice_tensor(x, s, e) if torch.is_tensor(x) else x)
+ return tuple(sliced)
+
+
+def _slice_pkv(pkv, s, e):
+ if pkv is None:
+ return None
+ out = []
+ for layer in pkv: # layer: Tuple[key, value, ...]
+ out.append(tuple(x[s:e] for x in layer))
+ return out
+
+
+def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes):
+ """
+ split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes)
+ split_sizes[i] = ith request's batch_size。
+ """
+ outs = []
+ start = 0
+ for bsz in split_sizes:
+ end = start + bsz
+ outs.append(
+ MoeCausalLMOutputWithPast(
+ loss=_slice_tensor(batch_out.loss, start, end),
+ logits=batch_out.logits[start:end],
+ past_key_values=_slice_pkv(batch_out.past_key_values, start,
end),
+ hidden_states=_slice_tuple_of_tensors(
+ batch_out.hidden_states, start, end
+ ),
+ attentions=_slice_tuple_of_tensors(batch_out.attentions,
start, end),
+ )
+ )
+ start = end
+ return outs
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 37e43898c21..d62c61279a2 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -15,22 +15,35 @@
# specific language governing permissions and limitations
# under the License.
#
-import random
+import threading
+import time
from abc import ABC, abstractmethod
+from typing import Dict, List
-import numpy as np
import pandas as pd
import torch
+import torch.multiprocessing as mp
from iotdb.tsfile.utils.tsblock_serde import deserialize
+from ainode.core.config import AINodeDescriptor
from ainode.core.constant import TSStatusCode
from ainode.core.exception import (
InferenceModelInternalError,
InvalidWindowArgumentError,
runtime_error_extractor,
)
+from ainode.core.inference.inference_request import (
+ InferenceRequest,
+ InferenceRequestProxy,
+)
+from ainode.core.inference.inference_request_pool import InferenceRequestPool
+from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
+ TimerSundialInferencePipeline,
+)
+from ainode.core.inference.utils import _generate_req_id
from ainode.core.log import Logger
from ainode.core.manager.model_manager import ModelManager
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
from ainode.core.model.sundial.modeling_sundial import SundialForPrediction
from ainode.core.model.timerxl.modeling_timer import TimerForPrediction
from ainode.core.rpc.status import get_status
@@ -43,7 +56,6 @@ from ainode.thrift.ainode.ttypes import (
)
logger = Logger()
-FIX_SEED = 2021
class InferenceStrategy(ABC):
@@ -122,15 +134,71 @@ class RegisteredStrategy(InferenceStrategy):
class InferenceManager:
+ ACCELERATE_MODEL_ID = "sundial"
+ DEFAULT_DEVICE = "cpu"
+ # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
+ DEFAULT_POOL_SIZE = (
+ 0 # TODO: Remove these parameter by sampling model inference
consumption
+ )
+ WAITING_INTERVAL_IN_MS = (
+
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
+ ) # How often to check for requests in the result queue
+
def __init__(self, model_manager: ModelManager):
- self.model_manager = model_manager
+ self._model_manager = model_manager
+ self._result_queue = mp.Queue()
+ self._result_wrapper_map = {}
+ self._result_wrapper_lock = threading.RLock()
+ # structure: {model_id: [(InferenceRequestPool, request_queue), ...]}
+ self._request_pool_map: Dict[str, List[(InferenceRequestPool,
mp.Queue)]] = {}
+ self._stop_event = mp.Event()
+ self._init_inference_request_pool()
+ self._result_handler_thread = threading.Thread(
+ target=self._handle_results, daemon=True
+ )
+ self._result_handler_thread.start()
+
+ def _init_inference_request_pool(self):
+ """
+ Initialize the inference request pool for each model.
+ TODO: This is a temporary solution, we need a automatic algorithm to
adjust the pool size for different models
+ """
+ self._request_pool_map[self.ACCELERATE_MODEL_ID] = []
+ for idx in range(self.DEFAULT_POOL_SIZE):
+ sundial_model = self._model_manager.load_model(
+ self.ACCELERATE_MODEL_ID, {}
+ ).to(self.DEFAULT_DEVICE)
+ sundial_config = SundialConfig()
+ request_queue = mp.Queue()
+ request_pool = InferenceRequestPool(
+ pool_id=idx,
+ model=sundial_model,
+ config=sundial_config,
+ request_queue=request_queue,
+ result_queue=self._result_queue,
+ )
+ request_pool.start()
+ self._request_pool_map[self.ACCELERATE_MODEL_ID].append(
+ (request_pool, request_queue)
+ )
+
+ def _handle_results(self):
+ while not self._stop_event.is_set():
+ if self._result_queue.empty():
+ time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
+ continue
+ infer_req: InferenceRequest = self._result_queue.get()
+ with self._result_wrapper_lock:
+ self._result_wrapper_map[infer_req.req_id].set_result(
+ infer_req.get_final_output()
+ )
def _get_strategy(self, model_id, model):
if isinstance(model, TimerForPrediction):
return TimerXLStrategy(model)
if isinstance(model, SundialForPrediction):
return SundialStrategy(model)
- if
self.model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
+ if
self._model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
return BuiltInStrategy(model)
return RegisteredStrategy(model)
@@ -144,22 +212,42 @@ class InferenceManager:
single_output: bool,
):
model_id = req.modelId
- logger.info(f"Start processing for {model_id}")
- random.seed(FIX_SEED)
- torch.manual_seed(FIX_SEED)
- np.random.seed(FIX_SEED)
try:
raw = data_getter(req)
full_data = deserializer(raw)
inference_attrs = extract_attrs(req)
- # load model
- accel = str(inference_attrs.get("acceleration", "")).lower() ==
"true"
- model = self.model_manager.load_model(model_id, inference_attrs,
accel)
-
- # inference by strategy
- strategy = self._get_strategy(model_id, model)
- outputs = strategy.infer(full_data, **inference_attrs)
+ if model_id == self.ACCELERATE_MODEL_ID and self.DEFAULT_POOL_SIZE
> 0:
+ # TODO: Logic in this branch shall handle all LTSM inferences
+ # TODO: TSBlock -> Tensor codes should be unified
+ data = full_data[1][0]
+ if data.dtype.byteorder not in ("=", "|"):
+ data = data.byteswap().newbyteorder()
+ inputs =
torch.tensor(data).unsqueeze(0).float().to(self.DEFAULT_DEVICE)
+ infer_req = InferenceRequest(
+ req_id=_generate_req_id(),
+ inputs=inputs,
+
inference_pipeline=TimerSundialInferencePipeline(SundialConfig()),
+ max_new_tokens=inference_attrs.get("predict_length", 96),
+ )
+ infer_proxy = InferenceRequestProxy(infer_req.req_id)
+ with self._result_wrapper_lock:
+ self._result_wrapper_map[infer_req.req_id] = infer_proxy
+ pool_idx = hash(infer_req.req_id) % len(
+ self._request_pool_map[model_id]
+ )
+ self._request_pool_map[model_id][pool_idx][1].put(infer_req)
+ outputs = infer_proxy.wait_for_completion()
+ outputs = convert_to_binary(pd.DataFrame(outputs[0]))
+ with self._result_wrapper_lock:
+ del self._result_wrapper_map[infer_req.req_id]
+ else:
+ # load model
+ accel = str(inference_attrs.get("acceleration", "")).lower()
== "true"
+ model = self._model_manager.load_model(model_id,
inference_attrs, accel)
+ # inference by strategy
+ strategy = self._get_strategy(model_id, model)
+ outputs = strategy.infer(full_data, **inference_attrs)
# construct response
status = get_status(TSStatusCode.SUCCESS_STATUS)
@@ -200,3 +288,15 @@ class InferenceManager:
resp_cls=TInferenceResp,
single_output=False,
)
+
+ def shutdown(self):
+ self._stop_event.set()
+ for model_id, pools in self._request_pool_map.items():
+ for requestPool, requestQueue in pools:
+ requestPool.stop()
+ while not requestQueue.empty():
+ requestQueue.get_nowait()
+ requestQueue.close()
+ while not self._result_queue.empty():
+ self._result_queue.get_nowait()
+ self._result_queue.close()
diff --git
a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
index c903ce3e9dd..21eefef2933 100644
--- a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
+++ b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
@@ -42,7 +42,6 @@ class SundialConfig(PretrainedConfig):
flow_loss_depth: int = 3,
num_sampling_steps: int = 50,
diffusion_batch_mul: int = 4,
- ckpt_path: str = None, # weight path
**kwargs,
):
self.input_token_len = input_token_len
@@ -60,7 +59,6 @@ class SundialConfig(PretrainedConfig):
self.flow_loss_depth = flow_loss_depth
self.num_sampling_steps = num_sampling_steps
self.diffusion_batch_mul = diffusion_batch_mul
- self.ckpt_path = ckpt_path
super().__init__(
**kwargs,
diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
index 42566c0e1c9..4aed1696af7 100644
--- a/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
@@ -16,13 +16,10 @@
# under the License.
#
-import os
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
-from huggingface_hub import hf_hub_download
-from safetensors.torch import load_file as load_safetensors
from torch import nn
from transformers import Cache, DynamicCache, PreTrainedModel
from transformers.activations import ACT2FN