This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 78689179359 [AINode] Support built-in inference of table/tree model 
for time_xl
78689179359 is described below

commit 78689179359d6064d5d6568b4baf147fce482604
Author: YangCaiyin <[email protected]>
AuthorDate: Tue May 13 19:35:43 2025 +0800

    [AINode] Support built-in inference of table/tree model for time_xl
---
 .../relational/it/schema/IoTDBDatabaseIT.java      |   6 +-
 .../ainode/ainode/TimerXL/models/timer_xl.py       |  20 +-
 iotdb-core/ainode/ainode/core/handler.py           |   5 +-
 .../ainode/core/manager/inference_manager.py       | 329 ++++++----------
 .../ainode/core/model/built_in_model_factory.py    |  22 +-
 iotdb-core/ainode/ainode/core/util/serde.py        | 413 ---------------------
 .../iotdb/confignode/persistence/ModelInfo.java    |   1 +
 .../operator/process/ai/InferenceOperator.java     |  34 +-
 .../iotdb/commons/client/ainode/AINodeClient.java  |  16 +-
 .../thrift-ainode/src/main/thrift/ainode.thrift    |   7 +-
 10 files changed, 156 insertions(+), 697 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
index 34f1027b8b8..b19f1a1d869 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
@@ -620,7 +620,11 @@ public class IoTDBDatabaseIT {
           "model_id,",
           new HashSet<>(
               Arrays.asList(
-                  "_STLForecaster,", "_NaiveForecaster,", "_ARIMA,", 
"_ExponentialSmoothing,")));
+                  "_timerxl,",
+                  "_STLForecaster,",
+                  "_NaiveForecaster,",
+                  "_ARIMA,",
+                  "_ExponentialSmoothing,")));
 
       TestUtils.assertResultSetEqual(
           statement.executeQuery(
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py 
b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
index 1945fa25a2e..4e4d8588fd2 100644
--- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
+++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
@@ -27,9 +27,12 @@ from ainode.TimerXL.models.configuration_timer import 
TimerxlConfig
 from ainode.core.util.masking import prepare_4d_causal_attention_mask
 from ainode.core.util.huggingface_cache import Cache, DynamicCache
 
-import safetensors
+from safetensors.torch import load_file as load_safetensors
 from huggingface_hub import hf_hub_download
 
+from ainode.core.log import Logger
+logger = Logger()
+
 @dataclass
 class Output:
     outputs: torch.Tensor
@@ -211,12 +214,15 @@ class Model(nn.Module):
                 state_dict = torch.load(config.ckpt_path)
             elif config.ckpt_path.endswith('.safetensors'):
                 if not os.path.exists(config.ckpt_path):
-                    print(f"[INFO] Checkpoint not found at {config.ckpt_path}, 
downloading from HuggingFace...")
+                    logger.info(f"Checkpoint not found at {config.ckpt_path}, 
downloading from HuggingFace...")
                     repo_id = "thuml/timer-base-84m"
-                    filename = os.path.basename(config.ckpt_path)  # eg: 
model.safetensors
-                    config.ckpt_path = hf_hub_download(repo_id=repo_id, 
filename=filename)
-                    print(f"[INFO] Downloaded checkpoint to 
{config.ckpt_path}")
-                state_dict = safetensors.torch.load_file(config.ckpt_path)
+                    try:
+                        config.ckpt_path = hf_hub_download(repo_id=repo_id, 
filename=os.path.basename(config.ckpt_path), 
local_dir=os.path.dirname(config.ckpt_path))
+                        logger.info(f"Got checkpoint to {config.ckpt_path}")
+                    except Exception as e:
+                        logger.error(f"Failed to download checkpoint to 
{config.ckpt_path} due to {e}")
+                        raise e
+                state_dict = load_safetensors(config.ckpt_path)
             else:
                 raise ValueError('unsupported model weight type')
             # If there is no key beginning with 'model.model' in state_dict, 
add a 'model.' before all keys. (The model code here has an additional layer of 
encapsulation compared to the code on huggingface.)
@@ -234,7 +240,7 @@ class Model(nn.Module):
         # change [L, C=1] to [batchsize=1, L]
         self.device = next(self.model.parameters()).device
         
-        x = torch.tensor(x.values, dtype=next(self.model.parameters()).dtype, 
device=self.device)
+        x = torch.tensor(x, dtype=next(self.model.parameters()).dtype, 
device=self.device)
         x = x.view(1, -1)
 
         preds = self.forward(x, max_new_tokens)
diff --git a/iotdb-core/ainode/ainode/core/handler.py 
b/iotdb-core/ainode/ainode/core/handler.py
index 0405d773f69..fc8f8f1aae7 100644
--- a/iotdb-core/ainode/ainode/core/handler.py
+++ b/iotdb-core/ainode/ainode/core/handler.py
@@ -29,6 +29,7 @@ from ainode.thrift.common.ttypes import TSStatus
 class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
     def __init__(self):
         self._model_manager = ModelManager()
+        self._inference_manager = 
InferenceManager(model_manager=self._model_manager)
 
     def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
         return self._model_manager.register_model(req)
@@ -37,10 +38,10 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
         return self._model_manager.delete_model(req)
 
     def inference(self, req: TInferenceReq) -> TInferenceResp:
-        return InferenceManager.inference(req, self._model_manager)
+        return self._inference_manager.inference(req)
 
     def forecast(self, req: TForecastReq) -> TSStatus:
-        return InferenceManager.forecast(req, self._model_manager)
+        return self._inference_manager.forecast(req)
 
     def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
         return ClusterManager.get_heart_beat(req)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index d16e521cde6..c9815745516 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+from abc import abstractmethod, ABC
+
 import pandas as pd
 import torch
 from iotdb.tsfile.utils.tsblock_serde import deserialize
@@ -23,246 +25,139 @@ from ainode.core.constant import TSStatusCode
 from ainode.core.exception import InvalidWindowArgumentError, 
InferenceModelInternalError, runtime_error_extractor
 from ainode.core.log import Logger
 from ainode.core.manager.model_manager import ModelManager
-from ainode.core.util.serde import convert_to_binary, convert_to_df
+from ainode.core.util.serde import convert_to_binary
 from ainode.core.util.status import get_status
 from ainode.thrift.ainode.ttypes import TInferenceReq, TInferenceResp, 
TForecastReq, TForecastResp
 
 logger = Logger()
 
-def _process_data(full_data):
-    """
-    Args:
-        full_data: a tuple of (data, time_stamp, type_list, column_name_list), 
where the data is a DataFrame with shape
-            (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a 
list of data types with length C,
-            column_name_list is a list of column names with length C, where L 
is the number of data points, C is the
-            number of variables, the data and time_stamp are aligned by index
-    Returns:
-        data: a tensor with shape (1, L, C)
-        data_length: the number of data points
-    Description:
-        the process_data module will convert the input data into a tensor with 
shape (1, L, C), where L is the number of
-        data points, C is the number of variables, the data and time_stamp are 
aligned by index. The module will also
-        convert the data type of each column to the corresponding type.
-    """
-    data, time_stamp, type_list, _ = full_data
-    data_length = time_stamp.shape[0]
-    data = data.fillna(0)
-    for i in range(len(type_list)):
-        if type_list[i] == "TEXT":
-            data[data.columns[i]] = 0
-        elif type_list[i] == "BOOLEAN":
-            data[data.columns[i]] = data[data.columns[i]].astype("int")
-    data = torch.tensor(data.values).unsqueeze(0)
-    return data, data_length
 
+class InferenceStrategy(ABC):
+    def __init__(self, model):
+        self.model = model
 
-class InferenceManager:
+    @abstractmethod
+    def infer(self, full_data, **kwargs):
+        pass
 
-    @staticmethod
-    def forecast(req: TForecastReq, model_manager:ModelManager):
-        model_id = req.modelId
-        logger.info(f"start to forcast by model {model_id}")
-        try:
-            data = deserialize(req.inputData)
-            if model_id.startswith('_'):
-                # built-in models
-                logger.info(f"start to forecast built-in model {model_id}")
-                # parse the inference attributes and create the built-in model
-                options = req.options
-                options['predict_length'] = req.outputLength
-                model = _get_built_in_model(model_id, model_manager, options)
-                inference_result = 
convert_to_binary(_inference_with_built_in_model(
-                    model, data))
-            else:
-                # user-registered models
-                model = _get_model(model_id, model_manager, req.options)
-                _, dataset, _, dataset_length = data
-                dataset = torch.tensor(dataset, dtype=torch.float).unsqueeze(2)
-                inference_results = _inference_with_registered_model(
-                    model, dataset, dataset_length, dataset_length, 
float('inf'))
-                inference_result = convert_to_binary(inference_results[0])
-            return TForecastResp(
-                get_status(TSStatusCode.SUCCESS_STATUS),
-                inference_result
-            )
-        except Exception as e:
-            logger.warning(e)
-            inference_results = []
-            return 
TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)), 
inference_results)
 
-    @staticmethod
-    def inference(req: TInferenceReq, model_manager: ModelManager):
-        logger.info(f"start inference registered model {req.modelId}")
-        try:
-            model_id, full_data, window_interval, window_step, 
inference_attributes = _parse_inference_request(req)
+# [IoTDB] full data deserialized from iotdb is composed of [timestampList, 
valueList, length],
+# we only get valueList currently.
+class TimerXLStrategy(InferenceStrategy):
+    def infer(self, full_data, predict_length=96, **_):
+        data = full_data[1][0]
+        if data.dtype.byteorder not in ('=', '|'):
+            data = data.byteswap().newbyteorder()
+        output = self.model.inference(data, int(predict_length))
+        df = pd.DataFrame(output[0])
+        return convert_to_binary(df)
 
-            if model_id.startswith('_'):
-                # built-in models
-                logger.info(f"start inference built-in model {model_id}")
-                # parse the inference attributes and create the built-in model
-                model = _get_built_in_model(model_id, model_manager, 
inference_attributes)
-                if model_id == '_timerxl':
-                    inference_results = [_inference_with_timerxl(
-                        model, full_data, 
inference_attributes.get("predict_length", 96))]
-                else:
-                    inference_results = [_inference_with_built_in_model(
-                        model, full_data)]
-            else:
-                # user-registered models
-                model = _get_model(model_id, model_manager, 
inference_attributes)
-                dataset, dataset_length = _process_data(full_data)
-                inference_results = _inference_with_registered_model(
-                    model, dataset, dataset_length, window_interval, 
window_step)
-            for i in range(len(inference_results)):
-                inference_results[i] = convert_to_binary(inference_results[i])
-            return TInferenceResp(
-                get_status(
-                    TSStatusCode.SUCCESS_STATUS),
-                inference_results)
-        except Exception as e:
-            logger.warning(e)
-            inference_results = []
-            return 
TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)), 
inference_results)
 
-def _inference_with_registered_model(model, dataset, dataset_length, 
window_interval, window_step):
-    """
-    Args:
-        model: the user-defined model
-        full_data: a tuple of (data, time_stamp, type_list, column_name_list), 
where the data is a DataFrame with shape
-            (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a 
list of data types with length C,
-            column_name_list is a list of column names with length C, where L 
is the number of data points, C is the
-            number of variables, the data and time_stamp are aligned by index
-        window_interval: the length of each sliding window
-        window_step: the step between two adjacent sliding windows
-    Returns:
-        outputs: a list of output DataFrames, where each DataFrame has shape 
(H', C'), where H' is the output window
-            interval, C' is the number of variables in the output DataFrame
-    Description:
-        the inference_with_registered_model function will inference with deep 
learning model, which is registered in
-        user register process. This module will split the input data into 
several sliding windows which has the same
-        shape (1, H, C), where H is the window interval, and then feed each 
sliding window into the model to get the
-        output, the output is a DataFrame with shape (H', C'), where H' is the 
output window interval, C' is the number
-        of variables in the output DataFrame. Then the inference module will 
concatenate all the output DataFrames into
-        a list.
-    """
+class BuiltInStrategy(InferenceStrategy):
+    def infer(self, full_data, **_):
+        data = full_data[1][0]
+        if data.dtype.byteorder not in ('=', '|'):
+            data = data.byteswap().newbyteorder()
+        output = self.model.inference(data)
+        df = pd.DataFrame(output)
+        return convert_to_binary(df)
 
-    # check the validity of window_interval and window_step, the two arguments 
must be positive integers, and the
-    # window_interval should not be larger than the dataset length
-    if window_interval is None or window_step is None \
-            or window_interval > dataset_length \
-            or window_interval <= 0 or \
-            window_step <= 0:
-        raise InvalidWindowArgumentError(window_interval, window_step, 
dataset_length)
 
-    sliding_times = int((dataset_length - window_interval) // window_step + 1)
-    outputs = []
-    try:
-        # split the input data into several sliding windows
-        for sliding_time in range(sliding_times):
-            if window_step == float('inf'):
-                start_index = 0
-            else:
-                start_index = sliding_time * window_step
-            end_index = start_index + window_interval
-            # input_data: tensor, shape: (1, H, C), where H is input window 
interval
-            input_data = dataset[:, start_index:end_index, :]
-            # output: tensor, shape: (1, H', C'), where H' is the output 
window interval
-            output = model(input_data)
-            # output: DataFrame, shape: (H', C')
-            output = pd.DataFrame(output.squeeze(0).detach().numpy())
-            outputs.append(output)
-    except Exception as e:
-        error_msg = runtime_error_extractor(str(e))
-        if error_msg != "":
-            raise InferenceModelInternalError(error_msg)
-        raise InferenceModelInternalError(str(e))
+class RegisteredStrategy(InferenceStrategy):
+    def infer(self, full_data, window_interval=None, window_step=None, 
**kwargs):
+        _, dataset, _, length = full_data
+        if window_interval is None or window_step is None:
+            window_interval = length
+            window_step = float('inf')
 
-    return outputs
+        if window_interval <= 0 or window_step <= 0 or window_interval > 
length:
+            raise InvalidWindowArgumentError(window_interval, window_step, 
length)
 
+        data = torch.tensor(dataset, 
dtype=torch.float32).unsqueeze(0).permute(0, 2, 1)
 
-def _inference_with_built_in_model(model, full_data):
-    """
-    Args:
-        model: the built-in model
-        full_data: a tuple of (data, time_stamp, type_list, column_name_list), 
where the data is a DataFrame with shape
-            (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a 
list of data types with length C,
-            column_name_list is a list of column names with length C, where L 
is the number of data points, C is the
-            number of variables, the data and time_stamp are aligned by index
-    Returns:
-        outputs: a list of output DataFrames, where each DataFrame has shape 
(H', C'), where H' is the output window
-            interval, C' is the number of variables in the output DataFrame
-    Description:
-        the inference_with_built_in_model function will inference with 
built-in model, which does not
-        require user registration. This module will parse the inference 
attributes and create the built-in model, then
-        feed the input data into the model to get the output, the output is a 
DataFrame with shape (H', C'), where H'
-        is the output window interval, C' is the number of variables in the 
output DataFrame. Then the inference module
-        will concatenate all the output DataFrames into a list.
-    """
+        times = int((length - window_interval) // window_step + 1)
+        results = []
+        try:
+            for i in range(times):
+                start = 0 if window_step == float('inf') else i * window_step
+                end = start + window_interval
+                window = data[:, start:end, :]
+                out = self.model(window)
+                df = pd.DataFrame(out.squeeze(0).detach().numpy())
+                results.append(df)
+        except Exception as e:
+            msg = runtime_error_extractor(str(e)) or str(e)
+            raise InferenceModelInternalError(msg)
 
-    data, _, _, _ = full_data
-    output = model.inference(data)
-    # output: DataFrame, shape: (H', C')
-    output = pd.DataFrame(output)
-    return output
+        # concatenate or return first window for forecast
+        return [convert_to_binary(df) for df in results]
 
-def _inference_with_timerxl(model, full_data, pred_len):
-    """
-    Args:
-        model: the built-in model
-        full_data: a tuple of (data, time_stamp, type_list, column_name_list), 
where the data is a DataFrame with shape
-            (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a 
list of data types with length C,
-            column_name_list is a list of column names with length C, where L 
is the number of data points, C is the
-            number of variables, the data and time_stamp are aligned by index
-    Returns:
-        outputs: a list of output DataFrames, where each DataFrame has shape 
(H', C'), where H' is the output window
-            interval, C' is the number of variables in the output DataFrame
-    Description:
-        the inference_with_built_in_model function will inference with 
built-in model, which does not
-        require user registration. This module will parse the inference 
attributes and create the built-in model, then
-        feed the input data into the model to get the output, the output is a 
DataFrame with shape (H', C'), where H'
-        is the output window interval, C' is the number of variables in the 
output DataFrame. Then the inference module
-        will concatenate all the output DataFrames into a list.
-    """
 
-    data, _, _, _ = full_data
-    output = model.inference(data, pred_len)
-    # output: DataFrame, shape: (H', C')
-    output = pd.DataFrame(output)
-    return output
+def _get_strategy(model_id, model):
+    if model_id == '_timerxl':
+        return TimerXLStrategy(model)
+    if model_id.startswith('_'):
+        return BuiltInStrategy(model)
+    return RegisteredStrategy(model)
 
 
-def _get_model(model_id: str, model_manager: ModelManager, 
inference_attributes: {}):
-    if inference_attributes is None or 'acceleration' not in 
inference_attributes:
-        # if the acceleration is not specified, then the acceleration will be 
set to default value False
-        acceleration = False
-    else:
-        # if the acceleration is specified, then the acceleration will be set 
to the specified value
-        acceleration = (inference_attributes['acceleration'].lower() == 'true')
-    return model_manager.load_model(model_id, acceleration)
+class InferenceManager:
 
+    def __init__(self, model_manager: ModelManager):
+        self.model_manager = model_manager
 
-def _get_built_in_model(model_id: str, model_manager: ModelManager, 
inference_attributes: {}):
-    return model_manager.load_built_in_model(model_id, inference_attributes)
+    def _run(self, req, data_getter, deserializer, extract_attrs, resp_cls, 
single_output: bool):
+        model_id = req.modelId
+        logger.info(f"Start processing for {model_id}")
+        try:
+            raw = data_getter(req)
+            full_data = deserializer(raw)
+            attrs = extract_attrs(req)
 
+            # load model
+            if model_id.startswith('_'):
+                model = self.model_manager.load_built_in_model(model_id, attrs)
+            else:
+                accel = str(attrs.get('acceleration', '')).lower() == 'true'
+                model = self.model_manager.load_model(model_id, accel)
+
+            # inference by strategy
+            strategy = _get_strategy(model_id, model)
+            outputs = strategy.infer(full_data, **attrs)
 
-def _parse_inference_request(req: TInferenceReq):
-    binary_dataset = req.dataset
-    type_list = req.typeList
-    column_name_list = req.columnNameList
-    column_name_index = req.columnNameIndexMap
-    data = convert_to_df(column_name_list, type_list, column_name_index, 
[binary_dataset])
-    time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
-    full_data = (data, time_stamp, type_list, column_name_list)
-    inference_attributes = req.inferenceAttributes
-    if inference_attributes is None:
-        inference_attributes = {}
+            # construct response
+            status = get_status(TSStatusCode.SUCCESS_STATUS)
 
-    window_params = req.windowParams
-    if window_params is None:
-        # set default window_step to infinity and window_interval to dataset 
length
-        window_step = float('inf')
-        window_interval = data.shape[0]
-    else:
-        window_step = window_params.windowStep
-        window_interval = window_params.windowInterval
-    return req.modelId, full_data, window_interval, window_step, 
inference_attributes
+            if isinstance(outputs, list):
+                return resp_cls(status, outputs[0] if single_output else 
outputs)
+            return resp_cls(status, outputs if single_output else [outputs])
+
+        except Exception as e:
+            logger.error(e)
+            status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
+            empty = b'' if single_output else []
+            return resp_cls(status, empty)
+
+    def forecast(self, req: TForecastReq):
+        return self._run(
+            req,
+            data_getter=lambda r: r.inputData,
+            deserializer=deserialize,
+            extract_attrs=lambda r: {'predict_length': r.outputLength, 
**(r.options or {})},
+            resp_cls=TForecastResp,
+            single_output=True
+        )
+
+    def inference(self, req: TInferenceReq):
+        return self._run(
+            req,
+            data_getter=lambda r: r.dataset,
+            deserializer=deserialize,
+            extract_attrs=lambda r: {
+                'window_interval': getattr(r.windowParams, 'windowInterval', 
None),
+                'window_step': getattr(r.windowParams, 'windowStep', None),
+                **(r.inferenceAttributes or {})
+            },
+            resp_cls=TInferenceResp,
+            single_output=False
+        )
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index dd8c9ee5308..0d2991a7f9f 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -15,9 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+import os
 from abc import abstractmethod
 from typing import List, Dict
-import os
 
 import numpy as np
 from sklearn.preprocessing import MinMaxScaler
@@ -28,15 +28,15 @@ from sktime.forecasting.exp_smoothing import 
ExponentialSmoothing
 from sktime.forecasting.naive import NaiveForecaster
 from sktime.forecasting.trend import STLForecaster
 
+from ainode.TimerXL.models import timer_xl
+from ainode.TimerXL.models.configuration_timer import TimerxlConfig
+from ainode.core.config import AINodeDescriptor
 from ainode.core.constant import AttributeName, BuiltInModelType
-from ainode.core.exception import InferenceModelInternalError, 
AttributeNotSupportError
+from ainode.core.exception import InferenceModelInternalError
 from ainode.core.exception import WrongAttributeTypeError, 
NumericalRangeException, StringRangeException, \
     ListRangeException, BuiltInModelNotSupportError
 from ainode.core.log import Logger
 
-from ainode.TimerXL.models import timer_xl
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
 logger = Logger()
 
 
@@ -79,11 +79,6 @@ def fetch_built_in_model(model_id, inference_attributes):
     """
     attribute_map = get_model_attributes(model_id)
 
-    # validate the inference attributes
-    for attribute_name in inference_attributes:
-        if attribute_name not in attribute_map:
-            raise AttributeNotSupportError(model_id, attribute_name)
-
     # parse the inference attributes, attributes is a Dict[str, Any]
     attributes = parse_attribute(inference_attributes, attribute_map)
 
@@ -398,9 +393,10 @@ timerxl_attribute_map = {
     ),
     AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute(
         name=AttributeName.TIMERXL_CKPT_PATH.value,
-        default_value=os.path.join(os.path.dirname(os.path.abspath(__file__)), 
'weights', 'timerxl', 'model.safetensors'),
-        
value_choices=[os.path.join(os.path.dirname(os.path.abspath(__file__)), 
'weights', 'timerxl', 'model.safetensors'), ""],
-    ),
+        default_value=os.path.join(os.getcwd(), 
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
+                                   'timerxl', 'model.safetensors'),
+        value_choices=['']
+    )
 }
 
 # built-in sktime model attributes
diff --git a/iotdb-core/ainode/ainode/core/util/serde.py 
b/iotdb-core/ainode/ainode/core/util/serde.py
index b9edccfd03e..70b86d66095 100644
--- a/iotdb-core/ainode/ainode/core/util/serde.py
+++ b/iotdb-core/ainode/ainode/core/util/serde.py
@@ -97,164 +97,6 @@ def convert_to_binary(data_frame: pd.DataFrame):
     return binary
 
 
-# convert tsBlock in binary to dataFrame
-def convert_to_df(name_list, type_list, name_index, binary_list):
-    column_name_list = [TIMESTAMP_STR]
-    column_type_list = [TSDataType.INT64]
-    column_ordinal_dict = {TIMESTAMP_STR: 1}
-
-    if name_index is not None:
-        column_type_deduplicated_list = [
-            None for _ in range(len(name_index))
-        ]
-        for i in range(len(name_list)):
-            name = name_list[i]
-            column_name_list.append(name)
-            column_type_list.append(TSDataType[type_list[i]])
-            if name not in column_ordinal_dict:
-                index = name_index[name]
-                column_ordinal_dict[name] = index + START_INDEX
-                column_type_deduplicated_list[index] = TSDataType[type_list[i]]
-    else:
-        index = START_INDEX
-        column_type_deduplicated_list = []
-        for i in range(len(name_list)):
-            name = name_list[i]
-            column_name_list.append(name)
-            column_type_list.append(TSDataType[type_list[i]])
-            if name not in column_ordinal_dict:
-                column_ordinal_dict[name] = index
-                index += 1
-                column_type_deduplicated_list.append(
-                    TSDataType[type_list[i]]
-                )
-
-    binary_size = len(binary_list)
-    binary_index = 0
-    result = {}
-    for column_name in column_name_list:
-        result[column_name] = None
-
-    while binary_index < binary_size:
-        buffer = binary_list[binary_index]
-        binary_index += 1
-        time_column_values, column_values, null_indicators, _ = 
deserialize(buffer)
-        time_array = np.frombuffer(
-            time_column_values, np.dtype(np.longlong).newbyteorder(">")
-        )
-        if time_array.dtype.byteorder == ">":
-            time_array = 
time_array.byteswap().view(time_array.dtype.newbyteorder("<"))
-
-        if result[TIMESTAMP_STR] is None:
-            result[TIMESTAMP_STR] = time_array
-        else:
-            result[TIMESTAMP_STR] = np.concatenate(
-                (result[TIMESTAMP_STR], time_array), axis=0
-            )
-        total_length = len(time_array)
-
-        for i in range(len(column_values)):
-            column_name = column_name_list[i + 1]
-
-            location = column_ordinal_dict[column_name] - START_INDEX
-            if location < 0:
-                continue
-
-            data_type = column_type_deduplicated_list[location]
-            value_buffer = column_values[location]
-            value_buffer_len = len(value_buffer)
-
-            if data_type == TSDataType.DOUBLE:
-                data_array = np.frombuffer(
-                    value_buffer, np.dtype(np.double).newbyteorder(">")
-                )
-            elif data_type == TSDataType.FLOAT:
-                data_array = np.frombuffer(
-                    value_buffer, np.dtype(np.float32).newbyteorder(">")
-                )
-            elif data_type == TSDataType.BOOLEAN:
-                data_array = []
-                for index in range(len(value_buffer)):
-                    data_array.append(value_buffer[index])
-                data_array = np.array(data_array).astype("bool")
-            elif data_type == TSDataType.INT32:
-                data_array = np.frombuffer(
-                    value_buffer, np.dtype(np.int32).newbyteorder(">")
-                )
-            elif data_type == TSDataType.INT64:
-                data_array = np.frombuffer(
-                    value_buffer, np.dtype(np.int64).newbyteorder(">")
-                )
-            elif data_type == TSDataType.TEXT:
-                index = 0
-                data_array = []
-                while index < value_buffer_len:
-                    value_bytes = value_buffer[index]
-                    value = value_bytes.decode("utf-8")
-                    data_array.append(value)
-                    index += 1
-                data_array = np.array(data_array, dtype=object)
-            else:
-                raise RuntimeError("unsupported data type 
{}.".format(data_type))
-
-            if data_array.dtype.byteorder == ">":
-                data_array = 
data_array.byteswap().view(data_array.dtype.newbyteorder("<"))
-
-            null_indicator = null_indicators[location]
-            if len(data_array) < total_length or (data_type == 
TSDataType.BOOLEAN and null_indicator is not None):
-                if data_type == TSDataType.INT32 or data_type == 
TSDataType.INT64:
-                    tmp_array = np.full(total_length, np.nan, np.float32)
-                elif data_type == TSDataType.FLOAT or data_type == 
TSDataType.DOUBLE:
-                    tmp_array = np.full(total_length, np.nan, data_array.dtype)
-                elif data_type == TSDataType.BOOLEAN:
-                    tmp_array = np.full(total_length, np.nan, np.float32)
-                elif data_type == TSDataType.TEXT:
-                    tmp_array = np.full(total_length, np.nan, 
dtype=data_array.dtype)
-                else:
-                    raise Exception("Unsupported dataType in deserialization")
-
-                if null_indicator is not None:
-                    indexes = [not v for v in null_indicator]
-                    if data_type == TSDataType.BOOLEAN:
-                        tmp_array[indexes] = data_array[indexes]
-                    else:
-                        tmp_array[indexes] = data_array
-
-                if data_type == TSDataType.INT32:
-                    tmp_array = pd.Series(tmp_array).astype("Int32")
-                elif data_type == TSDataType.INT64:
-                    tmp_array = pd.Series(tmp_array).astype("Int64")
-                elif data_type == TSDataType.BOOLEAN:
-                    tmp_array = pd.Series(tmp_array).astype("boolean")
-
-                data_array = tmp_array
-
-            if result[column_name] is None:
-                result[column_name] = data_array
-            else:
-                if isinstance(result[column_name], pd.Series):
-                    if not isinstance(data_array, pd.Series):
-                        if data_type == TSDataType.INT32:
-                            data_array = pd.Series(data_array).astype("Int32")
-                        elif data_type == TSDataType.INT64:
-                            data_array = pd.Series(data_array).astype("Int64")
-                        elif data_type == TSDataType.BOOLEAN:
-                            data_array = 
pd.Series(data_array).astype("boolean")
-                        else:
-                            raise RuntimeError("Series Error")
-                    result[column_name] = 
result[column_name].append(data_array)
-                else:
-                    result[column_name] = np.concatenate(
-                        (result[column_name], data_array), axis=0
-                    )
-    for k, v in result.items():
-        if v is None:
-            result[k] = []
-    df = pd.DataFrame(result)
-    df = df.reset_index(drop=True)
-    return df
-
-
 def _get_encoder(data_type: pd.Series):
     if data_type == "bool":
         return b'\x00'
@@ -284,74 +126,7 @@ def _get_type_in_byte(data_type: pd.Series):
                                   "data_type should be in ['bool', 'int32', 
'int64', 'float32', 'float64', 'text']")
 
 
-# Serialized tsBlock:
-#    
+-------------+---------------+---------+------------+-----------+----------+
-#    | val col cnt | val col types | pos cnt | encodings  | time col  | val 
col  |
-#    
+-------------+---------------+---------+------------+-----------+----------+
-#    | int32       | list[byte]    | int32   | list[byte] |  bytes    | byte   
  |
-#    
+-------------+---------------+---------+------------+-----------+----------+
-
-def deserialize(buffer):
-    value_column_count, buffer = read_int_from_buffer(buffer)
-    data_types, buffer = read_column_types(buffer, value_column_count)
-
-    position_count, buffer = read_int_from_buffer(buffer)
-    column_encodings, buffer = read_column_encoding(buffer, value_column_count 
+ 1)
-
-    time_column_values, buffer = read_time_column(buffer, position_count)
-    column_values = [None] * value_column_count
-    null_indicators = [None] * value_column_count
-    for i in range(value_column_count):
-        column_value, null_indicator, buffer = read_column(column_encodings[i 
+ 1], buffer, data_types[i],
-                                                           position_count)
-        column_values[i] = column_value
-        null_indicators[i] = null_indicator
-
-    return time_column_values, column_values, null_indicators, position_count
-
-
 # General Methods
-
-def read_int_from_buffer(buffer):
-    res, buffer = read_from_buffer(buffer, 4)
-    return int.from_bytes(res, "big"), buffer
-
-
-def read_byte_from_buffer(buffer):
-    return read_from_buffer(buffer, 1)
-
-
-def read_from_buffer(buffer, size):
-    res = buffer[:size]
-    buffer = buffer[size:]
-    return res, buffer
-
-
-# Read ColumnType
-
-def read_column_types(buffer, value_column_count):
-    data_types = []
-    for _ in range(value_column_count):
-        res, buffer = read_byte_from_buffer(buffer)
-        data_types.append(get_data_type(res))
-    return data_types, buffer
-
-
-def get_data_type(value):
-    if value == b'\x00':
-        return TSDataType.BOOLEAN
-    elif value == b'\x01':
-        return TSDataType.INT32
-    elif value == b'\x02':
-        return TSDataType.INT64
-    elif value == b'\x03':
-        return TSDataType.FLOAT
-    elif value == b'\x04':
-        return TSDataType.DOUBLE
-    elif value == b'\x05':
-        return TSDataType.TEXT
-
-
 def get_data_type_byte_from_str(value):
     '''
     Args:
@@ -374,191 +149,3 @@ def get_data_type_byte_from_str(value):
         return TSDataType.DOUBLE.value
     elif value == "text":
         return TSDataType.TEXT.value
-
-
-# Read ColumnEncodings
-
-def read_column_encoding(buffer, size):
-    encodings = []
-    for _ in range(size):
-        res, buffer = read_byte_from_buffer(buffer)
-        encodings.append(res)
-    return encodings, buffer
-
-
-# Read Column
-
-def deserialize_null_indicators(buffer, size):
-    may_have_null, buffer = read_byte_from_buffer(buffer)
-    if may_have_null != b'\x00':
-        return deserialize_from_boolean_array(buffer, size)
-    return None, buffer
-
-
-# Serialized data layout:
-#    +---------------+-----------------+-------------+
-#    | may have null | null indicators |   values    |
-#    +---------------+-----------------+-------------+
-#    | byte          | list[byte]      | list[int64] |
-#    +---------------+-----------------+-------------+
-
-def read_time_column(buffer, size):
-    null_indicators, buffer = deserialize_null_indicators(buffer, size)
-    if null_indicators is None:
-        values, buffer = read_from_buffer(
-            buffer, size * 8
-        )
-    else:
-        raise Exception("TimeColumn should not contains null value")
-    return values, buffer
-
-
-def read_int64_column(buffer, data_type, position_count):
-    null_indicators, buffer = deserialize_null_indicators(buffer, 
position_count)
-    if null_indicators is None:
-        size = position_count
-    else:
-        size = null_indicators.count(False)
-
-    if TSDataType.INT64 == data_type or TSDataType.DOUBLE == data_type:
-        values, buffer = read_from_buffer(buffer, size * 8)
-        return values, null_indicators, buffer
-    else:
-        raise Exception("Invalid data type: " + data_type)
-
-
-# Serialized data layout:
-#    +---------------+-----------------+-------------+
-#    | may have null | null indicators |   values    |
-#    +---------------+-----------------+-------------+
-#    | byte          | list[byte]      | list[int32] |
-#    +---------------+-----------------+-------------+
-
-def read_int32_column(buffer, data_type, position_count):
-    null_indicators, buffer = deserialize_null_indicators(buffer, 
position_count)
-    if null_indicators is None:
-        size = position_count
-    else:
-        size = null_indicators.count(False)
-
-    if TSDataType.INT32 == data_type or TSDataType.FLOAT == data_type:
-        values, buffer = read_from_buffer(buffer, size * 4)
-        return values, null_indicators, buffer
-    else:
-        raise Exception("Invalid data type: " + data_type)
-
-
-# Serialized data layout:
-#    +---------------+-----------------+-------------+
-#    | may have null | null indicators |   values    |
-#    +---------------+-----------------+-------------+
-#    | byte          | list[byte]      | list[byte] |
-#    +---------------+-----------------+-------------+
-
-def read_byte_column(buffer, data_type, position_count):
-    if data_type != TSDataType.BOOLEAN:
-        raise Exception("Invalid data type: " + data_type)
-    null_indicators, buffer = deserialize_null_indicators(buffer, 
position_count)
-    res, buffer = deserialize_from_boolean_array(buffer, position_count)
-    return res, null_indicators, buffer
-
-
-def deserialize_from_boolean_array(buffer, size):
-    packed_boolean_array, buffer = read_from_buffer(buffer, (size + 7) // 8)
-    current_byte = 0
-    output = [None] * size
-    position = 0
-    # read null bits 8 at a time
-    while position < (size & ~0b111):
-        value = packed_boolean_array[current_byte]
-        output[position] = ((value & 0b1000_0000) != 0)
-        output[position + 1] = ((value & 0b0100_0000) != 0)
-        output[position + 2] = ((value & 0b0010_0000) != 0)
-        output[position + 3] = ((value & 0b0001_0000) != 0)
-        output[position + 4] = ((value & 0b0000_1000) != 0)
-        output[position + 5] = ((value & 0b0000_0100) != 0)
-        output[position + 6] = ((value & 0b0000_0010) != 0)
-        output[position + 7] = ((value & 0b0000_0001) != 0)
-
-        position += 8
-        current_byte += 1
-    # read last null bits
-    if (size & 0b111) > 0:
-        value = packed_boolean_array[-1]
-        mask = 0b1000_0000
-        position = size & ~0b111
-        while position < size:
-            output[position] = ((value & mask) != 0)
-            mask >>= 1
-            position += 1
-    return output, buffer
-
-
-# Serialized data layout:
-#    +---------------+-----------------+-------------+
-#    | may have null | null indicators |   values    |
-#    +---------------+-----------------+-------------+
-#    | byte          | list[byte]      | list[entry] |
-#    +---------------+-----------------+-------------+
-#
-# Each entry is represented as:
-#    +---------------+-------+
-#    | value length  | value |
-#    +---------------+-------+
-#    | int32         | bytes |
-#    +---------------+-------+
-
-def read_binary_column(buffer, data_type, position_count):
-    if data_type != TSDataType.TEXT:
-        raise Exception("Invalid data type: " + data_type)
-    null_indicators, buffer = deserialize_null_indicators(buffer, 
position_count)
-
-    if null_indicators is None:
-        size = position_count
-    else:
-        size = null_indicators.count(False)
-    values = [None] * size
-    for i in range(size):
-        length, buffer = read_int_from_buffer(buffer)
-        res, buffer = read_from_buffer(buffer, length)
-        values[i] = res
-    return values, null_indicators, buffer
-
-
-def read_column(encoding, buffer, data_type, position_count):
-    if encoding == b'\x00':
-        return read_byte_column(buffer, data_type, position_count)
-    elif encoding == b'\x01':
-        return read_int32_column(buffer, data_type, position_count)
-    elif encoding == b'\x02':
-        return read_int64_column(buffer, data_type, position_count)
-    elif encoding == b'\x03':
-        return read_binary_column(buffer, data_type, position_count)
-    elif encoding == b'\x04':
-        return read_run_length_column(buffer, data_type, position_count)
-    else:
-        raise Exception("Unsupported encoding: " + encoding)
-
-
-# Serialized data layout:
-#    +-----------+-------------------------+
-#    | encoding  | serialized inner column |
-#    +-----------+-------------------------+
-#    | byte      | list[byte]              |
-#    +-----------+-------------------------+
-
-def read_run_length_column(buffer, data_type, position_count):
-    encoding, buffer = read_byte_from_buffer(buffer)
-    column, null_indicators, buffer = read_column(encoding, buffer, data_type, 
1)
-
-    return repeat(column, data_type, position_count), null_indicators * 
position_count, buffer
-
-
-def repeat(buffer, data_type, position_count):
-    if data_type == TSDataType.BOOLEAN or data_type == TSDataType.TEXT:
-        return buffer * position_count
-    else:
-        res = bytes(0)
-        for _ in range(position_count):
-            res.join(buffer)
-        return res
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 21b317d9854..e2beede330a 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -75,6 +75,7 @@ public class ModelInfo implements SnapshotProcessor {
   private static final Set<String> builtInAnomalyDetectionModel = new 
HashSet<>();
 
   static {
+    builtInForecastModel.add("_timerxl");
     builtInForecastModel.add("_ARIMA");
     builtInForecastModel.add("_NaiveForecaster");
     builtInForecastModel.add("_STLForecaster");
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
index 9bdb57dc5bd..a384be3ad24 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
@@ -38,7 +38,6 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import org.apache.tsfile.block.column.Column;
 import org.apache.tsfile.block.column.ColumnBuilder;
-import org.apache.tsfile.enums.TSDataType;
 import org.apache.tsfile.read.common.block.TsBlock;
 import org.apache.tsfile.read.common.block.TsBlockBuilder;
 import org.apache.tsfile.read.common.block.column.TimeColumnBuilder;
@@ -47,12 +46,9 @@ import org.apache.tsfile.utils.RamUsageEstimator;
 
 import java.nio.ByteBuffer;
 import java.util.Arrays;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
-import java.util.stream.Collectors;
 
 import static com.google.common.util.concurrent.Futures.successfulAsList;
 
@@ -74,8 +70,7 @@ public class InferenceOperator implements ProcessOperator {
 
   private final long maxRetainedSize;
   private final long maxReturnSize;
-  private final List<String> inputColumnNames;
-  private final List<String> targetColumnNames;
+  private final int[] columnIndexes;
   private long totalRow;
   private int resultIndex = 0;
   private List<ByteBuffer> results;
@@ -105,8 +100,11 @@ public class InferenceOperator implements ProcessOperator {
         new TsBlockBuilder(
             
Arrays.asList(modelInferenceDescriptor.getModelInformation().getInputDataType()));
     this.modelInferenceExecutor = modelInferenceExecutor;
-    this.targetColumnNames = targetColumnNames;
-    this.inputColumnNames = inputColumnNames;
+    this.columnIndexes = new int[inputColumnNames.size()];
+    for (int i = 0; i < inputColumnNames.size(); i++) {
+      columnIndexes[i] = targetColumnNames.indexOf(inputColumnNames.get(i));
+    }
+
     this.maxRetainedSize = maxRetainedSize;
     this.maxReturnSize = maxReturnSize;
     this.totalRow = 0;
@@ -232,7 +230,7 @@ public class InferenceOperator implements ProcessOperator {
       }
       timeColumnBuilder.writeLong(timestamp);
       for (int columnIndex = 0; columnIndex < 
inputTsBlock.getValueColumnCount(); columnIndex++) {
-        columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), 
i);
+        
columnBuilders[columnIndexes[columnIndex]].write(inputTsBlock.getColumn(columnIndex),
 i);
       }
       inputTsBlockBuilder.declarePosition();
     }
@@ -304,12 +302,6 @@ public class InferenceOperator implements ProcessOperator {
     TsBlock finalInputTsBlock = preProcess(inputTsBlock);
     TWindowParams windowParams = getWindowParams();
 
-    Map<String, Integer> columnNameIndexMap = new HashMap<>();
-
-    for (int i = 0; i < inputColumnNames.size(); i++) {
-      columnNameIndexMap.put(inputColumnNames.get(i), i);
-    }
-
     inferenceExecutionFuture =
         Futures.submit(
             () -> {
@@ -318,11 +310,6 @@ public class InferenceOperator implements ProcessOperator {
                       
.borrowClient(modelInferenceDescriptor.getTargetAINode())) {
                 return client.inference(
                     modelInferenceDescriptor.getModelName(),
-                    targetColumnNames,
-                    
Arrays.stream(modelInferenceDescriptor.getModelInformation().getInputDataType())
-                        .map(TSDataType::toString)
-                        .collect(Collectors.toList()),
-                    columnNameIndexMap,
                     finalInputTsBlock,
                     modelInferenceDescriptor.getInferenceAttributes(),
                     windowParams);
@@ -367,11 +354,6 @@ public class InferenceOperator implements ProcessOperator {
         + MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(child)
         + 
MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(operatorContext)
         + inputTsBlockBuilder.getRetainedSizeInBytes()
-        + (inputColumnNames == null
-            ? 0
-            : 
inputColumnNames.stream().mapToLong(RamUsageEstimator::sizeOf).sum())
-        + (targetColumnNames == null
-            ? 0
-            : 
targetColumnNames.stream().mapToLong(RamUsageEstimator::sizeOf).sum());
+        + (long) columnIndexes.length * Integer.BYTES;
   }
 }
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index 3cc416f0aad..5532d283bcb 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -56,7 +56,6 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 
@@ -161,21 +160,12 @@ public class AINodeClient implements AutoCloseable, 
ThriftClient {
 
   public TInferenceResp inference(
       String modelId,
-      List<String> inputColumnNames,
-      List<String> inputTypeList,
-      Map<String, Integer> columnIndexMap,
       TsBlock inputTsBlock,
       Map<String, String> inferenceAttributes,
       TWindowParams windowParams)
       throws TException {
     try {
-      TInferenceReq inferenceReq =
-          new TInferenceReq(
-              modelId,
-              tsBlockSerde.serialize(inputTsBlock),
-              inputTypeList,
-              inputColumnNames,
-              columnIndexMap);
+      TInferenceReq inferenceReq = new TInferenceReq(modelId, 
tsBlockSerde.serialize(inputTsBlock));
       if (windowParams != null) {
         inferenceReq.setWindowParams(windowParams);
       }
@@ -184,10 +174,10 @@ public class AINodeClient implements AutoCloseable, 
ThriftClient {
       }
       return client.inference(inferenceReq);
     } catch (IOException e) {
-      throw new TException("An exception occurred while serializing input 
tsblock", e);
+      throw new TException("An exception occurred while serializing input 
data", e);
     } catch (TException e) {
       logger.warn(
-          "Failed to connect to AINode from DataNode when executing {}: {}",
+          "Error happens in AINode when executing {}: {}",
           Thread.currentThread().getStackTrace()[1].getMethodName(),
           e.getMessage());
       throw new TException(MSG_CONNECTION_FAIL);
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 5643da743a8..db1a15c2460 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -58,11 +58,8 @@ struct TRegisterModelResp {
 struct TInferenceReq {
   1: required string modelId
   2: required binary dataset
-  3: required list<string> typeList
-  4: required list<string> columnNameList
-  5: required map<string, i32> columnNameIndexMap
-  6: optional TWindowParams windowParams
-  7: optional map<string, string> inferenceAttributes
+  3: optional TWindowParams windowParams
+  4: optional map<string, string> inferenceAttributes
 }
 
 struct TWindowParams {

Reply via email to