This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch concurrent-sundial in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 09f62aa050b295b70d817a51fc823c6f764343e8 Author: Yongzao <[email protected]> AuthorDate: Wed Jul 9 17:24:47 2025 +0800 stash changes --- iotdb-core/ainode/ainode/core/constant.py | 6 +- .../ainode/ainode/core/inference/__init__.py | 17 ++ .../ainode/core/inference/inference_request.py | 108 +++++++++ .../core/inference/inference_request_pool.py | 269 +++++++++++++++++++++ .../ainode/core/inference/strategy/__init__.py | 17 ++ .../core/inference/strategy/abstract_strategy.py | 60 +++++ .../inference/strategy/timer_sundial_strategy.py | 49 ++++ iotdb-core/ainode/ainode/core/inference/utils.py | 80 ++++++ .../ainode/core/manager/inference_manager.py | 83 ++++++- .../core/model/sundial/configuration_sundial.py | 2 - .../ainode/core/model/timerxl/modeling_timer.py | 3 - 11 files changed, 674 insertions(+), 20 deletions(-) diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index c307dbafe63..c7b75103d03 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -31,7 +31,7 @@ AINODE_CONF_POM_FILE_NAME = "pom.properties" AINODE_SYSTEM_FILE_NAME = "system.properties" # inference_rpc_address AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1" -AINODE_INFERENCE_RPC_PORT = 10810 +AINODE_INFERENCE_RPC_PORT = 11810 # AINode folder structure AINODE_MODELS_DIR = "data/ainode/models" AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in models, we only need to store their weights and config. @@ -39,7 +39,7 @@ AINODE_SYSTEM_DIR = "data/ainode/system" AINODE_LOG_DIR = "logs/ainode" AINODE_THRIFT_COMPRESSION_ENABLED = False # use for node management -AINODE_CLUSTER_NAME = "defaultCluster" +AINODE_CLUSTER_NAME = "yongzaoCluster" AINODE_VERSION_INFO = "UNKNOWN" AINODE_BUILD_INFO = "UNKNOWN" AINODE_ROOT_DIR = os.path.dirname( @@ -47,7 +47,7 @@ AINODE_ROOT_DIR = os.path.dirname( ) # connect IoTDB cluster AINODE_CLUSTER_INGRESS_ADDRESS = "127.0.0.1" -AINODE_CLUSTER_INGRESS_PORT = 6667 +AINODE_CLUSTER_INGRESS_PORT = 7667 AINODE_CLUSTER_INGRESS_USERNAME = "root" AINODE_CLUSTER_INGRESS_PASSWORD = "root" AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8" diff --git a/iotdb-core/ainode/ainode/core/inference/__init__.py b/iotdb-core/ainode/ainode/core/inference/__init__.py new file mode 100644 index 00000000000..2a1e720805f --- /dev/null +++ b/iotdb-core/ainode/ainode/core/inference/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request.py b/iotdb-core/ainode/ainode/core/inference/inference_request.py new file mode 100644 index 00000000000..ccd3fb3e542 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/inference/inference_request.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import threading +from typing import Any + +import torch + +from ainode.core.inference.strategy.abstract_strategy import AbstractStrategy + + +class InferenceRequestState: + WAITING = "waiting" + RUNNING = "running" + FINISHED = "finished" + + +class InferenceRequest: + def __init__( + self, + req_id: int, + inputs: torch.Tensor, + strategy: AbstractStrategy, + max_new_tokens: int = 96, + **infer_kwargs, + ): + if inputs.ndim == 1: + inputs = inputs.unsqueeze(0) + + self.id = req_id + self.inputs = inputs + self.infer_kwargs = infer_kwargs + self.strategy = strategy + self.max_new_tokens = ( + max_new_tokens # Number of time series data points to generate + ) + + self.batch_size = inputs.size(0) + self.state = InferenceRequestState.WAITING + self.cur_step_idx = 0 # Current write position in the output step index + + # Preallocate output buffer [batch_size, max_new_tokens] + device = inputs.device + self.output_tensor = torch.zeros( + self.batch_size, max_new_tokens, device=device + ) # shape: [self.batch_size, max_new_steps] + + self._lock = threading.Lock() + self._condition = threading.Condition(self._lock) + + def mark_running(self): + self.state = InferenceRequestState.RUNNING + + def mark_finished(self): + self.state = InferenceRequestState.FINISHED + + def is_finished(self) -> bool: + return ( + self.state == InferenceRequestState.FINISHED + or self.cur_step_idx >= self.max_new_tokens + ) + + def write_step_output(self, step_output: torch.Tensor): + with self._lock: + if step_output.ndim == 1: + step_output = step_output.unsqueeze(0) + + batch_size, step_size = step_output.shape + end_idx = self.cur_step_idx + step_size + + if end_idx > self.max_new_tokens: + self.output_tensor[:, self.cur_step_idx :] = step_output[ + :, : self.max_new_tokens - self.cur_step_idx + ] + self.cur_step_idx = self.max_new_tokens + else: + self.output_tensor[:, self.cur_step_idx : end_idx] = step_output + self.cur_step_idx = end_idx + + if self.is_finished(): + self.mark_finished() + + def get_final_output(self) -> torch.Tensor: + with self._lock: + return self.output_tensor[:, : self.cur_step_idx] + + def notify_completion(self): + with self._lock: + self._condition.notify_all() + + def wait_for_completion(self) -> Any: + with self._lock: + while self.state != InferenceRequestState.FINISHED: + self._condition.wait() diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py new file mode 100644 index 00000000000..b541c877f65 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py @@ -0,0 +1,269 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import multiprocessing +import queue +import random +import threading +import time +from multiprocessing import Process + +import numpy as np +import torch +from transformers import PretrainedConfig, PreTrainedModel + +from ainode.core.inference.inference_request import InferenceRequest +from ainode.core.log import Logger + +logger = Logger() + + +class InferenceRequestPool(Process): + """ + The request pool to handle inference for a specific model. + """ + + FIX_SEED = 2021 + WAITING_INTERVAL_IN_MS = ( + 15 # How often to check for requests in the waiting/running queue + ) + + def __init__( + self, + model: PreTrainedModel, + config: PretrainedConfig, + request_queue: multiprocessing.Queue, + **pool_kwargs, + ): + super().__init__() + self.model = model + self.device = self.model.device + self.config = config + self.pool_kwargs = pool_kwargs + + # TODO: A scheduler is necessary for better handling following queues + self.waiting_queue = request_queue # Requests that are waiting to be processed + self.running_queue = ( + queue.Queue() + ) # Requests that are currently being processed, TODO: we might need coroutine to accelerate different stages + self.finished_queue = queue.Queue() # Requests that are finished + self._stop_event = multiprocessing.Event() + + # Fix inference seed + random.seed(self.FIX_SEED) + torch.manual_seed(self.FIX_SEED) + np.random.seed(self.FIX_SEED) + + def memory_is_available(self, request): + # need test with several rounds of dummy data + pass + + def _activate_requests(self): + while not self.waiting_queue.empty(): + request: InferenceRequest = self.waiting_queue.get() + # TODO: Check memory size before activating requests + request.inputs = request.strategy.preprocess_inputs(request.inputs) + request.mark_running() + self.running_queue.put(request) + + def _requests_activate_loop(self): + while not self._stop_event.is_set(): + time.sleep(self.WAITING_INTERVAL_IN_MS / 1000) + self._activate_requests() + + def _step(self): + if self.running_queue.empty(): + return + # TODO: We need a batcher to accelerate the concurrent inference + # TODO: Check memory size before executing requests + request: InferenceRequest = self.running_queue.get() + output = self.model.generate( + request.inputs, + max_new_tokens=request.max_new_tokens, + num_samples=10, + revin=True, + ) + request.write_step_output(output[0].mean(dim=0)) + request.strategy.post_decode() + if request.is_finished(): + self.finished_queue.put(request) + else: + self.waiting_queue.put(request) + + def _requests_execute_loop(self): + while not self._stop_event.is_set(): + time.sleep(self.WAITING_INTERVAL_IN_MS / 1000) + self._step() + + def _finish(self): + while not self.finished_queue.empty(): + request: InferenceRequest = self.finished_queue.get() + request.strategy.post_inference() + request.notify_completion() + + def _requests_finish_loop(self): + while not self._stop_event.is_set(): + time.sleep(self.WAITING_INTERVAL_IN_MS / 1000) + self._finish() + + def run(self): + activate_daemon = threading.Thread(target=self._activate_requests) + activate_daemon.daemon = True + activate_daemon.start() + execute_daemon = threading.Thread(target=self._requests_execute_loop) + execute_daemon.daemon = True + execute_daemon.start() + finish_daemon = threading.Thread(target=self._requests_finish_loop) + finish_daemon.daemon = True + finish_daemon.start() + activate_daemon.join() + execute_daemon.join() + finish_daemon.join() + + def stop(self): + self._stop_event.set() + + +def pool_worker(p, done_event): + while not done_event.is_set(): + p._step() + time.sleep(0.001) + + +""" +The following code is used to test the difference in inference speed and the difference in result values when using and not using requestPool +""" +if __name__ == "__main__": + config = TimerConfig() + config.ckpt_path = "/data/mahaoke/AINode/ainode/TimerXL/model.safetensors" + model = TimerForPrediction(config).eval() + if config.ckpt_path is not None and config.ckpt_path != "": + state_dict = load_file(config.ckpt_path) + model.load_state_dict(state_dict, strict=True) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + BATCH = 1 + INPUT_LEN = config.input_token_len * 7 # 例如 4 × 96 + x1 = torch.randn(BATCH, INPUT_LEN, device=device) + x2 = torch.randn(BATCH, INPUT_LEN, device=device) + x3 = torch.randn(BATCH, INPUT_LEN, device=device) + + pool = InferenceRequestPool(model, config, total_memory_availble=24 * 1024) + + def _always_true(self, req): + return True + + InferenceRequestPool.memory_is_availble = _always_true + + def prepare_inputs(model, x, max_new_steps: int = 96, **model_kwargs): + model_inputs = model.prepare_inputs_for_generation(x, **model_kwargs) + return model_inputs + + def baseline_generate(model, inp: torch.Tensor, max_steps: int, **model_kwargs): + cur_ids = inp + preds = [] + remain = max_steps + + model_kwargs["attention_mask"] = pool.prepare_attention_mask_for_generation(inp) + + batch_size, cur_len = inp.shape + + model_kwargs["unfinished_sequences"] = torch.ones( + batch_size, dtype=torch.long, device=inp.device + ) + model_kwargs["cache_position"] = torch.arange(cur_len, device=inp.device) + true_seq_len = cur_len // config.input_token_len + model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ + :, -true_seq_len: + ] + model_kwargs["past_key_values"] = None + model_kwargs["position_ids"] = None + model_kwargs["is_encoder_decoder"] = getattr( + config, "is_encoder_decoder", False + ) + model_kwargs["max_output_length"] = max_steps + + while remain > 0: + chunk = 96 + model_inputs = prepare_inputs(model, cur_ids, max_steps, **model_kwargs) + out = model(**model_inputs) + # [B, chunk] + tok = out.logits.detach() + preds.append(tok.cpu()) + cur_ids = torch.cat([cur_ids, tok.to(device)], dim=-1) + + horizon_len = 96 // config.input_token_len + model_kwargs = pool._update_model_kwargs_for_generation( + out, model_kwargs, horizon_len, False + ) + + remain -= chunk + return torch.cat(preds, dim=-1) # [B, max_steps] + + # warm up + for i in range(3): + base_res1 = baseline_generate(model, x1, 192) + + torch.cuda.synchronize() + t_base_start = time.perf_counter() + base_res1 = baseline_generate(model, x1, 192) + base_res2 = baseline_generate(model, x2, 192) + base_res3 = baseline_generate(model, x3, 192) + base_reses = [base_res1, base_res2, base_res3] + # print(f'base_reses:{base_reses}') + torch.cuda.synchronize() + t_base_end = time.perf_counter() + base_time = t_base_end - t_base_start + print(f"[Baseline] total time: {base_time*1000:.1f} ms") + + done_event = threading.Event() + threading.Thread(target=pool_worker, args=(pool, done_event), daemon=True).start() + + torch.cuda.synchronize() + t_pool_start = time.perf_counter() + pool.add_request(1, x1, max_new_steps=192) + # time.sleep(0.010) + pool.add_request(2, x2, max_new_steps=192) + # time.sleep(0.010) + pool.add_request(3, x3, max_new_steps=192) + pool_results = [] + while len(pool_results) < 3: + pool_results.append(pool.results_queue.get()) + torch.cuda.synchronize() + t_pool_end = time.perf_counter() + pool_time = t_pool_end - t_pool_start + print(f"[RequestPool] total time: {pool_time*1000:.1f} ms") + + done_event.set() # stop pool + + def mae(a, b): + return (a - b).abs().mean().item() + + diff1 = mae( + pool_results[0][1].to("cpu"), base_reses[pool_results[0][0] - 1].to("cpu") + ) + diff2 = mae( + pool_results[1][1].to("cpu"), base_reses[pool_results[1][0] - 1].to("cpu") + ) + diff3 = mae( + pool_results[2][1].to("cpu"), base_reses[pool_results[2][0] - 1].to("cpu") + ) + + print(f"MAE diff (req1/2/3): {diff1:.6f}, {diff2:.6f}, {diff3:.6f}") + print(f"Speed-up: {base_time/pool_time:.2f}× faster with RequestPool") diff --git a/iotdb-core/ainode/ainode/core/inference/strategy/__init__.py b/iotdb-core/ainode/ainode/core/inference/strategy/__init__.py new file mode 100644 index 00000000000..2a1e720805f --- /dev/null +++ b/iotdb-core/ainode/ainode/core/inference/strategy/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/ainode/core/inference/strategy/abstract_strategy.py b/iotdb-core/ainode/ainode/core/inference/strategy/abstract_strategy.py new file mode 100644 index 00000000000..a77cff95453 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/inference/strategy/abstract_strategy.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import ABC, abstractmethod + +import torch + + +class AbstractStrategy(ABC): + """ + Abstract assistance strategy class for model inference. + This class shall define the interface process for specific model. + """ + + def __init__(self, model_config, **infer_kwargs): + self.model_config = model_config + self.infer_kwargs = infer_kwargs + + @abstractmethod + def preprocess_inputs(self, inputs: torch.Tensor): + """ + Preprocess the inputs before inference, including shape validation and value transformation. + + Args: + inputs (torch.Tensor): The input tensor to be preprocessed. + + Returns: + torch.Tensor: The preprocessed input tensor. + """ + # TODO: Integrate with the data processing pipeline operators + pass + + @abstractmethod + def post_decode(self): + """ + Post-process the outputs after each decode step. + """ + pass + + @abstractmethod + def post_inference(self): + """ + Post-process the outputs after the entire inference task. + """ + pass diff --git a/iotdb-core/ainode/ainode/core/inference/strategy/timer_sundial_strategy.py b/iotdb-core/ainode/ainode/core/inference/strategy/timer_sundial_strategy.py new file mode 100644 index 00000000000..e936fc038cd --- /dev/null +++ b/iotdb-core/ainode/ainode/core/inference/strategy/timer_sundial_strategy.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import torch + +from ainode.core.exception import InferenceModelInternalError +from ainode.core.inference.strategy.abstract_strategy import AbstractStrategy +from ainode.core.model.sundial.configuration_sundial import SundialConfig + + +class TimerSundialStrategy(AbstractStrategy): + """ + Strategy for Timer-Sundial model inference. + """ + + def __init__(self, model_config: SundialConfig, **infer_kwargs): + super().__init__(model_config, infer_kwargs=infer_kwargs) + + def preprocess_inputs(self, inputs: torch.Tensor): + super().preprocess_inputs(inputs) + if len(inputs.shape) != 2: + raise InferenceModelInternalError( + f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" + ) + # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py + return inputs + + def post_decode(self): + # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py + pass + + def post_inference(self): + # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py + pass diff --git a/iotdb-core/ainode/ainode/core/inference/utils.py b/iotdb-core/ainode/ainode/core/inference/utils.py new file mode 100644 index 00000000000..04199389205 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/inference/utils.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import secrets +import string + +import torch +from transformers.modeling_outputs import MoeCausalLMOutputWithPast + + +def _generate_req_id(length=10, charset=string.ascii_letters + string.digits): + """ + Generate a random req_id string of specified length. + The length is 10 by default, with 10^{17} possible combinations. + """ + return "".join(secrets.choice(charset) for _ in range(length)) + + +def _slice_tensor(t, s, e): + return None if t is None else t[s:e] + + +def _slice_tuple_of_tensors(tup, s, e): + """ + hidden_states / attentions: Tuple[layer0, layer1, ...] + every layer maybe Tensor or None。 + """ + if tup is None: + return None + sliced = [] + for x in tup: + sliced.append(_slice_tensor(x, s, e) if torch.is_tensor(x) else x) + return tuple(sliced) + + +def _slice_pkv(pkv, s, e): + if pkv is None: + return None + out = [] + for layer in pkv: # layer: Tuple[key, value, ...] + out.append(tuple(x[s:e] for x in layer)) + return out + + +def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes): + """ + split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes) + split_sizes[i] = ith request's batch_size。 + """ + outs = [] + start = 0 + for bsz in split_sizes: + end = start + bsz + outs.append( + MoeCausalLMOutputWithPast( + loss=_slice_tensor(batch_out.loss, start, end), + logits=batch_out.logits[start:end], + past_key_values=_slice_pkv(batch_out.past_key_values, start, end), + hidden_states=_slice_tuple_of_tensors( + batch_out.hidden_states, start, end + ), + attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end), + ) + ) + start = end + return outs diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index 9eda1c22651..a1983adef3f 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. # -import random +import multiprocessing from abc import ABC, abstractmethod +from typing import Dict, List -import numpy as np import pandas as pd import torch from iotdb.tsfile.utils.tsblock_serde import deserialize @@ -29,8 +29,13 @@ from ainode.core.exception import ( InvalidWindowArgumentError, runtime_error_extractor, ) +from ainode.core.inference.inference_request import InferenceRequest +from ainode.core.inference.inference_request_pool import InferenceRequestPool +from ainode.core.inference.strategy.timer_sundial_strategy import TimerSundialStrategy +from ainode.core.inference.utils import _generate_req_id from ainode.core.log import Logger from ainode.core.manager.model_manager import ModelManager +from ainode.core.model.sundial.configuration_sundial import SundialConfig from ainode.core.model.sundial.modeling_sundial import SundialForPrediction from ainode.core.model.timerxl.modeling_timer import TimerForPrediction from ainode.core.util.serde import convert_to_binary @@ -43,7 +48,6 @@ from ainode.thrift.ainode.ttypes import ( ) logger = Logger() -FIX_SEED = 2021 class InferenceStrategy(ABC): @@ -122,8 +126,41 @@ class RegisteredStrategy(InferenceStrategy): class InferenceManager: + ACCELERATE_MODEL_ID = "sundial" + DEFAULT_DEVICE = "cpu" + # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + DEFAULT_POOL_SIZE = ( + 1 # TODO: Remove these parameter by sampling model inference consumption + ) + def __init__(self, model_manager: ModelManager): self.model_manager = model_manager + # structure: {model_id: [(InferenceRequestPool, request_queue), ...]} + self.request_pool_map: Dict[ + str, List[(InferenceRequestPool, multiprocessing.Queue)] + ] = {} + self.result_queue = multiprocessing.Queue() + self._init_inference_request_pool() + + def _init_inference_request_pool(self): + """ + Initialize the inference request pool for each model. + TODO: This is a temporary solution, we need a automatic algorithm to adjust the pool size for different models + """ + self.request_pool_map[self.ACCELERATE_MODEL_ID] = [] + for _ in range(self.DEFAULT_POOL_SIZE): + sundial_model = self.model_manager.load_model( + self.ACCELERATE_MODEL_ID, {} + ).to(self.DEFAULT_DEVICE) + sundial_config = SundialConfig() + request_queue = multiprocessing.Queue() + request_pool = InferenceRequestPool( + sundial_model, sundial_config, request_queue + ) + request_pool.start() + self.request_pool_map[self.ACCELERATE_MODEL_ID].append( + (request_pool, request_queue) + ) def _get_strategy(self, model_id, model): if isinstance(model, TimerForPrediction): @@ -145,21 +182,34 @@ class InferenceManager: ): model_id = req.modelId logger.info(f"Start processing for {model_id}") - random.seed(FIX_SEED) - torch.manual_seed(FIX_SEED) - np.random.seed(FIX_SEED) try: raw = data_getter(req) full_data = deserializer(raw) inference_attrs = extract_attrs(req) - # load model - accel = str(inference_attrs.get("acceleration", "")).lower() == "true" - model = self.model_manager.load_model(model_id, inference_attrs, accel) + if model_id == self.ACCELERATE_MODEL_ID and self.DEFAULT_POOL_SIZE > 0: + data = full_data[1][0] + if data.dtype.byteorder not in ("=", "|"): + data = data.byteswap().newbyteorder() + inputs = torch.tensor(data).unsqueeze(0).float().to(self.DEFAULT_DEVICE) + infer_req = InferenceRequest( + req_id=_generate_req_id(), + inputs=inputs, + strategy=TimerSundialStrategy(SundialConfig()), + max_new_tokens=96, + ) + pool_idx = hash(infer_req.id) % len(self.request_pool_map[model_id]) + self.request_pool_map[model_id][pool_idx][1].put(infer_req) + infer_req.wait_for_completion() + outputs = convert_to_binary(pd.DataFrame(infer_req.get_final_output())) + else: + # load model + accel = str(inference_attrs.get("acceleration", "")).lower() == "true" + model = self.model_manager.load_model(model_id, inference_attrs, accel) - # inference by strategy - strategy = self._get_strategy(model_id, model) - outputs = strategy.infer(full_data) + # inference by strategy + strategy = self._get_strategy(model_id, model) + outputs = strategy.infer(full_data) # construct response status = get_status(TSStatusCode.SUCCESS_STATUS) @@ -200,3 +250,12 @@ class InferenceManager: resp_cls=TInferenceResp, single_output=False, ) + + def shutdown(self): + for model_id, pools in self.request_pool_map.items(): + for requestPool, requestQueue in pools: + requestPool.stop() + while not requestQueue.empty(): + requestQueue.get_nowait() + requestQueue.close() + requestQueue.join_thread() diff --git a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py index c903ce3e9dd..21eefef2933 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py @@ -42,7 +42,6 @@ class SundialConfig(PretrainedConfig): flow_loss_depth: int = 3, num_sampling_steps: int = 50, diffusion_batch_mul: int = 4, - ckpt_path: str = None, # weight path **kwargs, ): self.input_token_len = input_token_len @@ -60,7 +59,6 @@ class SundialConfig(PretrainedConfig): self.flow_loss_depth = flow_loss_depth self.num_sampling_steps = num_sampling_steps self.diffusion_batch_mul = diffusion_batch_mul - self.ckpt_path = ckpt_path super().__init__( **kwargs, diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py index 42566c0e1c9..4aed1696af7 100644 --- a/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py @@ -16,13 +16,10 @@ # under the License. # -import os from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file as load_safetensors from torch import nn from transformers import Cache, DynamicCache, PreTrainedModel from transformers.activations import ACT2FN
