This is an automated email from the ASF dual-hosted git repository. Caideyipi pushed a commit to branch hotfix/2.0.9.4-sjzt in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit e3652cd5810b1a5fa4a5109505e91d52cddc0363 Author: 陈荣钊 <[email protected]> AuthorDate: Wed May 27 17:57:15 2026 +0800 [TIMECHODB][AINode] Fix AINode CPU model loading and inference diagnostics --- .../core/inference/inference_request_pool.py | 64 ++++++--- .../iotdb/ainode/core/inference/pool_controller.py | 30 ++++- .../iotdb/ainode/core/manager/inference_manager.py | 11 +- .../ainode/iotdb/ainode/core/model/model_loader.py | 148 ++++++++++++++++++--- .../iotdb/ainode/core/util/batch_executor.py | 7 +- .../ainode/core/manager/inference_manager.py | 12 +- 6 files changed, 226 insertions(+), 46 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 8121d4fecd8..2076f804021 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -19,6 +19,7 @@ import random import threading import time +import traceback from collections import defaultdict from enum import Enum @@ -68,6 +69,7 @@ class InferenceRequestPool(mp.Process): request_queue: mp.Queue, result_queue: mp.Queue, ready_event, + startup_status_queue: mp.Queue = None, **pool_kwargs, ): super().__init__() @@ -75,6 +77,7 @@ class InferenceRequestPool(mp.Process): self.model_info = model_info self.pool_kwargs = pool_kwargs self.ready_event = ready_event + self.startup_status_queue = startup_status_queue self.device = device self._threads = [] @@ -186,29 +189,46 @@ class InferenceRequestPool(mp.Process): self._logger = Logger( INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device) ) - self._backend = DeviceManager() - self._request_scheduler.device = self.device - self._inference_pipeline = load_pipeline(self.model_info, self.device) - self.ready_event.set() + try: + self._backend = DeviceManager() + self._request_scheduler.device = self.device + self._logger.info( + f"[Inference][{self.device}][Pool-{self.pool_id}] Loading inference pipeline for model {self.model_info.model_id}." + ) + self._inference_pipeline = load_pipeline(self.model_info, self.device) + if self.startup_status_queue is not None: + self.startup_status_queue.put({"ok": True}) + self.ready_event.set() - activate_daemon = threading.Thread( - target=self._requests_activate_loop, daemon=True - ) - self._threads.append(activate_daemon) - activate_daemon.start() - execute_daemon = threading.Thread( - target=self._requests_execute_loop, daemon=True - ) - self._threads.append(execute_daemon) - execute_daemon.start() - self._logger.info( - f"[Inference][{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} is activated." - ) - for thread in self._threads: - thread.join() - self._logger.info( - f"[Inference][{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} exited cleanly." - ) + activate_daemon = threading.Thread( + target=self._requests_activate_loop, daemon=True + ) + self._threads.append(activate_daemon) + activate_daemon.start() + execute_daemon = threading.Thread( + target=self._requests_execute_loop, daemon=True + ) + self._threads.append(execute_daemon) + execute_daemon.start() + self._logger.info( + f"[Inference][{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} is activated." + ) + for thread in self._threads: + thread.join() + self._logger.info( + f"[Inference][{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} exited cleanly." + ) + except Exception as e: + error_traceback = traceback.format_exc() + self._logger.error( + f"[Inference][{self.device}][Pool-{self.pool_id}] Failed to start inference pool for model {self.model_info.model_id}: {e}\n{error_traceback}" + ) + if self.startup_status_queue is not None: + self.startup_status_queue.put( + {"ok": False, "error": str(e), "traceback": error_traceback} + ) + self.ready_event.set() + self._stop_event.set() def stop(self): self._stop_event.set() diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 73bf01c3f7a..a0bd5dea54f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -19,6 +19,7 @@ import concurrent import queue import random import threading +import traceback from concurrent.futures import wait from queue import Empty from typing import Dict, Optional @@ -196,17 +197,17 @@ class PoolController: try: task_fn(*args, **kwargs) except Exception as e: - logger.error(f"Error executing task: {e}") + logger.error(f"Error executing task: {e}\n{traceback.format_exc()}") finally: self._task_queue.task_done() def _load_one_model_task(self, model_id: str, device_id_list: list[torch.device]): def _load_one_model_on_device_task(device_id: torch.device): - if not self.has_pool_on_device(device_id): + if not self.has_request_pools(model_id, device_id): self._expand_pools_on_device(model_id, device_id, 1) else: logger.info( - f"[Inference][{device_id}] There are already pools on this device." + f"[Inference][{device_id}] Model {model_id} is already installed." ) load_model_futures = self._executor.submit_batch( @@ -299,6 +300,7 @@ class PoolController: def _expand_pool_on_device(*_): request_queue = mp.Queue() + startup_status_queue = mp.Queue() pool_id = self._new_pool_id.get_and_increment() model_info = self._model_manager.get_model_info(model_id) pool = InferenceRequestPool( @@ -308,6 +310,7 @@ class PoolController: request_queue=request_queue, result_queue=self._result_queue, ready_event=mp.Event(), + startup_status_queue=startup_status_queue, ) pool.start() self._register_pool(model_id, device_id, pool_id, pool, request_queue) @@ -315,9 +318,30 @@ class PoolController: logger.error( f"[Inference][{device_id}][Pool-{pool_id}] Pool failed to be ready in time" ) + pool.terminate() + pool.join(timeout=5) # TODO: retry or decrease the count? this error should be better handled self._erase_pool(model_id, device_id, pool_id) else: + startup_status = {} + try: + startup_status = startup_status_queue.get(timeout=1) + except Empty: + logger.error( + f"[Inference][{device_id}][Pool-{pool_id}] Pool signaled ready without startup status for model {model_id}" + ) + if not startup_status.get("ok", False): + logger.error( + f"[Inference][{device_id}][Pool-{pool_id}] Pool failed to start for model {model_id}. " + f"error={startup_status.get('error', 'unknown')}, " + f"traceback={startup_status.get('traceback', '')}" + ) + pool.join(timeout=5) + if pool.is_alive(): + pool.terminate() + pool.join(timeout=5) + self._erase_pool(model_id, device_id, pool_id) + return self.set_state(model_id, device_id, pool_id, PoolState.RUNNING) logger.info( f"[Inference][{device_id}][Pool-{pool_id}] Pool started running for model {model_id}" diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 8dcf03627dd..64bca30dae8 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -18,6 +18,7 @@ import threading import time +import traceback import torch import torch.multiprocessing as mp @@ -169,7 +170,10 @@ class InferenceManager: outputs = infer_proxy.wait_for_result() return outputs except Exception as e: - logger.error(e) + logger.error( + f"[Inference][Req-{req_id}] Failed to process request for model {req.model_id}: {e}\n" + f"{traceback.format_exc()}" + ) raise InferenceModelInternalException(str(e)) finally: with self._result_wrapper_lock: @@ -258,7 +262,10 @@ class InferenceManager: ) except Exception as e: - logger.error(e) + logger.error( + f"[Inference] Failed to run forecast/inference for model {model_id}: {e}\n" + f"{traceback.format_exc()}" + ) status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) empty = b"" if single_batch else [] return resp_cls(status, empty) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index a3e917459c7..c7ef7f484c8 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -18,6 +18,7 @@ import json import os +import traceback from typing import Any, Dict, Optional, Tuple import torch @@ -51,21 +52,138 @@ logger = Logger() BACKEND = DeviceManager() -def load_model(model_info: ModelInfo, **model_kwargs) -> Any: - if model_info.auto_map is not None: - model = _load_transformers_model(model_info, **model_kwargs) - elif model_info.hub_mixin_cls is not None: - model = _load_hub_mixin_model(model_info, **model_kwargs) - else: - if model_info.model_type == "sktime": - model = create_sktime_model(model_info.model_id) - else: - model = _load_torchscript_model(model_info, **model_kwargs) +def _enum_value(value): + return getattr(value, "value", value) + + +def _format_model_info(model_info: ModelInfo) -> str: + return ( + f"model_id={model_info.model_id}, model_type={model_info.model_type}, " + f"category={_enum_value(model_info.category)}, state={_enum_value(model_info.state)}, " + f"base_model_id={model_info.base_model_id}, has_auto_map={model_info.auto_map is not None}, " + f"has_hub_mixin_cls={model_info.hub_mixin_cls is not None}" + ) + + +def _format_torch_runtime() -> str: + default_device = "unknown" + if hasattr(torch, "get_default_device"): + try: + default_device = str(torch.get_default_device()) + except Exception as e: + default_device = f"unavailable({e})" + return ( + f"torch_version={torch.__version__}, default_dtype={torch.get_default_dtype()}, " + f"default_device={default_device}, backend={BACKEND.type.value}" + ) + +def _format_model_files(model_path: str) -> str: + if not os.path.isdir(model_path): + return "model_path_missing" + target_files = ( + CONFIG_JSON, + *MODEL_WEIGHT_FILES, + ADAPTER_CONFIG, + ADAPTER_SAFETENSORS, + ADAPTER_PT, + ADAPTER_BIN, + TRAINING_STATE, + ) + existing_files = [ + file_name + for file_name in target_files + if os.path.exists(os.path.join(model_path, file_name)) + ] + return ",".join(existing_files) if existing_files else "no_known_model_files_found" + + +def _collect_meta_tensors( + model: torch.nn.Module, limit: int = 10 +) -> Tuple[int, list[str]]: + if not isinstance(model, torch.nn.Module): + return 0, [] + total = 0 + samples = [] + for tensor_type, named_tensors in ( + ("parameter", model.named_parameters(recurse=True)), + ("buffer", model.named_buffers(recurse=True)), + ): + for name, tensor in named_tensors: + if getattr(tensor, "is_meta", False): + total += 1 + if len(samples) < limit: + samples.append( + f"{tensor_type}:{name}, shape={tuple(tensor.shape)}, " + f"dtype={tensor.dtype}, device={tensor.device}" + ) + return total, samples + + +def _first_model_device(model: Any) -> str: + if not isinstance(model, torch.nn.Module): + return "cpu" + for tensor in model.parameters(): + return str(tensor.device) + for tensor in model.buffers(): + return str(tensor.device) + return "no_parameters_or_buffers" + + +def _move_model_with_diagnostics( + model: torch.nn.Module, + model_info: ModelInfo, + model_path: str, + device_map, +) -> torch.nn.Module: logger.info( - f"Model {model_info.model_id} loaded to device {next(model.parameters()).device if model_info.model_type != 'sktime' else 'cpu'} successfully." + f"Moving model to device. {_format_model_info(model_info)}, target_device={device_map}, " + f"model_path={model_path}, model_files={_format_model_files(model_path)}, {_format_torch_runtime()}" ) - return model + meta_count, meta_samples = _collect_meta_tensors(model) + if meta_count: + logger.error( + f"Detected {meta_count} meta tensors before moving model {model_info.model_id} " + f"to {device_map}. samples={meta_samples}" + ) + try: + return BACKEND.move_model(model, device_map) + except Exception as e: + logger.error( + f"Failed to move model {model_info.model_id} to {device_map}: {e}. " + f"{_format_model_info(model_info)}, model_path={model_path}, " + f"model_files={_format_model_files(model_path)}, meta_tensor_count={meta_count}, " + f"meta_tensor_samples={meta_samples}, {_format_torch_runtime()}\n{traceback.format_exc()}" + ) + raise + + +def load_model(model_info: ModelInfo, **model_kwargs) -> Any: + try: + logger.info( + f"Start loading model. {_format_model_info(model_info)}, model_kwargs={model_kwargs}, {_format_torch_runtime()}" + ) + if model_info.auto_map is not None: + model = _load_transformers_model(model_info, **model_kwargs) + elif model_info.hub_mixin_cls is not None: + model = _load_hub_mixin_model(model_info, **model_kwargs) + else: + if model_info.model_type == "sktime": + model = create_sktime_model(model_info.model_id) + else: + model = _load_torchscript_model(model_info, **model_kwargs) + + logger.info( + f"Model {model_info.model_id} loaded to device {_first_model_device(model)} successfully." + ) + return model + except Exception as e: + logger.error( + f"Failed to load model {model_info.model_id}: {e}. " + f"{_format_model_info(model_info)}, model_kwargs={model_kwargs}, {_format_torch_runtime()}\n" + f"{traceback.format_exc()}" + ) + raise def _load_transformers_model(model_info: ModelInfo, **model_kwargs): @@ -109,7 +227,7 @@ def _load_transformers_model(model_info: ModelInfo, **model_kwargs): if has_base_model: model = _apply_adapter(model, model_path) - return BACKEND.move_model(model, device_map) + return _move_model_with_diagnostics(model, model_info, model_path, device_map) def _load_hub_mixin_model(model_info: ModelInfo, **model_kwargs): @@ -121,7 +239,7 @@ def _load_hub_mixin_model(model_info: ModelInfo, **model_kwargs): raise ModelNotExistException(model_info.model_id) # Load model model = model_class.from_pretrained(model_path) - return BACKEND.move_model(model, device_map) + return _move_model_with_diagnostics(model, model_info, model_path, device_map) def _load_torchscript_model(model_info: ModelInfo, **kwargs): @@ -139,7 +257,7 @@ def _load_torchscript_model(model_info: ModelInfo, **kwargs): model = torch.compile(model) except Exception as e: logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") - return BACKEND.move_model(model, device_map) + return _move_model_with_diagnostics(model, model_info, model_path, device_map) def _apply_adapter( diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py b/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py index 0629c893d8c..6048c5a6d1d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py +++ b/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py @@ -18,6 +18,7 @@ import atexit import threading +import traceback from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, Callable, Iterable, List, Optional @@ -113,7 +114,9 @@ class BatchExecutor: try: _ = future.result() except Exception as e: - logger.error(f"Batch task failed (item={item}), because {e}") + logger.error( + f"Batch task failed (item={item}), because {e}\n{traceback.format_exc()}" + ) def _attach_done_callback(self, fut: Future, item: Any) -> None: def _cb(f: Future, _item=item, self_ref=self): @@ -121,7 +124,7 @@ class BatchExecutor: self_ref.on_task_done(_item, f) except Exception as e: logger.error( - f"Error in on_task_done callback (item={_item}), because {e}" + f"Error in on_task_done callback (item={_item}), because {e}\n{traceback.format_exc()}" ) fut.add_done_callback(_cb) diff --git a/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py index ec4dbbf542c..d3b275f05aa 100644 --- a/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py @@ -1,3 +1,5 @@ +import traceback + import torch from iotdb.ainode.core.config import AINodeDescriptor @@ -136,7 +138,10 @@ class TimechoInferenceManager(InferenceManager): [resp_list[0]] if single_batch else resp_list, ) except Exception as e: - logger.error(e) + logger.error( + f"[Inference] Failed to run classify for model {model_id}: {e}\n" + f"{traceback.format_exc()}" + ) status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) empty = b"" if single_batch else [] return resp_cls(status, empty) @@ -190,7 +195,10 @@ class TimechoInferenceManager(InferenceManager): ) except Exception as e: - logger.error(e) + logger.error( + f"[Inference] Failed to run forecast for model {model_id}: {e}\n" + f"{traceback.format_exc()}" + ) status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) empty = b"" if single_batch else [] return resp_cls(status, empty)
