This is an automated email from the ASF dual-hosted git repository.

Caideyipi pushed a commit to branch patch-2094
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit e3652cd5810b1a5fa4a5109505e91d52cddc0363
Author: 陈荣钊 <[email protected]>
AuthorDate: Wed May 27 17:57:15 2026 +0800

    [TIMECHODB][AINode] Fix AINode CPU model loading and inference diagnostics
---
 .../core/inference/inference_request_pool.py       |  64 ++++++---
 .../iotdb/ainode/core/inference/pool_controller.py |  30 ++++-
 .../iotdb/ainode/core/manager/inference_manager.py |  11 +-
 .../ainode/iotdb/ainode/core/model/model_loader.py | 148 ++++++++++++++++++---
 .../iotdb/ainode/core/util/batch_executor.py       |   7 +-
 .../ainode/core/manager/inference_manager.py       |  12 +-
 6 files changed, 226 insertions(+), 46 deletions(-)

diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index 8121d4fecd8..2076f804021 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -19,6 +19,7 @@
 import random
 import threading
 import time
+import traceback
 from collections import defaultdict
 from enum import Enum
 
@@ -68,6 +69,7 @@ class InferenceRequestPool(mp.Process):
         request_queue: mp.Queue,
         result_queue: mp.Queue,
         ready_event,
+        startup_status_queue: mp.Queue = None,
         **pool_kwargs,
     ):
         super().__init__()
@@ -75,6 +77,7 @@ class InferenceRequestPool(mp.Process):
         self.model_info = model_info
         self.pool_kwargs = pool_kwargs
         self.ready_event = ready_event
+        self.startup_status_queue = startup_status_queue
         self.device = device
 
         self._threads = []
@@ -186,29 +189,46 @@ class InferenceRequestPool(mp.Process):
         self._logger = Logger(
             INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
         )
-        self._backend = DeviceManager()
-        self._request_scheduler.device = self.device
-        self._inference_pipeline = load_pipeline(self.model_info, self.device)
-        self.ready_event.set()
+        try:
+            self._backend = DeviceManager()
+            self._request_scheduler.device = self.device
+            self._logger.info(
+                f"[Inference][{self.device}][Pool-{self.pool_id}] Loading 
inference pipeline for model {self.model_info.model_id}."
+            )
+            self._inference_pipeline = load_pipeline(self.model_info, 
self.device)
+            if self.startup_status_queue is not None:
+                self.startup_status_queue.put({"ok": True})
+            self.ready_event.set()
 
-        activate_daemon = threading.Thread(
-            target=self._requests_activate_loop, daemon=True
-        )
-        self._threads.append(activate_daemon)
-        activate_daemon.start()
-        execute_daemon = threading.Thread(
-            target=self._requests_execute_loop, daemon=True
-        )
-        self._threads.append(execute_daemon)
-        execute_daemon.start()
-        self._logger.info(
-            f"[Inference][{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} is activated."
-        )
-        for thread in self._threads:
-            thread.join()
-        self._logger.info(
-            f"[Inference][{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
-        )
+            activate_daemon = threading.Thread(
+                target=self._requests_activate_loop, daemon=True
+            )
+            self._threads.append(activate_daemon)
+            activate_daemon.start()
+            execute_daemon = threading.Thread(
+                target=self._requests_execute_loop, daemon=True
+            )
+            self._threads.append(execute_daemon)
+            execute_daemon.start()
+            self._logger.info(
+                f"[Inference][{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} is activated."
+            )
+            for thread in self._threads:
+                thread.join()
+            self._logger.info(
+                f"[Inference][{self.device}][Pool-{self.pool_id}] 
InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
+            )
+        except Exception as e:
+            error_traceback = traceback.format_exc()
+            self._logger.error(
+                f"[Inference][{self.device}][Pool-{self.pool_id}] Failed to 
start inference pool for model {self.model_info.model_id}: 
{e}\n{error_traceback}"
+            )
+            if self.startup_status_queue is not None:
+                self.startup_status_queue.put(
+                    {"ok": False, "error": str(e), "traceback": 
error_traceback}
+                )
+            self.ready_event.set()
+            self._stop_event.set()
 
     def stop(self):
         self._stop_event.set()
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index 73bf01c3f7a..a0bd5dea54f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -19,6 +19,7 @@ import concurrent
 import queue
 import random
 import threading
+import traceback
 from concurrent.futures import wait
 from queue import Empty
 from typing import Dict, Optional
@@ -196,17 +197,17 @@ class PoolController:
             try:
                 task_fn(*args, **kwargs)
             except Exception as e:
-                logger.error(f"Error executing task: {e}")
+                logger.error(f"Error executing task: 
{e}\n{traceback.format_exc()}")
             finally:
                 self._task_queue.task_done()
 
     def _load_one_model_task(self, model_id: str, device_id_list: 
list[torch.device]):
         def _load_one_model_on_device_task(device_id: torch.device):
-            if not self.has_pool_on_device(device_id):
+            if not self.has_request_pools(model_id, device_id):
                 self._expand_pools_on_device(model_id, device_id, 1)
             else:
                 logger.info(
-                    f"[Inference][{device_id}] There are already pools on this 
device."
+                    f"[Inference][{device_id}] Model {model_id} is already 
installed."
                 )
 
         load_model_futures = self._executor.submit_batch(
@@ -299,6 +300,7 @@ class PoolController:
 
         def _expand_pool_on_device(*_):
             request_queue = mp.Queue()
+            startup_status_queue = mp.Queue()
             pool_id = self._new_pool_id.get_and_increment()
             model_info = self._model_manager.get_model_info(model_id)
             pool = InferenceRequestPool(
@@ -308,6 +310,7 @@ class PoolController:
                 request_queue=request_queue,
                 result_queue=self._result_queue,
                 ready_event=mp.Event(),
+                startup_status_queue=startup_status_queue,
             )
             pool.start()
             self._register_pool(model_id, device_id, pool_id, pool, 
request_queue)
@@ -315,9 +318,30 @@ class PoolController:
                 logger.error(
                     f"[Inference][{device_id}][Pool-{pool_id}] Pool failed to 
be ready in time"
                 )
+                pool.terminate()
+                pool.join(timeout=5)
                 # TODO: retry or decrease the count? this error should be 
better handled
                 self._erase_pool(model_id, device_id, pool_id)
             else:
+                startup_status = {}
+                try:
+                    startup_status = startup_status_queue.get(timeout=1)
+                except Empty:
+                    logger.error(
+                        f"[Inference][{device_id}][Pool-{pool_id}] Pool 
signaled ready without startup status for model {model_id}"
+                    )
+                if not startup_status.get("ok", False):
+                    logger.error(
+                        f"[Inference][{device_id}][Pool-{pool_id}] Pool failed 
to start for model {model_id}. "
+                        f"error={startup_status.get('error', 'unknown')}, "
+                        f"traceback={startup_status.get('traceback', '')}"
+                    )
+                    pool.join(timeout=5)
+                    if pool.is_alive():
+                        pool.terminate()
+                        pool.join(timeout=5)
+                    self._erase_pool(model_id, device_id, pool_id)
+                    return
                 self.set_state(model_id, device_id, pool_id, PoolState.RUNNING)
                 logger.info(
                     f"[Inference][{device_id}][Pool-{pool_id}] Pool started 
running for model {model_id}"
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index 8dcf03627dd..64bca30dae8 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -18,6 +18,7 @@
 
 import threading
 import time
+import traceback
 
 import torch
 import torch.multiprocessing as mp
@@ -169,7 +170,10 @@ class InferenceManager:
             outputs = infer_proxy.wait_for_result()
             return outputs
         except Exception as e:
-            logger.error(e)
+            logger.error(
+                f"[Inference][Req-{req_id}] Failed to process request for 
model {req.model_id}: {e}\n"
+                f"{traceback.format_exc()}"
+            )
             raise InferenceModelInternalException(str(e))
         finally:
             with self._result_wrapper_lock:
@@ -258,7 +262,10 @@ class InferenceManager:
             )
 
         except Exception as e:
-            logger.error(e)
+            logger.error(
+                f"[Inference] Failed to run forecast/inference for model 
{model_id}: {e}\n"
+                f"{traceback.format_exc()}"
+            )
             status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
             empty = b"" if single_batch else []
             return resp_cls(status, empty)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
index a3e917459c7..c7ef7f484c8 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
@@ -18,6 +18,7 @@
 
 import json
 import os
+import traceback
 from typing import Any, Dict, Optional, Tuple
 
 import torch
@@ -51,21 +52,138 @@ logger = Logger()
 BACKEND = DeviceManager()
 
 
-def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
-    if model_info.auto_map is not None:
-        model = _load_transformers_model(model_info, **model_kwargs)
-    elif model_info.hub_mixin_cls is not None:
-        model = _load_hub_mixin_model(model_info, **model_kwargs)
-    else:
-        if model_info.model_type == "sktime":
-            model = create_sktime_model(model_info.model_id)
-        else:
-            model = _load_torchscript_model(model_info, **model_kwargs)
+def _enum_value(value):
+    return getattr(value, "value", value)
+
+
+def _format_model_info(model_info: ModelInfo) -> str:
+    return (
+        f"model_id={model_info.model_id}, model_type={model_info.model_type}, "
+        f"category={_enum_value(model_info.category)}, 
state={_enum_value(model_info.state)}, "
+        f"base_model_id={model_info.base_model_id}, 
has_auto_map={model_info.auto_map is not None}, "
+        f"has_hub_mixin_cls={model_info.hub_mixin_cls is not None}"
+    )
+
+
+def _format_torch_runtime() -> str:
+    default_device = "unknown"
+    if hasattr(torch, "get_default_device"):
+        try:
+            default_device = str(torch.get_default_device())
+        except Exception as e:
+            default_device = f"unavailable({e})"
+    return (
+        f"torch_version={torch.__version__}, 
default_dtype={torch.get_default_dtype()}, "
+        f"default_device={default_device}, backend={BACKEND.type.value}"
+    )
 
+
+def _format_model_files(model_path: str) -> str:
+    if not os.path.isdir(model_path):
+        return "model_path_missing"
+    target_files = (
+        CONFIG_JSON,
+        *MODEL_WEIGHT_FILES,
+        ADAPTER_CONFIG,
+        ADAPTER_SAFETENSORS,
+        ADAPTER_PT,
+        ADAPTER_BIN,
+        TRAINING_STATE,
+    )
+    existing_files = [
+        file_name
+        for file_name in target_files
+        if os.path.exists(os.path.join(model_path, file_name))
+    ]
+    return ",".join(existing_files) if existing_files else 
"no_known_model_files_found"
+
+
+def _collect_meta_tensors(
+    model: torch.nn.Module, limit: int = 10
+) -> Tuple[int, list[str]]:
+    if not isinstance(model, torch.nn.Module):
+        return 0, []
+    total = 0
+    samples = []
+    for tensor_type, named_tensors in (
+        ("parameter", model.named_parameters(recurse=True)),
+        ("buffer", model.named_buffers(recurse=True)),
+    ):
+        for name, tensor in named_tensors:
+            if getattr(tensor, "is_meta", False):
+                total += 1
+                if len(samples) < limit:
+                    samples.append(
+                        f"{tensor_type}:{name}, shape={tuple(tensor.shape)}, "
+                        f"dtype={tensor.dtype}, device={tensor.device}"
+                    )
+    return total, samples
+
+
+def _first_model_device(model: Any) -> str:
+    if not isinstance(model, torch.nn.Module):
+        return "cpu"
+    for tensor in model.parameters():
+        return str(tensor.device)
+    for tensor in model.buffers():
+        return str(tensor.device)
+    return "no_parameters_or_buffers"
+
+
+def _move_model_with_diagnostics(
+    model: torch.nn.Module,
+    model_info: ModelInfo,
+    model_path: str,
+    device_map,
+) -> torch.nn.Module:
     logger.info(
-        f"Model {model_info.model_id} loaded to device 
{next(model.parameters()).device if model_info.model_type != 'sktime' else 
'cpu'} successfully."
+        f"Moving model to device. {_format_model_info(model_info)}, 
target_device={device_map}, "
+        f"model_path={model_path}, 
model_files={_format_model_files(model_path)}, {_format_torch_runtime()}"
     )
-    return model
+    meta_count, meta_samples = _collect_meta_tensors(model)
+    if meta_count:
+        logger.error(
+            f"Detected {meta_count} meta tensors before moving model 
{model_info.model_id} "
+            f"to {device_map}. samples={meta_samples}"
+        )
+    try:
+        return BACKEND.move_model(model, device_map)
+    except Exception as e:
+        logger.error(
+            f"Failed to move model {model_info.model_id} to {device_map}: {e}. 
"
+            f"{_format_model_info(model_info)}, model_path={model_path}, "
+            f"model_files={_format_model_files(model_path)}, 
meta_tensor_count={meta_count}, "
+            f"meta_tensor_samples={meta_samples}, 
{_format_torch_runtime()}\n{traceback.format_exc()}"
+        )
+        raise
+
+
+def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
+    try:
+        logger.info(
+            f"Start loading model. {_format_model_info(model_info)}, 
model_kwargs={model_kwargs}, {_format_torch_runtime()}"
+        )
+        if model_info.auto_map is not None:
+            model = _load_transformers_model(model_info, **model_kwargs)
+        elif model_info.hub_mixin_cls is not None:
+            model = _load_hub_mixin_model(model_info, **model_kwargs)
+        else:
+            if model_info.model_type == "sktime":
+                model = create_sktime_model(model_info.model_id)
+            else:
+                model = _load_torchscript_model(model_info, **model_kwargs)
+
+        logger.info(
+            f"Model {model_info.model_id} loaded to device 
{_first_model_device(model)} successfully."
+        )
+        return model
+    except Exception as e:
+        logger.error(
+            f"Failed to load model {model_info.model_id}: {e}. "
+            f"{_format_model_info(model_info)}, model_kwargs={model_kwargs}, 
{_format_torch_runtime()}\n"
+            f"{traceback.format_exc()}"
+        )
+        raise
 
 
 def _load_transformers_model(model_info: ModelInfo, **model_kwargs):
@@ -109,7 +227,7 @@ def _load_transformers_model(model_info: ModelInfo, 
**model_kwargs):
     if has_base_model:
         model = _apply_adapter(model, model_path)
 
-    return BACKEND.move_model(model, device_map)
+    return _move_model_with_diagnostics(model, model_info, model_path, 
device_map)
 
 
 def _load_hub_mixin_model(model_info: ModelInfo, **model_kwargs):
@@ -121,7 +239,7 @@ def _load_hub_mixin_model(model_info: ModelInfo, 
**model_kwargs):
         raise ModelNotExistException(model_info.model_id)
     # Load model
     model = model_class.from_pretrained(model_path)
-    return BACKEND.move_model(model, device_map)
+    return _move_model_with_diagnostics(model, model_info, model_path, 
device_map)
 
 
 def _load_torchscript_model(model_info: ModelInfo, **kwargs):
@@ -139,7 +257,7 @@ def _load_torchscript_model(model_info: ModelInfo, 
**kwargs):
         model = torch.compile(model)
     except Exception as e:
         logger.warning(f"acceleration failed, fallback to normal mode: 
{str(e)}")
-    return BACKEND.move_model(model, device_map)
+    return _move_model_with_diagnostics(model, model_info, model_path, 
device_map)
 
 
 def _apply_adapter(
diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py 
b/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py
index 0629c893d8c..6048c5a6d1d 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/util/batch_executor.py
@@ -18,6 +18,7 @@
 
 import atexit
 import threading
+import traceback
 from concurrent.futures import Future, ThreadPoolExecutor
 from typing import Any, Callable, Iterable, List, Optional
 
@@ -113,7 +114,9 @@ class BatchExecutor:
         try:
             _ = future.result()
         except Exception as e:
-            logger.error(f"Batch task failed (item={item}), because {e}")
+            logger.error(
+                f"Batch task failed (item={item}), because 
{e}\n{traceback.format_exc()}"
+            )
 
     def _attach_done_callback(self, fut: Future, item: Any) -> None:
         def _cb(f: Future, _item=item, self_ref=self):
@@ -121,7 +124,7 @@ class BatchExecutor:
                 self_ref.on_task_done(_item, f)
             except Exception as e:
                 logger.error(
-                    f"Error in on_task_done callback (item={_item}), because 
{e}"
+                    f"Error in on_task_done callback (item={_item}), because 
{e}\n{traceback.format_exc()}"
                 )
 
         fut.add_done_callback(_cb)
diff --git a/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py
index ec4dbbf542c..d3b275f05aa 100644
--- a/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/timecho/ainode/core/manager/inference_manager.py
@@ -1,3 +1,5 @@
+import traceback
+
 import torch
 
 from iotdb.ainode.core.config import AINodeDescriptor
@@ -136,7 +138,10 @@ class TimechoInferenceManager(InferenceManager):
                 [resp_list[0]] if single_batch else resp_list,
             )
         except Exception as e:
-            logger.error(e)
+            logger.error(
+                f"[Inference] Failed to run classify for model {model_id}: 
{e}\n"
+                f"{traceback.format_exc()}"
+            )
             status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
             empty = b"" if single_batch else []
             return resp_cls(status, empty)
@@ -190,7 +195,10 @@ class TimechoInferenceManager(InferenceManager):
             )
 
         except Exception as e:
-            logger.error(e)
+            logger.error(
+                f"[Inference] Failed to run forecast for model {model_id}: 
{e}\n"
+                f"{traceback.format_exc()}"
+            )
             status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
             empty = b"" if single_batch else []
             return resp_cls(status, empty)

Reply via email to