This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 412fec5a7c1 [AINode] Update forecast interface (#16978)
412fec5a7c1 is described below
commit 412fec5a7c1c6858c365784f3ddc02f824932c77
Author: Leo <[email protected]>
AuthorDate: Wed Jan 7 18:39:26 2026 +0800
[AINode] Update forecast interface (#16978)
---
.../apache/iotdb/ainode/it/AINodeForecastIT.java | 2 +-
.../core/inference/inference_request_pool.py | 12 +-
.../core/inference/pipeline/basic_pipeline.py | 163 ++++++++++++++++++---
.../iotdb/ainode/core/manager/inference_manager.py | 22 ++-
.../core/model/chronos2/pipeline_chronos2.py | 105 +++++++++++--
.../ainode/core/model/sktime/arima/config.json | 2 +-
.../core/model/sktime/configuration_sktime.py | 8 +-
.../model/sktime/exponential_smoothing/config.json | 2 +-
.../ainode/core/model/sktime/modeling_sktime.py | 18 +--
.../core/model/sktime/naive_forecaster/config.json | 2 +-
.../ainode/core/model/sktime/pipeline_sktime.py | 110 +++++++++-----
.../core/model/sktime/stl_forecaster/config.json | 2 +-
.../ainode/core/model/sundial/pipeline_sundial.py | 74 ++++++++--
.../ainode/core/model/timer_xl/pipeline_timer.py | 74 ++++++++--
.../function/tvf/ForecastTableFunction.java | 66 ++++-----
.../thrift-ainode/src/main/thrift/ainode.thrift | 4 +-
16 files changed, 514 insertions(+), 152 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
index c2114ac9499..e2d759d4bcf 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
@@ -113,7 +113,7 @@ public class AINodeForecastIT {
}
}
- public void forecastTableFunctionErrorTest(
+ public static void forecastTableFunctionErrorTest(
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws
SQLException {
// OUTPUT_START_TIME error
String invalidOutputStartTimeSQL =
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index 3cca9b183c8..fb03e0af520 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -123,11 +123,16 @@ class InferenceRequestPool(mp.Process):
batch_inputs = self._batcher.batch_request(requests).to(
"cpu"
) # The input data should first load to CPU in current version
- batch_inputs = self._inference_pipeline.preprocess(batch_inputs)
+ batch_input_list = []
+ for i in range(batch_inputs.size(0)):
+ batch_input_list.append({"targets": batch_inputs[i]})
+ batch_inputs = self._inference_pipeline.preprocess(
+ batch_input_list, output_length=requests[0].output_length
+ )
if isinstance(self._inference_pipeline, ForecastPipeline):
batch_output = self._inference_pipeline.forecast(
batch_inputs,
- predict_length=requests[0].output_length,
+ output_length=requests[0].output_length,
revin=True,
)
elif isinstance(self._inference_pipeline, ClassificationPipeline):
@@ -143,7 +148,8 @@ class InferenceRequestPool(mp.Process):
else:
batch_output = None
self._logger.error("[Inference] Unsupported pipeline type.")
- batch_output = self._inference_pipeline.postprocess(batch_output)
+ batch_output_list =
self._inference_pipeline.postprocess(batch_output)
+ batch_output = torch.stack([output for output in
batch_output_list], dim=0)
offset = 0
for request in requests:
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
index d7345d3140f..f1704fb90c4 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -21,24 +21,25 @@ from abc import ABC, abstractmethod
import torch
from iotdb.ainode.core.exception import InferenceModelInternalException
+from iotdb.ainode.core.model.model_info import ModelInfo
from iotdb.ainode.core.model.model_loader import load_model
class BasicPipeline(ABC):
- def __init__(self, model_info, **model_kwargs):
+ def __init__(self, model_info: ModelInfo, **model_kwargs):
self.model_info = model_info
self.device = model_kwargs.get("device", "cpu")
self.model = load_model(model_info, device_map=self.device,
**model_kwargs)
@abstractmethod
- def preprocess(self, inputs):
+ def preprocess(self, inputs, **infer_kwargs):
"""
Preprocess the input before inference, including shape validation and
value transformation.
"""
raise NotImplementedError("preprocess not implemented")
@abstractmethod
- def postprocess(self, outputs: torch.Tensor):
+ def postprocess(self, outputs, **infer_kwargs):
"""
Post-process the outputs after the entire inference task.
"""
@@ -46,59 +47,181 @@ class BasicPipeline(ABC):
class ForecastPipeline(BasicPipeline):
- def __init__(self, model_info, **model_kwargs):
+ def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, model_kwargs=model_kwargs)
- def preprocess(self, inputs):
+ def preprocess(
+ self,
+ inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]],
+ **infer_kwargs,
+ ):
"""
- The inputs should be 3D tensor: [batch_size, target_count,
sequence_length].
+ Preprocess the input data before passing it to the model for
inference, validating the shape and type of the input data.
+
+ Args:
+ inputs (list[dict]):
+ The input data, a list of dictionaries, where each dictionary
contains:
+ - 'targets': A tensor (1D or 2D) of shape (input_length,)
or (target_count, input_length).
+ - 'past_covariates': A dictionary of tensors (optional),
where each tensor has shape (input_length,).
+ - 'future_covariates': A dictionary of tensors (optional),
where each tensor has shape (input_length,).
+
+ infer_kwargs (dict, optional): Additional keyword arguments for
inference, such as:
+ - `output_length`(int): Used to check validation of
'future_covariates' if provided.
+
+ Raises:
+ ValueError: If the input format is incorrect (e.g., missing keys,
invalid tensor shapes).
+
+ Returns:
+ The preprocessed inputs, validated and ready for model inference.
"""
- if len(inputs.shape) != 3:
- raise InferenceModelInternalException(
- f"[Inference] Input must be: [batch_size, target_count,
sequence_length], but receives {inputs.shape}"
+
+ if isinstance(inputs, list):
+ output_length = infer_kwargs.get("output_length", 96)
+ for idx, input_dict in enumerate(inputs):
+ # Check if the dictionary contains the expected keys
+ if not isinstance(input_dict, dict):
+ raise ValueError(f"Input at index {idx} is not a
dictionary.")
+
+ required_keys = ["targets"]
+ for key in required_keys:
+ if key not in input_dict:
+ raise ValueError(
+ f"Key '{key}' is missing in input at index {idx}."
+ )
+
+ # Check 'targets' is torch.Tensor and has the correct shape
+ targets = input_dict["targets"]
+ if not isinstance(targets, torch.Tensor):
+ raise ValueError(
+ f"'targets' must be torch.Tensor, but got
{type(targets)} at index {idx}."
+ )
+ if targets.ndim not in [1, 2]:
+ raise ValueError(
+ f"'targets' must have 1 or 2 dimensions, but got
{targets.ndim} dimensions at index {idx}."
+ )
+ # If targets is 2-d, check if the second dimension is
input_length
+ if targets.ndim == 2:
+ n_variates, input_length = targets.shape
+ else:
+ input_length = targets.shape[
+ 0
+ ] # for 1-d targets, shape should be (input_length,)
+
+ # Check 'past_covariates' if it exists (optional)
+ past_covariates = input_dict.get("past_covariates", {})
+ if not isinstance(past_covariates, dict):
+ raise ValueError(
+ f"'past_covariates' must be a dictionary, but got
{type(past_covariates)} at index {idx}."
+ )
+ for cov_key, cov_value in past_covariates.items():
+ if not isinstance(cov_value, torch.Tensor):
+ raise ValueError(
+ f"Each value in 'past_covariates' must be
torch.Tensor, but got {type(cov_value)} for key '{cov_key}' at index {idx}."
+ )
+ if cov_value.ndim != 1 or cov_value.shape[0] !=
input_length:
+ raise ValueError(
+ f"Each covariate in 'past_covariates' must have
shape ({input_length},), but got shape {cov_value.shape} for key '{cov_key}' at
index {idx}."
+ )
+
+ # Check 'future_covariates' if it exists (optional)
+ future_covariates = input_dict.get("future_covariates", {})
+ if not isinstance(future_covariates, dict):
+ raise ValueError(
+ f"'future_covariates' must be a dictionary, but got
{type(future_covariates)} at index {idx}."
+ )
+ # If future_covariates exists, check if they are a subset of
past_covariates
+ if future_covariates:
+ for cov_key, cov_value in future_covariates.items():
+ if cov_key not in past_covariates:
+ raise ValueError(
+ f"Key '{cov_key}' in 'future_covariates' is
not in 'past_covariates' at index {idx}."
+ )
+ if not isinstance(cov_value, torch.Tensor):
+ raise ValueError(
+ f"Each value in 'future_covariates' must be
torch.Tensor, but got {type(cov_value)} for key '{cov_key}' at index {idx}."
+ )
+ if cov_value.ndim != 1 or cov_value.shape[0] !=
output_length:
+ raise ValueError(
+ f"Each covariate in 'future_covariates' must
have shape ({output_length},), but got shape {cov_value.shape} for key
'{cov_key}' at index {idx}."
+ )
+ else:
+ raise ValueError(
+ f"The inputs must be a list of dictionaries, but got
{type(inputs)}."
)
return inputs
@abstractmethod
def forecast(self, inputs, **infer_kwargs):
+ """
+ Perform forecasting on the given inputs.
+
+ Parameters:
+ inputs: The input data used for making predictions. The type and
structure
+ depend on the specific implementation of the model.
+ **infer_kwargs: Additional inference parameters such as:
+ - `output_length`(int): The number of time points that model
should generate.
+
+ Returns:
+ The forecasted output, which will depend on the specific model's
implementation.
+ """
pass
- def postprocess(self, outputs: torch.Tensor):
+ def postprocess(
+ self, outputs: list[torch.Tensor], **infer_kwargs
+ ) -> list[torch.Tensor]:
"""
- The outputs should be 3D tensor: [batch_size, target_count,
predict_length].
+ Postprocess the model outputs after inference, validating the shape of
the output data and ensures it matches the expected dimensions.
+
+ Args:
+ outputs:
+ The model outputs, which is a list of 2D tensors, where each
tensor has shape `[target_count, output_length]`.
+
+ Raises:
+ InferenceModelInternalException: If the output tensor has an
invalid shape (e.g., wrong number of dimensions).
+ ValueError: If the output format is incorrect.
+
+ Returns:
+ list[torch.Tensor]:
+ The postprocessed outputs, which will be a list of 2D tensors.
"""
- if len(outputs.shape) != 3:
- raise InferenceModelInternalException(
- f"[Inference] Output must be: [batch_size, target_count,
predict_length], but receives {outputs.shape}"
+ if isinstance(outputs, list):
+ for idx, output in enumerate(outputs):
+ if output.ndim != 2:
+ raise InferenceModelInternalException(
+ f"Output in outputs_list should be 2D-tensor, but
receives {output.ndim} dims at index {idx}."
+ )
+ else:
+ raise ValueError(
+ f"The outputs should be a list of 2D-tensors, but got
{type(outputs)}."
)
return outputs
class ClassificationPipeline(BasicPipeline):
- def __init__(self, model_info, **model_kwargs):
+ def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, model_kwargs=model_kwargs)
- def preprocess(self, inputs):
+ def preprocess(self, inputs, **kwargs):
return inputs
@abstractmethod
def classify(self, inputs, **kwargs):
pass
- def postprocess(self, outputs: torch.Tensor):
+ def postprocess(self, outputs, **kwargs):
return outputs
class ChatPipeline(BasicPipeline):
- def __init__(self, model_info, **model_kwargs):
+ def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, model_kwargs=model_kwargs)
- def preprocess(self, inputs):
+ def preprocess(self, inputs, **kwargs):
return inputs
@abstractmethod
def chat(self, inputs, **kwargs):
pass
- def postprocess(self, outputs: torch.Tensor):
+ def postprocess(self, outputs, **kwargs):
return outputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index d3f77e993a1..ada641dd54c 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -20,7 +20,6 @@ import threading
import time
from typing import Dict
-import pandas as pd
import torch
import torch.multiprocessing as mp
@@ -183,6 +182,13 @@ class InferenceManager:
inference_attrs = extract_attrs(req)
output_length = int(inference_attrs.pop("output_length", 96))
+
+ # model_inputs_list: Each element is a dict, which contains the
following keys:
+ # `targets`: The input tensor for the target variable(s), whose
shape is [target_count, input_length].
+ model_inputs_list: list[
+ dict[str, torch.Tensor | dict[str, torch.Tensor]]
+ ] = [{"targets": inputs[0]}]
+
if (
output_length
>
AINodeDescriptor().get_config().get_ain_inference_max_output_length()
@@ -200,17 +206,21 @@ class InferenceManager:
infer_req = InferenceRequest(
req_id=generate_req_id(),
model_id=model_id,
- inputs=inputs,
+ inputs=torch.stack(
+ [data["targets"] for data in model_inputs_list], dim=0
+ ),
output_length=output_length,
)
outputs = self._process_request(infer_req)
else:
model_info = self._model_manager.get_model_info(model_id)
inference_pipeline = load_pipeline(model_info, device="cpu")
- inputs = inference_pipeline.preprocess(inputs)
+ inputs = inference_pipeline.preprocess(
+ model_inputs_list, output_length=output_length
+ )
if isinstance(inference_pipeline, ForecastPipeline):
outputs = inference_pipeline.forecast(
- inputs, predict_length=output_length, **inference_attrs
+ inputs, output_length=output_length, **inference_attrs
)
elif isinstance(inference_pipeline, ClassificationPipeline):
outputs = inference_pipeline.classify(inputs)
@@ -223,8 +233,8 @@ class InferenceManager:
# convert tensor into tsblock for the output in each batch
output_list = []
- for batch_idx in range(outputs.size(0)):
- output = convert_tensor_to_tsblock(outputs[batch_idx])
+ for batch_idx, output in enumerate(outputs):
+ output = convert_tensor_to_tsblock(output)
output_list.append(output)
return resp_cls(
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
index b99a9307841..3fdc7b41b17 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
@@ -21,7 +21,6 @@ import torch
from einops import rearrange, repeat
from torch.utils.data import DataLoader
-from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pipeline.basic_pipeline import
ForecastPipeline
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.chronos2.dataset import Chronos2Dataset,
DatasetMode
@@ -37,8 +36,35 @@ class Chronos2Pipeline(ForecastPipeline):
def __init__(self, model_info, **model_kwargs):
super().__init__(model_info, model_kwargs=model_kwargs)
- def preprocess(self, inputs):
- inputs = super().preprocess(inputs)
+ def preprocess(self, inputs, **infer_kwargs):
+ """
+ Preprocess input data of chronos2.
+
+ Parameters
+ ----------
+ inputs : list of dict
+ A list of dictionaries containing input data. Each dictionary
contains:
+ - 'targets': A tensor (1D or 2D) of shape (input_length,) or
(target_count, input_length).
+ - 'past_covariates': A dictionary of tensors (optional), where
each tensor has shape (input_length,).
+ - 'future_covariates': A dictionary of tensors (optional), where
each tensor has shape (input_length,).
+
+ infer_kwargs: Additional keyword arguments for inference, such as:
+ - `output_length`(int): Used to check validation of
'future_covariates' if provided.
+
+ Returns
+ -------
+ list of dict
+ Processed inputs with the following structure for each dictionary:
+ - 'target': torch.Tensor
+ The renamed target tensor (originally 'targets'), shape
[target_count, input_length].
+ - 'past_covariates' (optional): dict of str to torch.Tensor
+ Unchanged past covariates.
+ - 'future_covariates' (optional): dict of str to torch.Tensor
+ Unchanged future covariates.
+ """
+ super().preprocess(inputs, **infer_kwargs)
+ for item in inputs:
+ item["target"] = item.pop("targets")
return inputs
@property
@@ -206,9 +232,28 @@ class Chronos2Pipeline(ForecastPipeline):
return prediction, context, future_covariates
- def forecast(self, inputs, **infer_kwargs):
+ def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]:
+ """
+ Generate forecasts for the input time series.
+
+ Parameters
+ ----------
+ inputs :
+ - A 3D `torch.Tensor` or `np.ndarray` of shape (batch_size,
n_variates, history_length).
+ - A list of 1D or 2D `torch.Tensor` or `np.ndarray`, where each
element has shape (history_length,) or (n_variates, history_length).
+ - A list of dictionaries, each with:
+ - `target` (required): 1D or 2D `torch.Tensor` or `np.ndarray`
of shape (history_length,) or (n_variates, history_length).
+ - `past_covariates` (optional): dict of past-only covariates
with 1D `torch.Tensor` or `np.ndarray`.
+ - `future_covariates` (optional): dict of future covariates
with 1D `torch.Tensor` or `np.ndarray`.
+
+ **infer_kwargs** : Additional arguments for inference.
+
+ Returns
+ -------
+ list of torch.Tensor : The model's predictions, each of shape
(n_variates, n_quantiles, prediction_length).
+ """
model_prediction_length = self.model_prediction_length
- prediction_length = infer_kwargs.get("predict_length", 96)
+ prediction_length = infer_kwargs.get("output_length", 96)
# The maximum number of output patches to generate in a single forward
pass before the long-horizon heuristic kicks in. Note: A value larger
# than the model's default max_output_patches may lead to degradation
in forecast accuracy, defaults to a model-specific value
max_output_patches = infer_kwargs.get(
@@ -218,10 +263,13 @@ class Chronos2Pipeline(ForecastPipeline):
# are appended to the historical context and input into the model
autoregressively to generate long-horizon predictions. Note that the
# effective batch size increases by a factor of
`len(unrolled_quantiles)` when making long-horizon predictions,
# by default [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
+
+ # Note that this parameter is used for long horizon forecasting, the
default quantile_levels
+ # are [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
+ # 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]
unrolled_quantiles = infer_kwargs.get(
"unrolled_quantiles", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
)
-
if not set(unrolled_quantiles).issubset(self.quantiles):
raise ValueError(
f"Unrolled quantiles must be a subset of the model's
quantiles. "
@@ -236,7 +284,9 @@ class Chronos2Pipeline(ForecastPipeline):
)
logger.warning(msg)
- context_length = inputs.shape[-1]
+ # The maximum context length used during for inference,
+ # by default set to the model's default context length
+ context_length = infer_kwargs.get("context_length",
self.model_context_length)
if context_length > self.model_context_length:
logger.warning(
f"The specified context_length {context_length} is greater
than the model's default context length {self.model_context_length}. "
@@ -244,11 +294,16 @@ class Chronos2Pipeline(ForecastPipeline):
)
context_length = self.model_context_length
+ # The batch size used for prediction.
+ # Note that the batch size here means the number of time series,
+ # including target(s) and covariates,which are input into the model.
+ batch_size = infer_kwargs.get("batch_size", 256)
+
test_dataset = Chronos2Dataset.convert_inputs(
inputs=inputs,
context_length=context_length,
prediction_length=prediction_length,
- batch_size=256,
+ batch_size=batch_size,
output_patch_size=self.model_output_patch_size,
mode=DatasetMode.TEST,
)
@@ -268,6 +323,13 @@ class Chronos2Pipeline(ForecastPipeline):
batch_future_covariates = batch["future_covariates"]
batch_target_idx_ranges = batch["target_idx_ranges"]
+ # If True, cross-learning is enabled,
+ # i.e., all the tasks in `inputs` will be predicted jointly and
the model will share information across all inputs,
+ # by default False
+ predict_batches_jointly =
infer_kwargs.get("predict_batches_jointly", False)
+ if predict_batches_jointly:
+ batch_group_ids = torch.zeros_like(batch_group_ids)
+
batch_prediction = self._predict_batch(
context=batch_context,
group_ids=batch_group_ids,
@@ -387,5 +449,28 @@ class Chronos2Pipeline(ForecastPipeline):
return prediction
- def postprocess(self, output: torch.Tensor):
- return output[0].mean(dim=1, keepdim=True)
+ def postprocess(
+ self, outputs: list[torch.Tensor], **infer_kwargs
+ ) -> list[torch.Tensor]:
+ """
+ Postprocesses the model's forecast outputs by selecting the 0.5
quantile or averaging over quantiles.
+
+ Args:
+ outputs (list[torch.Tensor]): List of forecast outputs, where each
output is a 3D-tensor with shape [target_count, quantile_count, output_length].
+
+ Returns:
+ list[torch.Tensor]: Processed list of forecast outputs, each is a
2D-tensor with shape [target_count, output_length].
+ """
+ outputs_list = []
+ for output in outputs:
+ # Check if 0.5 quantile is available
+ if 0.5 in self.quantiles:
+ idx = self.quantiles.index(0.5)
+ # Get the 0.5 quantile value
+ outputs_list.append(output[:, idx, :])
+ else:
+ # If 0.5 quantile is not provided,
+ # get the mean of all quantiles
+ outputs_list.append(output.mean(dim=1))
+ super().postprocess(outputs_list, **infer_kwargs)
+ return outputs_list
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json
index 1561124badd..c9f80477643 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json
@@ -1,7 +1,7 @@
{
"model_type": "sktime",
"model_id": "arima",
- "predict_length": 1,
+ "output_length": 1,
"order": [1, 0, 0],
"seasonal_order": [0, 0, 0, 0],
"start_params": null,
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
index d9d20545af6..f6d0c94ee39 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
@@ -153,7 +153,7 @@ class AttributeConfig:
# Model configuration definitions - using concise dictionary format
MODEL_CONFIGS = {
"NAIVE_FORECASTER": {
- "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "output_length": AttributeConfig("output_length", 1, "int", 1, 5000),
"strategy": AttributeConfig(
"strategy", "last", "str", choices=["last", "mean", "drift"]
),
@@ -161,7 +161,7 @@ MODEL_CONFIGS = {
"sp": AttributeConfig("sp", 1, "int", 1, 5000),
},
"EXPONENTIAL_SMOOTHING": {
- "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "output_length": AttributeConfig("output_length", 1, "int", 1, 5000),
"damped_trend": AttributeConfig("damped_trend", False, "bool"),
"initialization_method": AttributeConfig(
"initialization_method",
@@ -174,7 +174,7 @@ MODEL_CONFIGS = {
"use_brute": AttributeConfig("use_brute", False, "bool"),
},
"ARIMA": {
- "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "output_length": AttributeConfig("output_length", 1, "int", 1, 5000),
"order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int),
"seasonal_order": AttributeConfig(
"seasonal_order", (0, 0, 0, 0), "tuple", value_type=int
@@ -212,7 +212,7 @@ MODEL_CONFIGS = {
"concentrate_scale": AttributeConfig("concentrate_scale", False,
"bool"),
},
"STL_FORECASTER": {
- "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "output_length": AttributeConfig("output_length", 1, "int", 1, 5000),
"sp": AttributeConfig("sp", 2, "int", 1, 5000),
"seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000),
"trend": AttributeConfig("trend", None, "int"),
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json
index 4126d9de857..a0550548bbe 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json
+++
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json
@@ -1,7 +1,7 @@
{
"model_type": "sktime",
"model_id": "exponential_smoothing",
- "predict_length": 1,
+ "output_length": 1,
"damped_trend": false,
"initialization_method": "estimated",
"optimized": true,
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
index 9ddbcab286f..11277eb26e5 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
@@ -59,11 +59,11 @@ class ForecastingModel(SktimeModel):
def generate(self, data, **kwargs):
"""Execute forecasting"""
try:
- predict_length = kwargs.get(
- "predict_length", self._attributes["predict_length"]
+ output_length = kwargs.get(
+ "output_length", self._attributes["output_length"]
)
self._model.fit(data)
- output = self._model.predict(fh=range(predict_length))
+ output = self._model.predict(fh=range(output_length))
return np.array(output, dtype=np.float64)
except Exception as e:
raise InferenceModelInternalException(str(e))
@@ -75,8 +75,8 @@ class DetectionModel(SktimeModel):
def generate(self, data, **kwargs):
"""Execute detection"""
try:
- predict_length = kwargs.get("predict_length", data.size)
- output = self._model.fit_transform(data[:predict_length])
+ output_length = kwargs.get("output_length", data.size)
+ output = self._model.fit_transform(data[:output_length])
if isinstance(output, pd.DataFrame):
return np.array(output["labels"], dtype=np.int32)
else:
@@ -91,7 +91,7 @@ class ArimaModel(ForecastingModel):
def __init__(self, attributes: Dict[str, Any]):
super().__init__(attributes)
self._model = ARIMA(
- **{k: v for k, v in attributes.items() if k != "predict_length"}
+ **{k: v for k, v in attributes.items() if k != "output_length"}
)
@@ -101,7 +101,7 @@ class ExponentialSmoothingModel(ForecastingModel):
def __init__(self, attributes: Dict[str, Any]):
super().__init__(attributes)
self._model = ExponentialSmoothing(
- **{k: v for k, v in attributes.items() if k != "predict_length"}
+ **{k: v for k, v in attributes.items() if k != "output_length"}
)
@@ -111,7 +111,7 @@ class NaiveForecasterModel(ForecastingModel):
def __init__(self, attributes: Dict[str, Any]):
super().__init__(attributes)
self._model = NaiveForecaster(
- **{k: v for k, v in attributes.items() if k != "predict_length"}
+ **{k: v for k, v in attributes.items() if k != "output_length"}
)
@@ -121,7 +121,7 @@ class STLForecasterModel(ForecastingModel):
def __init__(self, attributes: Dict[str, Any]):
super().__init__(attributes)
self._model = STLForecaster(
- **{k: v for k, v in attributes.items() if k != "predict_length"}
+ **{k: v for k, v in attributes.items() if k != "output_length"}
)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json
index 3dadd7c3b1e..797a2fd4b98 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json
+++
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json
@@ -1,7 +1,7 @@
{
"model_type": "sktime",
"model_id": "naive_forecaster",
- "predict_length": 1,
+ "output_length": 1,
"strategy": "last",
"window_length": null,
"sp": 1
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
index a10a0a134a2..964ab156e26 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
@@ -22,54 +22,94 @@ import torch
from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pipeline.basic_pipeline import
ForecastPipeline
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.model_info import ModelInfo
+
+logger = Logger()
class SktimePipeline(ForecastPipeline):
- def __init__(self, model_info, **model_kwargs):
+ def __init__(self, model_info: ModelInfo, **model_kwargs):
model_kwargs.pop("device", None) # sktime models run on CPU
super().__init__(model_info, model_kwargs=model_kwargs)
- def preprocess(self, inputs):
- inputs = super().preprocess(inputs)
+ def preprocess(
+ self,
+ inputs: list[dict[str, dict[str, torch.Tensor] | torch.Tensor]],
+ **infer_kwargs,
+ ) -> list[pd.Series]:
+ """
+ Preprocess the input data for forecasting.
+
+ Parameters:
+ inputs (list): A list of dictionaries containing input data with
key 'targets'.
+
+ Returns:
+ list of pd.Series: Processed inputs for the model with each of
shape [input_length, ].
+ """
+ model_id = self.model_info.model_id
+
+ inputs = super().preprocess(inputs, **infer_kwargs)
+
+ # Here, we assume element in list has same history_length,
+ # otherwise, the model cannot proceed
+ if inputs[0].get("past_covariates", None) or inputs[0].get(
+ "future_covariates", None
+ ):
+ logger.warning(
+ f"[Inference] Past_covariates and future_covariates will be
ignored, as they are not supported for model {model_id}."
+ )
+
+ # stack the data and get a 3D-tensor: [batch_size, target_count(1),
input_length]
+ inputs = torch.stack([data["targets"] for data in inputs], dim=0)
if inputs.shape[1] != 1:
raise InferenceModelInternalException(
- f"[Inference] Sktime model only supports univarate forecast,
but receives {inputs.shape[1]} target variables."
+ f"Model {model_id} only supports univariate forecast, but
receives {inputs.shape[1]} target variables."
)
+ # Transform into a 2D-tensor: [batch_size, input_length]
inputs = inputs.squeeze(1)
+ # Transform into a list of Series with each of shape [input_length,]
+ inputs = [pd.Series(data.cpu().numpy()) for i, data in
enumerate(inputs)]
+
return inputs
- def forecast(self, inputs, **infer_kwargs):
- predict_length = infer_kwargs.get("predict_length", 96)
-
- # Convert to pandas Series for sktime (sktime expects Series or
DataFrame)
- # Handle batch dimension: if batch_size > 1, process each sample
separately
- if len(inputs.shape) == 2 and inputs.shape[0] > 1:
- # Batch processing: convert each row to Series
- outputs = []
- for i in range(inputs.shape[0]):
- series = pd.Series(
- inputs[i].cpu().numpy()
- if isinstance(inputs, torch.Tensor)
- else inputs[i]
- )
- output = self.model.generate(series,
predict_length=predict_length)
- outputs.append(output)
- outputs = np.array(outputs)
- else:
- # Single sample: convert to Series
- if isinstance(inputs, torch.Tensor):
- series = pd.Series(inputs.squeeze().cpu().numpy())
- else:
- series = pd.Series(inputs.squeeze())
- outputs = self.model.generate(series,
predict_length=predict_length)
- # Add batch dimension if needed
- if len(outputs.shape) == 1:
- outputs = outputs[np.newaxis, :]
+ def forecast(self, inputs: list[pd.Series], **infer_kwargs) -> np.ndarray:
+ """
+ Generate forecasts from the model for given inputs.
+
+ Parameters:
+ inputs (list[Series]): A list of input data for forecasting with
each of shape [input_length,].
+ **infer_kwargs: Additional inference parameters such as:
+ - 'output_length'(int): The number of time points that model
should generate.
+
+ Returns:
+ np.ndarray: Forecasted outputs.
+ """
+ output_length = infer_kwargs.get("output_length", 96)
+
+ # Batch processing
+ outputs = []
+ for series in inputs:
+ output = self.model.generate(series, output_length=output_length)
+ outputs.append(output)
+ outputs = np.array(outputs)
return outputs
- def postprocess(self, outputs):
- if isinstance(outputs, np.ndarray):
- outputs = torch.from_numpy(outputs).float()
- outputs = super().postprocess(outputs.unsqueeze(1))
+ def postprocess(self, outputs: np.ndarray, **infer_kwargs) ->
list[torch.Tensor]:
+ """
+ Postprocess the model's outputs.
+
+ Parameters:
+ outputs (np.ndarray): Model output to be processed.
+ **infer_kwargs: Additional inference parameters.
+
+ Returns:
+ list of torch.Tensor: List of 2D-tensors with shape
[target_count(1), output_length].
+ """
+
+ # Transform outputs into a 2D-tensor: [batch_size, output_length]
+ outputs = torch.from_numpy(outputs).float()
+ outputs = [outputs[i].unsqueeze(0) for i in range(outputs.size(0))]
+ outputs = super().postprocess(outputs, **infer_kwargs)
return outputs
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json
index bfe71dbc486..dff4d5d15a1 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json
+++
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json
@@ -1,7 +1,7 @@
{
"model_type": "sktime",
"model_id": "stl_forecaster",
- "predict_length": 1,
+ "output_length": 1,
"sp": 2,
"seasonal": 7,
"trend": null,
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
index 69422dfadb2..1715f190e32 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
@@ -20,42 +20,92 @@ import torch
from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pipeline.basic_pipeline import
ForecastPipeline
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.model_info import ModelInfo
+
+logger = Logger()
class SundialPipeline(ForecastPipeline):
- def __init__(self, model_info, **model_kwargs):
+ def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, model_kwargs=model_kwargs)
- def preprocess(self, inputs):
+ def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
"""
- The inputs shape should be 3D, but Sundial only supports 2D tensor:
[batch_size, sequence_length],
- we need to squeeze the target_count dimension.
+ Preprocess the input data by converting it to a 2D tensor (Sundial
only supports 2D inputs).
+
+ Parameters:
+ inputs (list): A list of dictionaries containing input data,
+ where each dictionary includes a "targets" key.
+ **infer_kwargs: Additional keyword arguments passed to the method.
+
+ Returns:
+ torch.Tensor: A 2D tensor with shape [batch_size, input_length]
after squeezing
+ the target_count dimension.
+
+ Raises:
+ InferenceModelInternalException: If the model receives more than
one target variable
+ (i.e., when inputs.shape[1] != 1).
"""
- inputs = super().preprocess(inputs)
+ model_id = self.model_info.model_id
+ inputs = super().preprocess(inputs, **infer_kwargs)
+ # Here, we assume element in list has same history_length,
+ # otherwise, the model cannot proceed
+ if inputs[0].get("past_covariates", None) or inputs[0].get(
+ "future_covariates", None
+ ):
+ logger.warning(
+ f"[Inference] Past_covariates and future_covariates will be
ignored, as they are not supported for model {model_id}."
+ )
+
+ # stack the data and get a 3D-tensor:[batch_size, target_count(1),
input_length]
+ inputs = torch.stack([data["targets"] for data in inputs], dim=0)
if inputs.shape[1] != 1:
raise InferenceModelInternalException(
- f"[Inference] Model sundial only supports univarate forecast,
but receives {inputs.shape[1]} target variables."
+ f"Model {model_id} only supports univariate forecast, but
receives {inputs.shape[1]} target variables."
)
inputs = inputs.squeeze(1)
return inputs
- def forecast(self, inputs, **infer_kwargs):
- predict_length = infer_kwargs.get("predict_length", 96)
+ def forecast(self, inputs: torch.Tensor, **infer_kwargs) -> torch.Tensor:
+ """
+ Generate forecasted outputs using the Sundial model based on the
provided inputs.
+
+ Parameters:
+ inputs (torch.Tensor): A 2D tensor of shape [batch_size,
input_length].
+ **infer_kwargs: Additional inference parameters:
+ - output_length (int): The length of the forecast output
(default: 96).
+ - num_samples (int): The number of samples to generate
(default: 10).
+ - revin (bool): Whether to apply revin (default: True).
+
+ Returns:
+ torch.Tensor: A tensor containing the forecasted outputs with
shape [batch_size, num_samples, output_length].
+ """
+ output_length = infer_kwargs.get("output_length", 96)
num_samples = infer_kwargs.get("num_samples", 10)
revin = infer_kwargs.get("revin", True)
outputs = self.model.generate(
inputs,
- max_new_tokens=predict_length,
+ max_new_tokens=output_length,
num_samples=num_samples,
revin=revin,
)
return outputs
- def postprocess(self, outputs: torch.Tensor):
+ def postprocess(self, outputs: torch.Tensor, **infer_kwargs) ->
list[torch.Tensor]:
"""
- The outputs shape should be 3D, we need to take the mean value across
num_samples dimension and expand dims.
+ Postprocess the model's output by averaging across the num_samples
dimension and
+ expanding the dimensions to match the expected shape.
+
+ Parameters:
+ outputs (torch.Tensor): The raw output 3D-tensor from the model
with shape [batch_size, num_samples, output_length].
+ **infer_kwargs: Additional inference parameters passed to the
method.
+
+ Returns:
+ list of torch.Tensor: A list of 2D tensors with shape
[target_count(1), output_length].
"""
outputs = outputs.mean(dim=1).unsqueeze(1)
- outputs = super().postprocess(outputs)
+ outputs = [outputs[i] for i in range(outputs.size(0))]
+ outputs = super().postprocess(outputs, **infer_kwargs)
return outputs
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
index 041cb8c7cba..bb54eed4ec6 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
@@ -20,37 +20,83 @@ import torch
from iotdb.ainode.core.exception import InferenceModelInternalException
from iotdb.ainode.core.inference.pipeline.basic_pipeline import
ForecastPipeline
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.model_info import ModelInfo
+
+logger = Logger()
class TimerPipeline(ForecastPipeline):
- def __init__(self, model_info, **model_kwargs):
+ def __init__(self, model_info: ModelInfo, **model_kwargs):
super().__init__(model_info, model_kwargs=model_kwargs)
- def preprocess(self, inputs):
+ def preprocess(self, inputs, **infer_kwargs) -> torch.Tensor:
"""
- The inputs shape should be 3D, but Timer-XL only supports 2D tensor:
[batch_size, sequence_length],
- we need to squeeze the target_count dimension.
+ Preprocess the input data by converting it to a 2D tensor (Timer-XL
only supports 2D inputs).
+
+ Parameters:
+ inputs (list): A list of dictionaries containing input data,
+ where each dictionary should include a "targets"
key.
+ **infer_kwargs: Additional keyword arguments passed to the method.
+
+ Returns:
+ torch.Tensor: A 2D tensor of shape [batch_size, input_length]
after squeezing
+ the target_count dimension.
+
+ Raises:
+ InferenceModelInternalException: If the model receives more than
one target variable
+ (i.e., when inputs.shape[1] != 1).
"""
- inputs = super().preprocess(inputs)
+ model_id = self.model_info.model_id
+ inputs = super().preprocess(inputs, **infer_kwargs)
+ # Here, we assume element in list has same history_length,
+ # otherwise, the model cannot proceed
+ if inputs[0].get("past_covariates", None) or inputs[0].get(
+ "future_covariates", None
+ ):
+ logger.warning(
+ f"[Inference] Past_covariates and future_covariates will be
ignored, as they are not supported for model {model_id}."
+ )
+
+ # stack the data and get a 3D-tensor:[batch_size, target_count(1),
input_length]
+ inputs = torch.stack([data["targets"] for data in inputs], dim=0)
if inputs.shape[1] != 1:
raise InferenceModelInternalException(
- f"[Inference] Model timer_xl only supports univarate forecast,
but receives {inputs.shape[1]} target variables."
+ f"Model {model_id} only supports univariate forecast, but
receives {inputs.shape[1]} target variables."
)
inputs = inputs.squeeze(1)
return inputs
- def forecast(self, inputs, **infer_kwargs):
- predict_length = infer_kwargs.get("predict_length", 96)
+ def forecast(self, inputs: torch.Tensor, **infer_kwargs) -> torch.Tensor:
+ """
+ Generate forecasted outputs using the model based on the provided
inputs.
+
+ Parameters:
+ inputs (torch.Tensor): A 2D tensor of shape [batch_size,
input_length].
+ **infer_kwargs: Additional inference parameters:
+ - output_length (int): The length of the forecast output
(default: 96).
+ - revin (bool): Whether to apply revin (default: True).
+
+ Returns:
+ torch.Tensor: A tensor containing the forecasted outputs, with
shape [batch_size, output_length].
+ """
+ output_length = infer_kwargs.get("output_length", 96)
revin = infer_kwargs.get("revin", True)
- outputs = self.model.generate(
- inputs, max_new_tokens=predict_length, revin=revin
- )
+ outputs = self.model.generate(inputs, max_new_tokens=output_length,
revin=revin)
return outputs
- def postprocess(self, outputs: torch.Tensor):
+ def postprocess(self, outputs: torch.Tensor, **infer_kwargs) ->
list[torch.Tensor]:
"""
- The outputs shape should be 3D, so we need to expand dims.
+ Postprocess the model's output by expanding its dimensions to match
the expected shape.
+
+ Parameters:
+ outputs (torch.Tensor): The raw output 2D-tensor from the model
with shape [batch_size, output_length].
+ **infer_kwargs: Additional inference parameters passed to the
method.
+
+ Returns:
+ list of torch.Tensor: A list of 2D tensors with shape
[target_count(1), output_length].
"""
- outputs = super().postprocess(outputs.unsqueeze(1))
+ outputs = [outputs[i].unsqueeze(0) for i in range(outputs.size(0))]
+ outputs = super().postprocess(outputs, **infer_kwargs)
return outputs
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index 950d8c464e8..dcb27825e31 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -77,14 +77,14 @@ import static
org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
public class ForecastTableFunction implements TableFunction {
public static class ForecastTableFunctionHandle implements
TableFunctionHandle {
- String modelId;
- int maxInputLength;
- int outputLength;
- long outputStartTime;
- long outputInterval;
- boolean keepInput;
- Map<String, String> options;
- List<Type> targetColumntypes;
+ protected String modelId;
+ protected int maxInputLength;
+ protected int outputLength;
+ protected long outputStartTime;
+ protected long outputInterval;
+ protected boolean keepInput;
+ protected Map<String, String> options;
+ protected List<Type> targetColumntypes;
public ForecastTableFunctionHandle() {}
@@ -182,22 +182,22 @@ public class ForecastTableFunction implements
TableFunction {
}
}
- private static final String TARGETS_PARAMETER_NAME = "TARGETS";
- private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
- private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
- private static final int DEFAULT_OUTPUT_LENGTH = 96;
- private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
+ protected static final String TARGETS_PARAMETER_NAME = "TARGETS";
+ protected static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
+ protected static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH";
+ protected static final int DEFAULT_OUTPUT_LENGTH = 96;
+ protected static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
- private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
+ protected static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
public static final long DEFAULT_OUTPUT_INTERVAL = 0L;
public static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
- private static final String DEFAULT_TIME_COL = "time";
- private static final String KEEP_INPUT_PARAMETER_NAME = "PRESERVE_INPUT";
- private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
- private static final String IS_INPUT_COLUMN_NAME = "is_input";
- private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
- private static final String DEFAULT_OPTIONS = "";
- private static final int MAX_INPUT_LENGTH = 2880;
+ protected static final String DEFAULT_TIME_COL = "time";
+ protected static final String KEEP_INPUT_PARAMETER_NAME = "PRESERVE_INPUT";
+ protected static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
+ protected static final String IS_INPUT_COLUMN_NAME = "is_input";
+ protected static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
+ protected static final String DEFAULT_OPTIONS = "";
+ protected static final int MAX_INPUT_LENGTH = 2880;
private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s";
@@ -366,7 +366,7 @@ public class ForecastTableFunction implements TableFunction
{
}
// only allow for INT32, INT64, FLOAT, DOUBLE
- private void checkType(Type type, String columnName) {
+ public void checkType(Type type, String columnName) {
if (!ALLOWED_INPUT_TYPES.contains(type)) {
throw new SemanticException(
String.format(
@@ -375,7 +375,7 @@ public class ForecastTableFunction implements TableFunction
{
}
}
- private static Map<String, String> parseOptions(String options) {
+ public static Map<String, String> parseOptions(String options) {
if (options.isEmpty()) {
return Collections.emptyMap();
}
@@ -397,22 +397,22 @@ public class ForecastTableFunction implements
TableFunction {
return optionsMap;
}
- private static class ForecastDataProcessor implements
TableFunctionDataProcessor {
+ protected static class ForecastDataProcessor implements
TableFunctionDataProcessor {
- private static final TsBlockSerde SERDE = new TsBlockSerde();
- private static final IClientManager<Integer, AINodeClient> CLIENT_MANAGER =
+ protected static final TsBlockSerde SERDE = new TsBlockSerde();
+ protected static final IClientManager<Integer, AINodeClient>
CLIENT_MANAGER =
AINodeClientManager.getInstance();
- private final String modelId;
+ protected final String modelId;
private final int maxInputLength;
- private final int outputLength;
+ protected final int outputLength;
private final long outputStartTime;
private final long outputInterval;
private final boolean keepInput;
- private final Map<String, String> options;
- private final LinkedList<Record> inputRecords;
- private final List<ResultColumnAppender> resultColumnAppenderList;
- private final TsBlockBuilder inputTsBlockBuilder;
+ protected final Map<String, String> options;
+ protected final LinkedList<Record> inputRecords;
+ protected final List<ResultColumnAppender> resultColumnAppenderList;
+ protected final TsBlockBuilder inputTsBlockBuilder;
public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) {
this.modelId = functionHandle.modelId;
@@ -536,7 +536,7 @@ public class ForecastTableFunction implements TableFunction
{
}
}
- private TsBlock forecast() {
+ protected TsBlock forecast() {
// construct inputTSBlock for AINode
while (!inputRecords.isEmpty()) {
Record row = inputRecords.removeFirst();
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index ea32f01b6e2..8a5971823ec 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -85,7 +85,9 @@ struct TForecastReq {
1: required string modelId
2: required binary inputData
3: required i32 outputLength
- 4: optional map<string, string> options
+ 4: optional string historyCovs
+ 5: optional string futureCovs
+ 6: optional map<string, string> options
}
struct TForecastResp {