This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 78689179359 [AINode] Support built-in inference of table/tree model
for time_xl
78689179359 is described below
commit 78689179359d6064d5d6568b4baf147fce482604
Author: YangCaiyin <[email protected]>
AuthorDate: Tue May 13 19:35:43 2025 +0800
[AINode] Support built-in inference of table/tree model for time_xl
---
.../relational/it/schema/IoTDBDatabaseIT.java | 6 +-
.../ainode/ainode/TimerXL/models/timer_xl.py | 20 +-
iotdb-core/ainode/ainode/core/handler.py | 5 +-
.../ainode/core/manager/inference_manager.py | 329 ++++++----------
.../ainode/core/model/built_in_model_factory.py | 22 +-
iotdb-core/ainode/ainode/core/util/serde.py | 413 ---------------------
.../iotdb/confignode/persistence/ModelInfo.java | 1 +
.../operator/process/ai/InferenceOperator.java | 34 +-
.../iotdb/commons/client/ainode/AINodeClient.java | 16 +-
.../thrift-ainode/src/main/thrift/ainode.thrift | 7 +-
10 files changed, 156 insertions(+), 697 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
index 34f1027b8b8..b19f1a1d869 100644
---
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
@@ -620,7 +620,11 @@ public class IoTDBDatabaseIT {
"model_id,",
new HashSet<>(
Arrays.asList(
- "_STLForecaster,", "_NaiveForecaster,", "_ARIMA,",
"_ExponentialSmoothing,")));
+ "_timerxl,",
+ "_STLForecaster,",
+ "_NaiveForecaster,",
+ "_ARIMA,",
+ "_ExponentialSmoothing,")));
TestUtils.assertResultSetEqual(
statement.executeQuery(
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
index 1945fa25a2e..4e4d8588fd2 100644
--- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
+++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
@@ -27,9 +27,12 @@ from ainode.TimerXL.models.configuration_timer import
TimerxlConfig
from ainode.core.util.masking import prepare_4d_causal_attention_mask
from ainode.core.util.huggingface_cache import Cache, DynamicCache
-import safetensors
+from safetensors.torch import load_file as load_safetensors
from huggingface_hub import hf_hub_download
+from ainode.core.log import Logger
+logger = Logger()
+
@dataclass
class Output:
outputs: torch.Tensor
@@ -211,12 +214,15 @@ class Model(nn.Module):
state_dict = torch.load(config.ckpt_path)
elif config.ckpt_path.endswith('.safetensors'):
if not os.path.exists(config.ckpt_path):
- print(f"[INFO] Checkpoint not found at {config.ckpt_path},
downloading from HuggingFace...")
+ logger.info(f"Checkpoint not found at {config.ckpt_path},
downloading from HuggingFace...")
repo_id = "thuml/timer-base-84m"
- filename = os.path.basename(config.ckpt_path) # eg:
model.safetensors
- config.ckpt_path = hf_hub_download(repo_id=repo_id,
filename=filename)
- print(f"[INFO] Downloaded checkpoint to
{config.ckpt_path}")
- state_dict = safetensors.torch.load_file(config.ckpt_path)
+ try:
+ config.ckpt_path = hf_hub_download(repo_id=repo_id,
filename=os.path.basename(config.ckpt_path),
local_dir=os.path.dirname(config.ckpt_path))
+ logger.info(f"Got checkpoint to {config.ckpt_path}")
+ except Exception as e:
+ logger.error(f"Failed to download checkpoint to
{config.ckpt_path} due to {e}")
+ raise e
+ state_dict = load_safetensors(config.ckpt_path)
else:
raise ValueError('unsupported model weight type')
# If there is no key beginning with 'model.model' in state_dict,
add a 'model.' before all keys. (The model code here has an additional layer of
encapsulation compared to the code on huggingface.)
@@ -234,7 +240,7 @@ class Model(nn.Module):
# change [L, C=1] to [batchsize=1, L]
self.device = next(self.model.parameters()).device
- x = torch.tensor(x.values, dtype=next(self.model.parameters()).dtype,
device=self.device)
+ x = torch.tensor(x, dtype=next(self.model.parameters()).dtype,
device=self.device)
x = x.view(1, -1)
preds = self.forward(x, max_new_tokens)
diff --git a/iotdb-core/ainode/ainode/core/handler.py
b/iotdb-core/ainode/ainode/core/handler.py
index 0405d773f69..fc8f8f1aae7 100644
--- a/iotdb-core/ainode/ainode/core/handler.py
+++ b/iotdb-core/ainode/ainode/core/handler.py
@@ -29,6 +29,7 @@ from ainode.thrift.common.ttypes import TSStatus
class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
def __init__(self):
self._model_manager = ModelManager()
+ self._inference_manager =
InferenceManager(model_manager=self._model_manager)
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
return self._model_manager.register_model(req)
@@ -37,10 +38,10 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
return self._model_manager.delete_model(req)
def inference(self, req: TInferenceReq) -> TInferenceResp:
- return InferenceManager.inference(req, self._model_manager)
+ return self._inference_manager.inference(req)
def forecast(self, req: TForecastReq) -> TSStatus:
- return InferenceManager.forecast(req, self._model_manager)
+ return self._inference_manager.forecast(req)
def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
return ClusterManager.get_heart_beat(req)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index d16e521cde6..c9815745516 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
#
+from abc import abstractmethod, ABC
+
import pandas as pd
import torch
from iotdb.tsfile.utils.tsblock_serde import deserialize
@@ -23,246 +25,139 @@ from ainode.core.constant import TSStatusCode
from ainode.core.exception import InvalidWindowArgumentError,
InferenceModelInternalError, runtime_error_extractor
from ainode.core.log import Logger
from ainode.core.manager.model_manager import ModelManager
-from ainode.core.util.serde import convert_to_binary, convert_to_df
+from ainode.core.util.serde import convert_to_binary
from ainode.core.util.status import get_status
from ainode.thrift.ainode.ttypes import TInferenceReq, TInferenceResp,
TForecastReq, TForecastResp
logger = Logger()
-def _process_data(full_data):
- """
- Args:
- full_data: a tuple of (data, time_stamp, type_list, column_name_list),
where the data is a DataFrame with shape
- (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a
list of data types with length C,
- column_name_list is a list of column names with length C, where L
is the number of data points, C is the
- number of variables, the data and time_stamp are aligned by index
- Returns:
- data: a tensor with shape (1, L, C)
- data_length: the number of data points
- Description:
- the process_data module will convert the input data into a tensor with
shape (1, L, C), where L is the number of
- data points, C is the number of variables, the data and time_stamp are
aligned by index. The module will also
- convert the data type of each column to the corresponding type.
- """
- data, time_stamp, type_list, _ = full_data
- data_length = time_stamp.shape[0]
- data = data.fillna(0)
- for i in range(len(type_list)):
- if type_list[i] == "TEXT":
- data[data.columns[i]] = 0
- elif type_list[i] == "BOOLEAN":
- data[data.columns[i]] = data[data.columns[i]].astype("int")
- data = torch.tensor(data.values).unsqueeze(0)
- return data, data_length
+class InferenceStrategy(ABC):
+ def __init__(self, model):
+ self.model = model
-class InferenceManager:
+ @abstractmethod
+ def infer(self, full_data, **kwargs):
+ pass
- @staticmethod
- def forecast(req: TForecastReq, model_manager:ModelManager):
- model_id = req.modelId
- logger.info(f"start to forcast by model {model_id}")
- try:
- data = deserialize(req.inputData)
- if model_id.startswith('_'):
- # built-in models
- logger.info(f"start to forecast built-in model {model_id}")
- # parse the inference attributes and create the built-in model
- options = req.options
- options['predict_length'] = req.outputLength
- model = _get_built_in_model(model_id, model_manager, options)
- inference_result =
convert_to_binary(_inference_with_built_in_model(
- model, data))
- else:
- # user-registered models
- model = _get_model(model_id, model_manager, req.options)
- _, dataset, _, dataset_length = data
- dataset = torch.tensor(dataset, dtype=torch.float).unsqueeze(2)
- inference_results = _inference_with_registered_model(
- model, dataset, dataset_length, dataset_length,
float('inf'))
- inference_result = convert_to_binary(inference_results[0])
- return TForecastResp(
- get_status(TSStatusCode.SUCCESS_STATUS),
- inference_result
- )
- except Exception as e:
- logger.warning(e)
- inference_results = []
- return
TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)),
inference_results)
- @staticmethod
- def inference(req: TInferenceReq, model_manager: ModelManager):
- logger.info(f"start inference registered model {req.modelId}")
- try:
- model_id, full_data, window_interval, window_step,
inference_attributes = _parse_inference_request(req)
+# [IoTDB] full data deserialized from iotdb is composed of [timestampList,
valueList, length],
+# we only get valueList currently.
+class TimerXLStrategy(InferenceStrategy):
+ def infer(self, full_data, predict_length=96, **_):
+ data = full_data[1][0]
+ if data.dtype.byteorder not in ('=', '|'):
+ data = data.byteswap().newbyteorder()
+ output = self.model.inference(data, int(predict_length))
+ df = pd.DataFrame(output[0])
+ return convert_to_binary(df)
- if model_id.startswith('_'):
- # built-in models
- logger.info(f"start inference built-in model {model_id}")
- # parse the inference attributes and create the built-in model
- model = _get_built_in_model(model_id, model_manager,
inference_attributes)
- if model_id == '_timerxl':
- inference_results = [_inference_with_timerxl(
- model, full_data,
inference_attributes.get("predict_length", 96))]
- else:
- inference_results = [_inference_with_built_in_model(
- model, full_data)]
- else:
- # user-registered models
- model = _get_model(model_id, model_manager,
inference_attributes)
- dataset, dataset_length = _process_data(full_data)
- inference_results = _inference_with_registered_model(
- model, dataset, dataset_length, window_interval,
window_step)
- for i in range(len(inference_results)):
- inference_results[i] = convert_to_binary(inference_results[i])
- return TInferenceResp(
- get_status(
- TSStatusCode.SUCCESS_STATUS),
- inference_results)
- except Exception as e:
- logger.warning(e)
- inference_results = []
- return
TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)),
inference_results)
-def _inference_with_registered_model(model, dataset, dataset_length,
window_interval, window_step):
- """
- Args:
- model: the user-defined model
- full_data: a tuple of (data, time_stamp, type_list, column_name_list),
where the data is a DataFrame with shape
- (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a
list of data types with length C,
- column_name_list is a list of column names with length C, where L
is the number of data points, C is the
- number of variables, the data and time_stamp are aligned by index
- window_interval: the length of each sliding window
- window_step: the step between two adjacent sliding windows
- Returns:
- outputs: a list of output DataFrames, where each DataFrame has shape
(H', C'), where H' is the output window
- interval, C' is the number of variables in the output DataFrame
- Description:
- the inference_with_registered_model function will inference with deep
learning model, which is registered in
- user register process. This module will split the input data into
several sliding windows which has the same
- shape (1, H, C), where H is the window interval, and then feed each
sliding window into the model to get the
- output, the output is a DataFrame with shape (H', C'), where H' is the
output window interval, C' is the number
- of variables in the output DataFrame. Then the inference module will
concatenate all the output DataFrames into
- a list.
- """
+class BuiltInStrategy(InferenceStrategy):
+ def infer(self, full_data, **_):
+ data = full_data[1][0]
+ if data.dtype.byteorder not in ('=', '|'):
+ data = data.byteswap().newbyteorder()
+ output = self.model.inference(data)
+ df = pd.DataFrame(output)
+ return convert_to_binary(df)
- # check the validity of window_interval and window_step, the two arguments
must be positive integers, and the
- # window_interval should not be larger than the dataset length
- if window_interval is None or window_step is None \
- or window_interval > dataset_length \
- or window_interval <= 0 or \
- window_step <= 0:
- raise InvalidWindowArgumentError(window_interval, window_step,
dataset_length)
- sliding_times = int((dataset_length - window_interval) // window_step + 1)
- outputs = []
- try:
- # split the input data into several sliding windows
- for sliding_time in range(sliding_times):
- if window_step == float('inf'):
- start_index = 0
- else:
- start_index = sliding_time * window_step
- end_index = start_index + window_interval
- # input_data: tensor, shape: (1, H, C), where H is input window
interval
- input_data = dataset[:, start_index:end_index, :]
- # output: tensor, shape: (1, H', C'), where H' is the output
window interval
- output = model(input_data)
- # output: DataFrame, shape: (H', C')
- output = pd.DataFrame(output.squeeze(0).detach().numpy())
- outputs.append(output)
- except Exception as e:
- error_msg = runtime_error_extractor(str(e))
- if error_msg != "":
- raise InferenceModelInternalError(error_msg)
- raise InferenceModelInternalError(str(e))
+class RegisteredStrategy(InferenceStrategy):
+ def infer(self, full_data, window_interval=None, window_step=None,
**kwargs):
+ _, dataset, _, length = full_data
+ if window_interval is None or window_step is None:
+ window_interval = length
+ window_step = float('inf')
- return outputs
+ if window_interval <= 0 or window_step <= 0 or window_interval >
length:
+ raise InvalidWindowArgumentError(window_interval, window_step,
length)
+ data = torch.tensor(dataset,
dtype=torch.float32).unsqueeze(0).permute(0, 2, 1)
-def _inference_with_built_in_model(model, full_data):
- """
- Args:
- model: the built-in model
- full_data: a tuple of (data, time_stamp, type_list, column_name_list),
where the data is a DataFrame with shape
- (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a
list of data types with length C,
- column_name_list is a list of column names with length C, where L
is the number of data points, C is the
- number of variables, the data and time_stamp are aligned by index
- Returns:
- outputs: a list of output DataFrames, where each DataFrame has shape
(H', C'), where H' is the output window
- interval, C' is the number of variables in the output DataFrame
- Description:
- the inference_with_built_in_model function will inference with
built-in model, which does not
- require user registration. This module will parse the inference
attributes and create the built-in model, then
- feed the input data into the model to get the output, the output is a
DataFrame with shape (H', C'), where H'
- is the output window interval, C' is the number of variables in the
output DataFrame. Then the inference module
- will concatenate all the output DataFrames into a list.
- """
+ times = int((length - window_interval) // window_step + 1)
+ results = []
+ try:
+ for i in range(times):
+ start = 0 if window_step == float('inf') else i * window_step
+ end = start + window_interval
+ window = data[:, start:end, :]
+ out = self.model(window)
+ df = pd.DataFrame(out.squeeze(0).detach().numpy())
+ results.append(df)
+ except Exception as e:
+ msg = runtime_error_extractor(str(e)) or str(e)
+ raise InferenceModelInternalError(msg)
- data, _, _, _ = full_data
- output = model.inference(data)
- # output: DataFrame, shape: (H', C')
- output = pd.DataFrame(output)
- return output
+ # concatenate or return first window for forecast
+ return [convert_to_binary(df) for df in results]
-def _inference_with_timerxl(model, full_data, pred_len):
- """
- Args:
- model: the built-in model
- full_data: a tuple of (data, time_stamp, type_list, column_name_list),
where the data is a DataFrame with shape
- (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a
list of data types with length C,
- column_name_list is a list of column names with length C, where L
is the number of data points, C is the
- number of variables, the data and time_stamp are aligned by index
- Returns:
- outputs: a list of output DataFrames, where each DataFrame has shape
(H', C'), where H' is the output window
- interval, C' is the number of variables in the output DataFrame
- Description:
- the inference_with_built_in_model function will inference with
built-in model, which does not
- require user registration. This module will parse the inference
attributes and create the built-in model, then
- feed the input data into the model to get the output, the output is a
DataFrame with shape (H', C'), where H'
- is the output window interval, C' is the number of variables in the
output DataFrame. Then the inference module
- will concatenate all the output DataFrames into a list.
- """
- data, _, _, _ = full_data
- output = model.inference(data, pred_len)
- # output: DataFrame, shape: (H', C')
- output = pd.DataFrame(output)
- return output
+def _get_strategy(model_id, model):
+ if model_id == '_timerxl':
+ return TimerXLStrategy(model)
+ if model_id.startswith('_'):
+ return BuiltInStrategy(model)
+ return RegisteredStrategy(model)
-def _get_model(model_id: str, model_manager: ModelManager,
inference_attributes: {}):
- if inference_attributes is None or 'acceleration' not in
inference_attributes:
- # if the acceleration is not specified, then the acceleration will be
set to default value False
- acceleration = False
- else:
- # if the acceleration is specified, then the acceleration will be set
to the specified value
- acceleration = (inference_attributes['acceleration'].lower() == 'true')
- return model_manager.load_model(model_id, acceleration)
+class InferenceManager:
+ def __init__(self, model_manager: ModelManager):
+ self.model_manager = model_manager
-def _get_built_in_model(model_id: str, model_manager: ModelManager,
inference_attributes: {}):
- return model_manager.load_built_in_model(model_id, inference_attributes)
+ def _run(self, req, data_getter, deserializer, extract_attrs, resp_cls,
single_output: bool):
+ model_id = req.modelId
+ logger.info(f"Start processing for {model_id}")
+ try:
+ raw = data_getter(req)
+ full_data = deserializer(raw)
+ attrs = extract_attrs(req)
+ # load model
+ if model_id.startswith('_'):
+ model = self.model_manager.load_built_in_model(model_id, attrs)
+ else:
+ accel = str(attrs.get('acceleration', '')).lower() == 'true'
+ model = self.model_manager.load_model(model_id, accel)
+
+ # inference by strategy
+ strategy = _get_strategy(model_id, model)
+ outputs = strategy.infer(full_data, **attrs)
-def _parse_inference_request(req: TInferenceReq):
- binary_dataset = req.dataset
- type_list = req.typeList
- column_name_list = req.columnNameList
- column_name_index = req.columnNameIndexMap
- data = convert_to_df(column_name_list, type_list, column_name_index,
[binary_dataset])
- time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
- full_data = (data, time_stamp, type_list, column_name_list)
- inference_attributes = req.inferenceAttributes
- if inference_attributes is None:
- inference_attributes = {}
+ # construct response
+ status = get_status(TSStatusCode.SUCCESS_STATUS)
- window_params = req.windowParams
- if window_params is None:
- # set default window_step to infinity and window_interval to dataset
length
- window_step = float('inf')
- window_interval = data.shape[0]
- else:
- window_step = window_params.windowStep
- window_interval = window_params.windowInterval
- return req.modelId, full_data, window_interval, window_step,
inference_attributes
+ if isinstance(outputs, list):
+ return resp_cls(status, outputs[0] if single_output else
outputs)
+ return resp_cls(status, outputs if single_output else [outputs])
+
+ except Exception as e:
+ logger.error(e)
+ status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
+ empty = b'' if single_output else []
+ return resp_cls(status, empty)
+
+ def forecast(self, req: TForecastReq):
+ return self._run(
+ req,
+ data_getter=lambda r: r.inputData,
+ deserializer=deserialize,
+ extract_attrs=lambda r: {'predict_length': r.outputLength,
**(r.options or {})},
+ resp_cls=TForecastResp,
+ single_output=True
+ )
+
+ def inference(self, req: TInferenceReq):
+ return self._run(
+ req,
+ data_getter=lambda r: r.dataset,
+ deserializer=deserialize,
+ extract_attrs=lambda r: {
+ 'window_interval': getattr(r.windowParams, 'windowInterval',
None),
+ 'window_step': getattr(r.windowParams, 'windowStep', None),
+ **(r.inferenceAttributes or {})
+ },
+ resp_cls=TInferenceResp,
+ single_output=False
+ )
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index dd8c9ee5308..0d2991a7f9f 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.
#
+import os
from abc import abstractmethod
from typing import List, Dict
-import os
import numpy as np
from sklearn.preprocessing import MinMaxScaler
@@ -28,15 +28,15 @@ from sktime.forecasting.exp_smoothing import
ExponentialSmoothing
from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.trend import STLForecaster
+from ainode.TimerXL.models import timer_xl
+from ainode.TimerXL.models.configuration_timer import TimerxlConfig
+from ainode.core.config import AINodeDescriptor
from ainode.core.constant import AttributeName, BuiltInModelType
-from ainode.core.exception import InferenceModelInternalError,
AttributeNotSupportError
+from ainode.core.exception import InferenceModelInternalError
from ainode.core.exception import WrongAttributeTypeError,
NumericalRangeException, StringRangeException, \
ListRangeException, BuiltInModelNotSupportError
from ainode.core.log import Logger
-from ainode.TimerXL.models import timer_xl
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
logger = Logger()
@@ -79,11 +79,6 @@ def fetch_built_in_model(model_id, inference_attributes):
"""
attribute_map = get_model_attributes(model_id)
- # validate the inference attributes
- for attribute_name in inference_attributes:
- if attribute_name not in attribute_map:
- raise AttributeNotSupportError(model_id, attribute_name)
-
# parse the inference attributes, attributes is a Dict[str, Any]
attributes = parse_attribute(inference_attributes, attribute_map)
@@ -398,9 +393,10 @@ timerxl_attribute_map = {
),
AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute(
name=AttributeName.TIMERXL_CKPT_PATH.value,
- default_value=os.path.join(os.path.dirname(os.path.abspath(__file__)),
'weights', 'timerxl', 'model.safetensors'),
-
value_choices=[os.path.join(os.path.dirname(os.path.abspath(__file__)),
'weights', 'timerxl', 'model.safetensors'), ""],
- ),
+ default_value=os.path.join(os.getcwd(),
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
+ 'timerxl', 'model.safetensors'),
+ value_choices=['']
+ )
}
# built-in sktime model attributes
diff --git a/iotdb-core/ainode/ainode/core/util/serde.py
b/iotdb-core/ainode/ainode/core/util/serde.py
index b9edccfd03e..70b86d66095 100644
--- a/iotdb-core/ainode/ainode/core/util/serde.py
+++ b/iotdb-core/ainode/ainode/core/util/serde.py
@@ -97,164 +97,6 @@ def convert_to_binary(data_frame: pd.DataFrame):
return binary
-# convert tsBlock in binary to dataFrame
-def convert_to_df(name_list, type_list, name_index, binary_list):
- column_name_list = [TIMESTAMP_STR]
- column_type_list = [TSDataType.INT64]
- column_ordinal_dict = {TIMESTAMP_STR: 1}
-
- if name_index is not None:
- column_type_deduplicated_list = [
- None for _ in range(len(name_index))
- ]
- for i in range(len(name_list)):
- name = name_list[i]
- column_name_list.append(name)
- column_type_list.append(TSDataType[type_list[i]])
- if name not in column_ordinal_dict:
- index = name_index[name]
- column_ordinal_dict[name] = index + START_INDEX
- column_type_deduplicated_list[index] = TSDataType[type_list[i]]
- else:
- index = START_INDEX
- column_type_deduplicated_list = []
- for i in range(len(name_list)):
- name = name_list[i]
- column_name_list.append(name)
- column_type_list.append(TSDataType[type_list[i]])
- if name not in column_ordinal_dict:
- column_ordinal_dict[name] = index
- index += 1
- column_type_deduplicated_list.append(
- TSDataType[type_list[i]]
- )
-
- binary_size = len(binary_list)
- binary_index = 0
- result = {}
- for column_name in column_name_list:
- result[column_name] = None
-
- while binary_index < binary_size:
- buffer = binary_list[binary_index]
- binary_index += 1
- time_column_values, column_values, null_indicators, _ =
deserialize(buffer)
- time_array = np.frombuffer(
- time_column_values, np.dtype(np.longlong).newbyteorder(">")
- )
- if time_array.dtype.byteorder == ">":
- time_array =
time_array.byteswap().view(time_array.dtype.newbyteorder("<"))
-
- if result[TIMESTAMP_STR] is None:
- result[TIMESTAMP_STR] = time_array
- else:
- result[TIMESTAMP_STR] = np.concatenate(
- (result[TIMESTAMP_STR], time_array), axis=0
- )
- total_length = len(time_array)
-
- for i in range(len(column_values)):
- column_name = column_name_list[i + 1]
-
- location = column_ordinal_dict[column_name] - START_INDEX
- if location < 0:
- continue
-
- data_type = column_type_deduplicated_list[location]
- value_buffer = column_values[location]
- value_buffer_len = len(value_buffer)
-
- if data_type == TSDataType.DOUBLE:
- data_array = np.frombuffer(
- value_buffer, np.dtype(np.double).newbyteorder(">")
- )
- elif data_type == TSDataType.FLOAT:
- data_array = np.frombuffer(
- value_buffer, np.dtype(np.float32).newbyteorder(">")
- )
- elif data_type == TSDataType.BOOLEAN:
- data_array = []
- for index in range(len(value_buffer)):
- data_array.append(value_buffer[index])
- data_array = np.array(data_array).astype("bool")
- elif data_type == TSDataType.INT32:
- data_array = np.frombuffer(
- value_buffer, np.dtype(np.int32).newbyteorder(">")
- )
- elif data_type == TSDataType.INT64:
- data_array = np.frombuffer(
- value_buffer, np.dtype(np.int64).newbyteorder(">")
- )
- elif data_type == TSDataType.TEXT:
- index = 0
- data_array = []
- while index < value_buffer_len:
- value_bytes = value_buffer[index]
- value = value_bytes.decode("utf-8")
- data_array.append(value)
- index += 1
- data_array = np.array(data_array, dtype=object)
- else:
- raise RuntimeError("unsupported data type
{}.".format(data_type))
-
- if data_array.dtype.byteorder == ">":
- data_array =
data_array.byteswap().view(data_array.dtype.newbyteorder("<"))
-
- null_indicator = null_indicators[location]
- if len(data_array) < total_length or (data_type ==
TSDataType.BOOLEAN and null_indicator is not None):
- if data_type == TSDataType.INT32 or data_type ==
TSDataType.INT64:
- tmp_array = np.full(total_length, np.nan, np.float32)
- elif data_type == TSDataType.FLOAT or data_type ==
TSDataType.DOUBLE:
- tmp_array = np.full(total_length, np.nan, data_array.dtype)
- elif data_type == TSDataType.BOOLEAN:
- tmp_array = np.full(total_length, np.nan, np.float32)
- elif data_type == TSDataType.TEXT:
- tmp_array = np.full(total_length, np.nan,
dtype=data_array.dtype)
- else:
- raise Exception("Unsupported dataType in deserialization")
-
- if null_indicator is not None:
- indexes = [not v for v in null_indicator]
- if data_type == TSDataType.BOOLEAN:
- tmp_array[indexes] = data_array[indexes]
- else:
- tmp_array[indexes] = data_array
-
- if data_type == TSDataType.INT32:
- tmp_array = pd.Series(tmp_array).astype("Int32")
- elif data_type == TSDataType.INT64:
- tmp_array = pd.Series(tmp_array).astype("Int64")
- elif data_type == TSDataType.BOOLEAN:
- tmp_array = pd.Series(tmp_array).astype("boolean")
-
- data_array = tmp_array
-
- if result[column_name] is None:
- result[column_name] = data_array
- else:
- if isinstance(result[column_name], pd.Series):
- if not isinstance(data_array, pd.Series):
- if data_type == TSDataType.INT32:
- data_array = pd.Series(data_array).astype("Int32")
- elif data_type == TSDataType.INT64:
- data_array = pd.Series(data_array).astype("Int64")
- elif data_type == TSDataType.BOOLEAN:
- data_array =
pd.Series(data_array).astype("boolean")
- else:
- raise RuntimeError("Series Error")
- result[column_name] =
result[column_name].append(data_array)
- else:
- result[column_name] = np.concatenate(
- (result[column_name], data_array), axis=0
- )
- for k, v in result.items():
- if v is None:
- result[k] = []
- df = pd.DataFrame(result)
- df = df.reset_index(drop=True)
- return df
-
-
def _get_encoder(data_type: pd.Series):
if data_type == "bool":
return b'\x00'
@@ -284,74 +126,7 @@ def _get_type_in_byte(data_type: pd.Series):
"data_type should be in ['bool', 'int32',
'int64', 'float32', 'float64', 'text']")
-# Serialized tsBlock:
-#
+-------------+---------------+---------+------------+-----------+----------+
-# | val col cnt | val col types | pos cnt | encodings | time col | val
col |
-#
+-------------+---------------+---------+------------+-----------+----------+
-# | int32 | list[byte] | int32 | list[byte] | bytes | byte
|
-#
+-------------+---------------+---------+------------+-----------+----------+
-
-def deserialize(buffer):
- value_column_count, buffer = read_int_from_buffer(buffer)
- data_types, buffer = read_column_types(buffer, value_column_count)
-
- position_count, buffer = read_int_from_buffer(buffer)
- column_encodings, buffer = read_column_encoding(buffer, value_column_count
+ 1)
-
- time_column_values, buffer = read_time_column(buffer, position_count)
- column_values = [None] * value_column_count
- null_indicators = [None] * value_column_count
- for i in range(value_column_count):
- column_value, null_indicator, buffer = read_column(column_encodings[i
+ 1], buffer, data_types[i],
- position_count)
- column_values[i] = column_value
- null_indicators[i] = null_indicator
-
- return time_column_values, column_values, null_indicators, position_count
-
-
# General Methods
-
-def read_int_from_buffer(buffer):
- res, buffer = read_from_buffer(buffer, 4)
- return int.from_bytes(res, "big"), buffer
-
-
-def read_byte_from_buffer(buffer):
- return read_from_buffer(buffer, 1)
-
-
-def read_from_buffer(buffer, size):
- res = buffer[:size]
- buffer = buffer[size:]
- return res, buffer
-
-
-# Read ColumnType
-
-def read_column_types(buffer, value_column_count):
- data_types = []
- for _ in range(value_column_count):
- res, buffer = read_byte_from_buffer(buffer)
- data_types.append(get_data_type(res))
- return data_types, buffer
-
-
-def get_data_type(value):
- if value == b'\x00':
- return TSDataType.BOOLEAN
- elif value == b'\x01':
- return TSDataType.INT32
- elif value == b'\x02':
- return TSDataType.INT64
- elif value == b'\x03':
- return TSDataType.FLOAT
- elif value == b'\x04':
- return TSDataType.DOUBLE
- elif value == b'\x05':
- return TSDataType.TEXT
-
-
def get_data_type_byte_from_str(value):
'''
Args:
@@ -374,191 +149,3 @@ def get_data_type_byte_from_str(value):
return TSDataType.DOUBLE.value
elif value == "text":
return TSDataType.TEXT.value
-
-
-# Read ColumnEncodings
-
-def read_column_encoding(buffer, size):
- encodings = []
- for _ in range(size):
- res, buffer = read_byte_from_buffer(buffer)
- encodings.append(res)
- return encodings, buffer
-
-
-# Read Column
-
-def deserialize_null_indicators(buffer, size):
- may_have_null, buffer = read_byte_from_buffer(buffer)
- if may_have_null != b'\x00':
- return deserialize_from_boolean_array(buffer, size)
- return None, buffer
-
-
-# Serialized data layout:
-# +---------------+-----------------+-------------+
-# | may have null | null indicators | values |
-# +---------------+-----------------+-------------+
-# | byte | list[byte] | list[int64] |
-# +---------------+-----------------+-------------+
-
-def read_time_column(buffer, size):
- null_indicators, buffer = deserialize_null_indicators(buffer, size)
- if null_indicators is None:
- values, buffer = read_from_buffer(
- buffer, size * 8
- )
- else:
- raise Exception("TimeColumn should not contains null value")
- return values, buffer
-
-
-def read_int64_column(buffer, data_type, position_count):
- null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
- if null_indicators is None:
- size = position_count
- else:
- size = null_indicators.count(False)
-
- if TSDataType.INT64 == data_type or TSDataType.DOUBLE == data_type:
- values, buffer = read_from_buffer(buffer, size * 8)
- return values, null_indicators, buffer
- else:
- raise Exception("Invalid data type: " + data_type)
-
-
-# Serialized data layout:
-# +---------------+-----------------+-------------+
-# | may have null | null indicators | values |
-# +---------------+-----------------+-------------+
-# | byte | list[byte] | list[int32] |
-# +---------------+-----------------+-------------+
-
-def read_int32_column(buffer, data_type, position_count):
- null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
- if null_indicators is None:
- size = position_count
- else:
- size = null_indicators.count(False)
-
- if TSDataType.INT32 == data_type or TSDataType.FLOAT == data_type:
- values, buffer = read_from_buffer(buffer, size * 4)
- return values, null_indicators, buffer
- else:
- raise Exception("Invalid data type: " + data_type)
-
-
-# Serialized data layout:
-# +---------------+-----------------+-------------+
-# | may have null | null indicators | values |
-# +---------------+-----------------+-------------+
-# | byte | list[byte] | list[byte] |
-# +---------------+-----------------+-------------+
-
-def read_byte_column(buffer, data_type, position_count):
- if data_type != TSDataType.BOOLEAN:
- raise Exception("Invalid data type: " + data_type)
- null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
- res, buffer = deserialize_from_boolean_array(buffer, position_count)
- return res, null_indicators, buffer
-
-
-def deserialize_from_boolean_array(buffer, size):
- packed_boolean_array, buffer = read_from_buffer(buffer, (size + 7) // 8)
- current_byte = 0
- output = [None] * size
- position = 0
- # read null bits 8 at a time
- while position < (size & ~0b111):
- value = packed_boolean_array[current_byte]
- output[position] = ((value & 0b1000_0000) != 0)
- output[position + 1] = ((value & 0b0100_0000) != 0)
- output[position + 2] = ((value & 0b0010_0000) != 0)
- output[position + 3] = ((value & 0b0001_0000) != 0)
- output[position + 4] = ((value & 0b0000_1000) != 0)
- output[position + 5] = ((value & 0b0000_0100) != 0)
- output[position + 6] = ((value & 0b0000_0010) != 0)
- output[position + 7] = ((value & 0b0000_0001) != 0)
-
- position += 8
- current_byte += 1
- # read last null bits
- if (size & 0b111) > 0:
- value = packed_boolean_array[-1]
- mask = 0b1000_0000
- position = size & ~0b111
- while position < size:
- output[position] = ((value & mask) != 0)
- mask >>= 1
- position += 1
- return output, buffer
-
-
-# Serialized data layout:
-# +---------------+-----------------+-------------+
-# | may have null | null indicators | values |
-# +---------------+-----------------+-------------+
-# | byte | list[byte] | list[entry] |
-# +---------------+-----------------+-------------+
-#
-# Each entry is represented as:
-# +---------------+-------+
-# | value length | value |
-# +---------------+-------+
-# | int32 | bytes |
-# +---------------+-------+
-
-def read_binary_column(buffer, data_type, position_count):
- if data_type != TSDataType.TEXT:
- raise Exception("Invalid data type: " + data_type)
- null_indicators, buffer = deserialize_null_indicators(buffer,
position_count)
-
- if null_indicators is None:
- size = position_count
- else:
- size = null_indicators.count(False)
- values = [None] * size
- for i in range(size):
- length, buffer = read_int_from_buffer(buffer)
- res, buffer = read_from_buffer(buffer, length)
- values[i] = res
- return values, null_indicators, buffer
-
-
-def read_column(encoding, buffer, data_type, position_count):
- if encoding == b'\x00':
- return read_byte_column(buffer, data_type, position_count)
- elif encoding == b'\x01':
- return read_int32_column(buffer, data_type, position_count)
- elif encoding == b'\x02':
- return read_int64_column(buffer, data_type, position_count)
- elif encoding == b'\x03':
- return read_binary_column(buffer, data_type, position_count)
- elif encoding == b'\x04':
- return read_run_length_column(buffer, data_type, position_count)
- else:
- raise Exception("Unsupported encoding: " + encoding)
-
-
-# Serialized data layout:
-# +-----------+-------------------------+
-# | encoding | serialized inner column |
-# +-----------+-------------------------+
-# | byte | list[byte] |
-# +-----------+-------------------------+
-
-def read_run_length_column(buffer, data_type, position_count):
- encoding, buffer = read_byte_from_buffer(buffer)
- column, null_indicators, buffer = read_column(encoding, buffer, data_type,
1)
-
- return repeat(column, data_type, position_count), null_indicators *
position_count, buffer
-
-
-def repeat(buffer, data_type, position_count):
- if data_type == TSDataType.BOOLEAN or data_type == TSDataType.TEXT:
- return buffer * position_count
- else:
- res = bytes(0)
- for _ in range(position_count):
- res.join(buffer)
- return res
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 21b317d9854..e2beede330a 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -75,6 +75,7 @@ public class ModelInfo implements SnapshotProcessor {
private static final Set<String> builtInAnomalyDetectionModel = new
HashSet<>();
static {
+ builtInForecastModel.add("_timerxl");
builtInForecastModel.add("_ARIMA");
builtInForecastModel.add("_NaiveForecaster");
builtInForecastModel.add("_STLForecaster");
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
index 9bdb57dc5bd..a384be3ad24 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
@@ -38,7 +38,6 @@ import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
-import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlock;
import org.apache.tsfile.read.common.block.TsBlockBuilder;
import org.apache.tsfile.read.common.block.column.TimeColumnBuilder;
@@ -47,12 +46,9 @@ import org.apache.tsfile.utils.RamUsageEstimator;
import java.nio.ByteBuffer;
import java.util.Arrays;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
-import java.util.stream.Collectors;
import static com.google.common.util.concurrent.Futures.successfulAsList;
@@ -74,8 +70,7 @@ public class InferenceOperator implements ProcessOperator {
private final long maxRetainedSize;
private final long maxReturnSize;
- private final List<String> inputColumnNames;
- private final List<String> targetColumnNames;
+ private final int[] columnIndexes;
private long totalRow;
private int resultIndex = 0;
private List<ByteBuffer> results;
@@ -105,8 +100,11 @@ public class InferenceOperator implements ProcessOperator {
new TsBlockBuilder(
Arrays.asList(modelInferenceDescriptor.getModelInformation().getInputDataType()));
this.modelInferenceExecutor = modelInferenceExecutor;
- this.targetColumnNames = targetColumnNames;
- this.inputColumnNames = inputColumnNames;
+ this.columnIndexes = new int[inputColumnNames.size()];
+ for (int i = 0; i < inputColumnNames.size(); i++) {
+ columnIndexes[i] = targetColumnNames.indexOf(inputColumnNames.get(i));
+ }
+
this.maxRetainedSize = maxRetainedSize;
this.maxReturnSize = maxReturnSize;
this.totalRow = 0;
@@ -232,7 +230,7 @@ public class InferenceOperator implements ProcessOperator {
}
timeColumnBuilder.writeLong(timestamp);
for (int columnIndex = 0; columnIndex <
inputTsBlock.getValueColumnCount(); columnIndex++) {
- columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex),
i);
+
columnBuilders[columnIndexes[columnIndex]].write(inputTsBlock.getColumn(columnIndex),
i);
}
inputTsBlockBuilder.declarePosition();
}
@@ -304,12 +302,6 @@ public class InferenceOperator implements ProcessOperator {
TsBlock finalInputTsBlock = preProcess(inputTsBlock);
TWindowParams windowParams = getWindowParams();
- Map<String, Integer> columnNameIndexMap = new HashMap<>();
-
- for (int i = 0; i < inputColumnNames.size(); i++) {
- columnNameIndexMap.put(inputColumnNames.get(i), i);
- }
-
inferenceExecutionFuture =
Futures.submit(
() -> {
@@ -318,11 +310,6 @@ public class InferenceOperator implements ProcessOperator {
.borrowClient(modelInferenceDescriptor.getTargetAINode())) {
return client.inference(
modelInferenceDescriptor.getModelName(),
- targetColumnNames,
-
Arrays.stream(modelInferenceDescriptor.getModelInformation().getInputDataType())
- .map(TSDataType::toString)
- .collect(Collectors.toList()),
- columnNameIndexMap,
finalInputTsBlock,
modelInferenceDescriptor.getInferenceAttributes(),
windowParams);
@@ -367,11 +354,6 @@ public class InferenceOperator implements ProcessOperator {
+ MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(child)
+
MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(operatorContext)
+ inputTsBlockBuilder.getRetainedSizeInBytes()
- + (inputColumnNames == null
- ? 0
- :
inputColumnNames.stream().mapToLong(RamUsageEstimator::sizeOf).sum())
- + (targetColumnNames == null
- ? 0
- :
targetColumnNames.stream().mapToLong(RamUsageEstimator::sizeOf).sum());
+ + (long) columnIndexes.length * Integer.BYTES;
}
}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index 3cc416f0aad..5532d283bcb 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -56,7 +56,6 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
-import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -161,21 +160,12 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
public TInferenceResp inference(
String modelId,
- List<String> inputColumnNames,
- List<String> inputTypeList,
- Map<String, Integer> columnIndexMap,
TsBlock inputTsBlock,
Map<String, String> inferenceAttributes,
TWindowParams windowParams)
throws TException {
try {
- TInferenceReq inferenceReq =
- new TInferenceReq(
- modelId,
- tsBlockSerde.serialize(inputTsBlock),
- inputTypeList,
- inputColumnNames,
- columnIndexMap);
+ TInferenceReq inferenceReq = new TInferenceReq(modelId,
tsBlockSerde.serialize(inputTsBlock));
if (windowParams != null) {
inferenceReq.setWindowParams(windowParams);
}
@@ -184,10 +174,10 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
}
return client.inference(inferenceReq);
} catch (IOException e) {
- throw new TException("An exception occurred while serializing input
tsblock", e);
+ throw new TException("An exception occurred while serializing input
data", e);
} catch (TException e) {
logger.warn(
- "Failed to connect to AINode from DataNode when executing {}: {}",
+ "Error happens in AINode when executing {}: {}",
Thread.currentThread().getStackTrace()[1].getMethodName(),
e.getMessage());
throw new TException(MSG_CONNECTION_FAIL);
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 5643da743a8..db1a15c2460 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -58,11 +58,8 @@ struct TRegisterModelResp {
struct TInferenceReq {
1: required string modelId
2: required binary dataset
- 3: required list<string> typeList
- 4: required list<string> columnNameList
- 5: required map<string, i32> columnNameIndexMap
- 6: optional TWindowParams windowParams
- 7: optional map<string, string> inferenceAttributes
+ 3: optional TWindowParams windowParams
+ 4: optional map<string, string> inferenceAttributes
}
struct TWindowParams {