This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 76d02a58f00 Modify pipeline and automap (#16923)
76d02a58f00 is described below

commit 76d02a58f00a61b3b19fefcdc80f0f857e05e7dd
Author: Gewu <[email protected]>
AuthorDate: Wed Dec 17 23:03:44 2025 +0800

    Modify pipeline and automap (#16923)
---
 .../ainode/core/inference/inference_request.py     | 19 ++++-----
 .../core/inference/inference_request_pool.py       |  4 ++
 .../core/inference/pipeline/basic_pipeline.py      | 43 ++++++++++++++------
 .../iotdb/ainode/core/manager/inference_manager.py | 46 ++++++++++------------
 .../ainode/iotdb/ainode/core/model/model_info.py   | 16 ++++----
 .../ainode/iotdb/ainode/core/model/model_loader.py | 17 ++++----
 .../ainode/core/model/sktime/pipeline_sktime.py    | 45 ++++++++++++---------
 .../ainode/core/model/sundial/pipeline_sundial.py  | 30 +++++++++-----
 .../ainode/core/model/timer_xl/pipeline_timer.py   | 29 +++++++++-----
 iotdb-core/ainode/iotdb/ainode/core/util/serde.py  | 24 +++++++++--
 10 files changed, 167 insertions(+), 106 deletions(-)

diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
index 50634914c27..93887477aa5 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
@@ -42,7 +42,7 @@ class InferenceRequest:
         output_length: int = 96,
         **infer_kwargs,
     ):
-        if inputs.ndim == 1:
+        while inputs.ndim < 3:
             inputs = inputs.unsqueeze(0)
 
         self.req_id = req_id
@@ -54,6 +54,7 @@ class InferenceRequest:
         )
 
         self.batch_size = inputs.size(0)
+        self.variable_size = inputs.size(1)
         self.state = InferenceRequestState.WAITING
         self.cur_step_idx = 0  # Current write position in the output step 
index
         self.assigned_pool_id = -1  # The pool handling this request
@@ -61,8 +62,8 @@ class InferenceRequest:
 
         # Preallocate output buffer [batch_size, max_new_tokens]
         self.output_tensor = torch.zeros(
-            self.batch_size, output_length, device="cpu"
-        )  # shape: [self.batch_size, max_new_steps]
+            self.batch_size, self.variable_size, output_length, device="cpu"
+        )  # shape: [batch_size, target_count, predict_length]
 
     def mark_running(self):
         self.state = InferenceRequestState.RUNNING
@@ -77,26 +78,26 @@ class InferenceRequest:
         )
 
     def write_step_output(self, step_output: torch.Tensor):
-        if step_output.ndim == 1:
+        while step_output.ndim < 3:
             step_output = step_output.unsqueeze(0)
 
-        batch_size, step_size = step_output.shape
+        batch_size, variable_size, step_size = step_output.shape
         end_idx = self.cur_step_idx + step_size
 
         if end_idx > self.output_length:
-            self.output_tensor[:, self.cur_step_idx :] = step_output[
-                :, : self.output_length - self.cur_step_idx
+            self.output_tensor[:, :, self.cur_step_idx :] = step_output[
+                :, :, : self.output_length - self.cur_step_idx
             ]
             self.cur_step_idx = self.output_length
         else:
-            self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
+            self.output_tensor[:, :, self.cur_step_idx : end_idx] = step_output
             self.cur_step_idx = end_idx
 
         if self.is_finished():
             self.mark_finished()
 
     def get_final_output(self) -> torch.Tensor:
-        return self.output_tensor[:, : self.cur_step_idx]
+        return self.output_tensor[:, :, : self.cur_step_idx]
 
 
 class InferenceRequestProxy:
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index c31bcd3d762..3cca9b183c8 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -123,6 +123,7 @@ class InferenceRequestPool(mp.Process):
             batch_inputs = self._batcher.batch_request(requests).to(
                 "cpu"
             )  # The input data should first load to CPU in current version
+            batch_inputs = self._inference_pipeline.preprocess(batch_inputs)
             if isinstance(self._inference_pipeline, ForecastPipeline):
                 batch_output = self._inference_pipeline.forecast(
                     batch_inputs,
@@ -140,7 +141,10 @@ class InferenceRequestPool(mp.Process):
                     # more infer kwargs can be added here
                 )
             else:
+                batch_output = None
                 self._logger.error("[Inference] Unsupported pipeline type.")
+            batch_output = self._inference_pipeline.postprocess(batch_output)
+
             offset = 0
             for request in requests:
                 request.output_tensor = request.output_tensor.to(self.device)
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
index 82601e39805..d7345d3140f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -20,6 +20,7 @@ from abc import ABC, abstractmethod
 
 import torch
 
+from iotdb.ainode.core.exception import InferenceModelInternalException
 from iotdb.ainode.core.model.model_loader import load_model
 
 
@@ -29,59 +30,75 @@ class BasicPipeline(ABC):
         self.device = model_kwargs.get("device", "cpu")
         self.model = load_model(model_info, device_map=self.device, 
**model_kwargs)
 
-    def _preprocess(self, inputs):
+    @abstractmethod
+    def preprocess(self, inputs):
         """
         Preprocess the input before inference, including shape validation and 
value transformation.
         """
-        return inputs
+        raise NotImplementedError("preprocess not implemented")
 
-    def _postprocess(self, output: torch.Tensor):
+    @abstractmethod
+    def postprocess(self, outputs: torch.Tensor):
         """
         Post-process the outputs after the entire inference task.
         """
-        return output
+        raise NotImplementedError("postprocess not implemented")
 
 
 class ForecastPipeline(BasicPipeline):
     def __init__(self, model_info, **model_kwargs):
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
+    def preprocess(self, inputs):
+        """
+        The inputs should be 3D tensor: [batch_size, target_count, 
sequence_length].
+        """
+        if len(inputs.shape) != 3:
+            raise InferenceModelInternalException(
+                f"[Inference] Input must be: [batch_size, target_count, 
sequence_length], but receives {inputs.shape}"
+            )
         return inputs
 
     @abstractmethod
     def forecast(self, inputs, **infer_kwargs):
         pass
 
-    def _postprocess(self, output: torch.Tensor):
-        return output
+    def postprocess(self, outputs: torch.Tensor):
+        """
+        The outputs should be 3D tensor: [batch_size, target_count, 
predict_length].
+        """
+        if len(outputs.shape) != 3:
+            raise InferenceModelInternalException(
+                f"[Inference] Output must be: [batch_size, target_count, 
predict_length], but receives {outputs.shape}"
+            )
+        return outputs
 
 
 class ClassificationPipeline(BasicPipeline):
     def __init__(self, model_info, **model_kwargs):
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
+    def preprocess(self, inputs):
         return inputs
 
     @abstractmethod
     def classify(self, inputs, **kwargs):
         pass
 
-    def _postprocess(self, output: torch.Tensor):
-        return output
+    def postprocess(self, outputs: torch.Tensor):
+        return outputs
 
 
 class ChatPipeline(BasicPipeline):
     def __init__(self, model_info, **model_kwargs):
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
+    def preprocess(self, inputs):
         return inputs
 
     @abstractmethod
     def chat(self, inputs, **kwargs):
         pass
 
-    def _postprocess(self, output: torch.Tensor):
-        return output
+    def postprocess(self, outputs: torch.Tensor):
+        return outputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index 34c315274f5..d3f77e993a1 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -46,7 +46,10 @@ from iotdb.ainode.core.log import Logger
 from iotdb.ainode.core.manager.model_manager import ModelManager
 from iotdb.ainode.core.rpc.status import get_status
 from iotdb.ainode.core.util.gpu_mapping import get_available_devices
-from iotdb.ainode.core.util.serde import convert_to_binary
+from iotdb.ainode.core.util.serde import (
+    convert_tensor_to_tsblock,
+    convert_tsblock_to_tensor,
+)
 from iotdb.thrift.ainode.ttypes import (
     TForecastReq,
     TForecastResp,
@@ -58,7 +61,6 @@ from iotdb.thrift.ainode.ttypes import (
     TUnloadModelReq,
 )
 from iotdb.thrift.common.ttypes import TSStatus
-from iotdb.tsfile.utils.tsblock_serde import deserialize
 
 logger = Logger()
 
@@ -170,23 +172,14 @@ class InferenceManager:
         self,
         req,
         data_getter,
-        deserializer,
         extract_attrs,
         resp_cls,
-        single_output: bool,
+        single_batch: bool,
     ):
         model_id = req.modelId
         try:
             raw = data_getter(req)
-            # full data deserialized from iotdb is composed of [timestampList, 
valueList, None, length], we only get valueList currently.
-            full_data = deserializer(raw)
-            # TODO: TSBlock -> Tensor codes should be unified
-            data = full_data[1][0]  # get valueList in ndarray
-            if data.dtype.byteorder not in ("=", "|"):
-                np_data = data.byteswap()
-                data = np_data.view(np_data.dtype.newbyteorder())
-            # the inputs should be on CPU before passing to the inference 
request
-            inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
+            inputs = convert_tsblock_to_tensor(raw)
 
             inference_attrs = extract_attrs(req)
             output_length = int(inference_attrs.pop("output_length", 96))
@@ -211,10 +204,10 @@ class InferenceManager:
                     output_length=output_length,
                 )
                 outputs = self._process_request(infer_req)
-                outputs = convert_to_binary(pd.DataFrame(outputs[0]))
             else:
                 model_info = self._model_manager.get_model_info(model_id)
                 inference_pipeline = load_pipeline(model_info, device="cpu")
+                inputs = inference_pipeline.preprocess(inputs)
                 if isinstance(inference_pipeline, ForecastPipeline):
                     outputs = inference_pipeline.forecast(
                         inputs, predict_length=output_length, **inference_attrs
@@ -224,46 +217,49 @@ class InferenceManager:
                 elif isinstance(inference_pipeline, ChatPipeline):
                     outputs = inference_pipeline.chat(inputs)
                 else:
+                    outputs = None
                     logger.error("[Inference] Unsupported pipeline type.")
-                outputs = convert_to_binary(pd.DataFrame(outputs[0]))
+                outputs = inference_pipeline.postprocess(outputs)
 
-            # construct response
-            status = get_status(TSStatusCode.SUCCESS_STATUS)
+            # convert tensor into tsblock for the output in each batch
+            output_list = []
+            for batch_idx in range(outputs.size(0)):
+                output = convert_tensor_to_tsblock(outputs[batch_idx])
+                output_list.append(output)
 
-            if isinstance(outputs, list):
-                return resp_cls(status, outputs[0] if single_output else 
outputs)
-            return resp_cls(status, outputs if single_output else [outputs])
+            return resp_cls(
+                get_status(TSStatusCode.SUCCESS_STATUS),
+                output_list[0] if single_batch else output_list,
+            )
 
         except Exception as e:
             logger.error(e)
             status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
-            empty = b"" if single_output else []
+            empty = b"" if single_batch else []
             return resp_cls(status, empty)
 
     def forecast(self, req: TForecastReq):
         return self._run(
             req,
             data_getter=lambda r: r.inputData,
-            deserializer=deserialize,
             extract_attrs=lambda r: {
                 "output_length": r.outputLength,
                 **(r.options or {}),
             },
             resp_cls=TForecastResp,
-            single_output=True,
+            single_batch=True,
         )
 
     def inference(self, req: TInferenceReq):
         return self._run(
             req,
             data_getter=lambda r: r.dataset,
-            deserializer=deserialize,
             extract_attrs=lambda r: {
                 "output_length": int(r.inferenceAttributes.pop("outputLength", 
96)),
                 **(r.inferenceAttributes or {}),
             },
             resp_cls=TInferenceResp,
-            single_output=False,
+            single_batch=False,
         )
 
     def stop(self):
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index 718ead530dd..697b3671275 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -28,8 +28,6 @@ class ModelInfo:
         category: ModelCategory,
         state: ModelStates,
         model_type: str = "",
-        config_cls: str = "",
-        model_cls: str = "",
         pipeline_cls: str = "",
         repo_id: str = "",
         auto_map: Optional[Dict] = None,
@@ -39,8 +37,6 @@ class ModelInfo:
         self.model_type = model_type
         self.category = category
         self.state = state
-        self.config_cls = config_cls
-        self.model_cls = model_cls
         self.pipeline_cls = pipeline_cls
         self.repo_id = repo_id
         self.auto_map = auto_map  # If exists, indicates it's a Transformers 
model
@@ -114,19 +110,23 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
         category=ModelCategory.BUILTIN,
         state=ModelStates.INACTIVE,
         model_type="timer",
-        config_cls="configuration_timer.TimerConfig",
-        model_cls="modeling_timer.TimerForPrediction",
         pipeline_cls="pipeline_timer.TimerPipeline",
         repo_id="thuml/timer-base-84m",
+        auto_map={
+            "AutoConfig": "configuration_timer.TimerConfig",
+            "AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
+        },
     ),
     "sundial": ModelInfo(
         model_id="sundial",
         category=ModelCategory.BUILTIN,
         state=ModelStates.INACTIVE,
         model_type="sundial",
-        config_cls="configuration_sundial.SundialConfig",
-        model_cls="modeling_sundial.SundialForPrediction",
         pipeline_cls="pipeline_sundial.SundialPipeline",
         repo_id="thuml/sundial-base-128m",
+        auto_map={
+            "AutoConfig": "configuration_sundial.SundialConfig",
+            "AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
+        },
     ),
 }
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
index 29a7c14c972..605620d4261 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
@@ -69,23 +69,22 @@ def load_model_from_transformers(model_info: ModelInfo, 
**model_kwargs):
         model_info.model_id,
     )
 
+    config_str = model_info.auto_map.get("AutoConfig", "")
+    model_str = model_info.auto_map.get("AutoModelForCausalLM", "")
+
     if model_info.category == ModelCategory.BUILTIN:
         module_name = (
             AINodeDescriptor().get_config().get_ain_models_builtin_dir()
             + "."
             + model_info.model_id
         )
-        config_cls = import_class_from_path(module_name, model_info.config_cls)
-        model_cls = import_class_from_path(module_name, model_info.model_cls)
-    elif model_info.model_cls and model_info.config_cls:
+        config_cls = import_class_from_path(module_name, config_str)
+        model_cls = import_class_from_path(module_name, model_str)
+    elif model_str and config_str:
         module_parent = str(Path(model_path).parent.absolute())
         with temporary_sys_path(module_parent):
-            config_cls = import_class_from_path(
-                model_info.model_id, model_info.config_cls
-            )
-            model_cls = import_class_from_path(
-                model_info.model_id, model_info.model_cls
-            )
+            config_cls = import_class_from_path(model_info.model_id, 
config_str)
+            model_cls = import_class_from_path(model_info.model_id, model_str)
     else:
         config_cls = AutoConfig.from_pretrained(model_path)
         if type(config_cls) in 
AutoModelForTimeSeriesPrediction._model_mapping.keys():
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
index ced21f29a2b..a10a0a134a2 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
@@ -20,6 +20,7 @@ import numpy as np
 import pandas as pd
 import torch
 
+from iotdb.ainode.core.exception import InferenceModelInternalException
 from iotdb.ainode.core.inference.pipeline.basic_pipeline import 
ForecastPipeline
 
 
@@ -28,41 +29,47 @@ class SktimePipeline(ForecastPipeline):
         model_kwargs.pop("device", None)  # sktime models run on CPU
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
+    def preprocess(self, inputs):
+        inputs = super().preprocess(inputs)
+        if inputs.shape[1] != 1:
+            raise InferenceModelInternalException(
+                f"[Inference] Sktime model only supports univarate forecast, 
but receives {inputs.shape[1]} target variables."
+            )
+        inputs = inputs.squeeze(1)
         return inputs
 
     def forecast(self, inputs, **infer_kwargs):
         predict_length = infer_kwargs.get("predict_length", 96)
-        input_ids = self._preprocess(inputs)
 
         # Convert to pandas Series for sktime (sktime expects Series or 
DataFrame)
         # Handle batch dimension: if batch_size > 1, process each sample 
separately
-        if len(input_ids.shape) == 2 and input_ids.shape[0] > 1:
+        if len(inputs.shape) == 2 and inputs.shape[0] > 1:
             # Batch processing: convert each row to Series
             outputs = []
-            for i in range(input_ids.shape[0]):
+            for i in range(inputs.shape[0]):
                 series = pd.Series(
-                    input_ids[i].cpu().numpy()
-                    if isinstance(input_ids, torch.Tensor)
-                    else input_ids[i]
+                    inputs[i].cpu().numpy()
+                    if isinstance(inputs, torch.Tensor)
+                    else inputs[i]
                 )
                 output = self.model.generate(series, 
predict_length=predict_length)
                 outputs.append(output)
-            output = np.array(outputs)
+            outputs = np.array(outputs)
         else:
             # Single sample: convert to Series
-            if isinstance(input_ids, torch.Tensor):
-                series = pd.Series(input_ids.squeeze().cpu().numpy())
+            if isinstance(inputs, torch.Tensor):
+                series = pd.Series(inputs.squeeze().cpu().numpy())
             else:
-                series = pd.Series(input_ids.squeeze())
-            output = self.model.generate(series, predict_length=predict_length)
+                series = pd.Series(inputs.squeeze())
+            outputs = self.model.generate(series, 
predict_length=predict_length)
             # Add batch dimension if needed
-            if len(output.shape) == 1:
-                output = output[np.newaxis, :]
+            if len(outputs.shape) == 1:
+                outputs = outputs[np.newaxis, :]
 
-        return self._postprocess(output)
+        return outputs
 
-    def _postprocess(self, output):
-        if isinstance(output, np.ndarray):
-            return torch.from_numpy(output).float()
-        return output
+    def postprocess(self, outputs):
+        if isinstance(outputs, np.ndarray):
+            outputs = torch.from_numpy(outputs).float()
+        outputs = super().postprocess(outputs.unsqueeze(1))
+        return outputs
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
index ee128802d24..69422dfadb2 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
@@ -26,11 +26,17 @@ class SundialPipeline(ForecastPipeline):
     def __init__(self, model_info, **model_kwargs):
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
-        if len(inputs.shape) != 2:
+    def preprocess(self, inputs):
+        """
+        The inputs shape should be 3D, but Sundial only supports 2D tensor: 
[batch_size, sequence_length],
+        we need to squeeze the target_count dimension.
+        """
+        inputs = super().preprocess(inputs)
+        if inputs.shape[1] != 1:
             raise InferenceModelInternalException(
-                f"[Inference] Input shape must be: [batch_size, seq_len], but 
receives {inputs.shape}"
+                f"[Inference] Model sundial only supports univarate forecast, 
but receives {inputs.shape[1]} target variables."
             )
+        inputs = inputs.squeeze(1)
         return inputs
 
     def forecast(self, inputs, **infer_kwargs):
@@ -38,14 +44,18 @@ class SundialPipeline(ForecastPipeline):
         num_samples = infer_kwargs.get("num_samples", 10)
         revin = infer_kwargs.get("revin", True)
 
-        input_ids = self._preprocess(inputs)
-        output = self.model.generate(
-            input_ids,
+        outputs = self.model.generate(
+            inputs,
             max_new_tokens=predict_length,
             num_samples=num_samples,
             revin=revin,
         )
-        return self._postprocess(output)
-
-    def _postprocess(self, output: torch.Tensor):
-        return output.mean(dim=1)
+        return outputs
+
+    def postprocess(self, outputs: torch.Tensor):
+        """
+        The outputs shape should be 3D, we need to take the mean value across 
num_samples dimension and expand dims.
+        """
+        outputs = outputs.mean(dim=1).unsqueeze(1)
+        outputs = super().postprocess(outputs)
+        return outputs
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
index 65c6cdd74cd..041cb8c7cba 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
@@ -26,22 +26,31 @@ class TimerPipeline(ForecastPipeline):
     def __init__(self, model_info, **model_kwargs):
         super().__init__(model_info, model_kwargs=model_kwargs)
 
-    def _preprocess(self, inputs):
-        if len(inputs.shape) != 2:
+    def preprocess(self, inputs):
+        """
+        The inputs shape should be 3D, but Timer-XL only supports 2D tensor: 
[batch_size, sequence_length],
+        we need to squeeze the target_count dimension.
+        """
+        inputs = super().preprocess(inputs)
+        if inputs.shape[1] != 1:
             raise InferenceModelInternalException(
-                f"[Inference] Input shape must be: [batch_size, seq_len], but 
receives {inputs.shape}"
+                f"[Inference] Model timer_xl only supports univarate forecast, 
but receives {inputs.shape[1]} target variables."
             )
+        inputs = inputs.squeeze(1)
         return inputs
 
     def forecast(self, inputs, **infer_kwargs):
         predict_length = infer_kwargs.get("predict_length", 96)
         revin = infer_kwargs.get("revin", True)
 
-        input_ids = self._preprocess(inputs)
-        output = self.model.generate(
-            input_ids, max_new_tokens=predict_length, revin=revin
+        outputs = self.model.generate(
+            inputs, max_new_tokens=predict_length, revin=revin
         )
-        return self._postprocess(output)
-
-    def _postprocess(self, output: torch.Tensor):
-        return output
+        return outputs
+
+    def postprocess(self, outputs: torch.Tensor):
+        """
+        The outputs shape should be 3D, so we need to expand dims.
+        """
+        outputs = super().postprocess(outputs.unsqueeze(1))
+        return outputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py 
b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
index 9c6020019fc..a61032ba26f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py
@@ -20,8 +20,10 @@ from enum import Enum
 
 import numpy as np
 import pandas as pd
+import torch
 
 from iotdb.ainode.core.exception import BadConfigValueException
+from iotdb.tsfile.utils.tsblock_serde import deserialize
 
 
 class TSDataType(Enum):
@@ -55,9 +57,25 @@ TIMESTAMP_STR = "Time"
 START_INDEX = 2
 
 
-# convert dataFrame to tsBlock in binary
-# input shouldn't contain time column
-def convert_to_binary(data_frame: pd.DataFrame):
+# Full data deserialized from iotdb tsblock is composed of [timestampList, 
multiple valueList, None, length].
+# We only get valueList currently.
+def convert_tsblock_to_tensor(tsblock_data: bytes):
+    full_data = deserialize(tsblock_data)
+    # ensure the byteorder is correct.
+    for i, data in enumerate(full_data[1]):
+        if data.dtype.byteorder not in ("=", "|"):
+            np_data = data.byteswap()
+            full_data[1][i] = np_data.view(np_data.dtype.newbyteorder())
+    # the size should be [batch_size, target_count, sequence_length]
+    tensor_data = torch.from_numpy(np.stack(full_data[1], 
axis=0)).unsqueeze(0).float()
+    # data should be on CPU before passing to the inference request
+    return tensor_data.to("cpu")
+
+
+# Convert DataFrame to TsBlock in binary, input shouldn't contain time column.
+# Maybe contain multiple value columns.
+def convert_tensor_to_tsblock(data_tensor: torch.Tensor):
+    data_frame = pd.DataFrame(data_tensor).T
     data_shape = data_frame.shape
     value_column_size = data_shape[1]
     position_count = data_shape[0]

Reply via email to