This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 14b4382ff44 [AINode]Fix the parameter "predict_length" (#15900)
14b4382ff44 is described below
commit 14b4382ff44054d830b1406045caf438b39197a1
Author: Leo <[email protected]>
AuthorDate: Thu Jul 10 14:35:28 2025 +0800
[AINode]Fix the parameter "predict_length" (#15900)
---
iotdb-core/ainode/ainode/core/manager/inference_manager.py | 6 +++---
iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py | 5 +++++
2 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 9eda1c22651..a8109e278db 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -84,7 +84,7 @@ class SundialStrategy(InferenceStrategy):
class BuiltInStrategy(InferenceStrategy):
- def infer(self, full_data):
+ def infer(self, full_data, **_):
data = pd.DataFrame(full_data[1]).T
output = self.model.inference(data)
df = pd.DataFrame(output)
@@ -92,7 +92,7 @@ class BuiltInStrategy(InferenceStrategy):
class RegisteredStrategy(InferenceStrategy):
- def infer(self, full_data, window_interval=None, window_step=None,
**kwargs):
+ def infer(self, full_data, window_interval=None, window_step=None, **_):
_, dataset, _, length = full_data
if window_interval is None or window_step is None:
window_interval = length
@@ -159,7 +159,7 @@ class InferenceManager:
# inference by strategy
strategy = self._get_strategy(model_id, model)
- outputs = strategy.infer(full_data)
+ outputs = strategy.infer(full_data, **inference_attrs)
# construct response
status = get_status(TSStatusCode.SUCCESS_STATUS)
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
index d894d3d5ed3..04571161660 100644
--- a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
+++ b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
@@ -54,6 +54,11 @@ class TSGenerationMixin(GenerationMixin):
) -> Union[GenerateOutput, torch.LongTensor]:
if len(inputs.shape) != 2:
raise ValueError("Input shape must be: [batch_size, seq_len]")
+ batch_size, cur_len = inputs.shape
+ if cur_len < self.config.input_token_len:
+ raise ValueError(
+ f"Input length must be at least {self.config.input_token_len}"
+ )
if revin:
means = inputs.mean(dim=-1, keepdim=True)
stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5