This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 14b4382ff44 [AINode]Fix the parameter "predict_length" (#15900)
14b4382ff44 is described below

commit 14b4382ff44054d830b1406045caf438b39197a1
Author: Leo <[email protected]>
AuthorDate: Thu Jul 10 14:35:28 2025 +0800

    [AINode]Fix the parameter "predict_length" (#15900)
---
 iotdb-core/ainode/ainode/core/manager/inference_manager.py         | 6 +++---
 iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py | 5 +++++
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 9eda1c22651..a8109e278db 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -84,7 +84,7 @@ class SundialStrategy(InferenceStrategy):
 
 
 class BuiltInStrategy(InferenceStrategy):
-    def infer(self, full_data):
+    def infer(self, full_data, **_):
         data = pd.DataFrame(full_data[1]).T
         output = self.model.inference(data)
         df = pd.DataFrame(output)
@@ -92,7 +92,7 @@ class BuiltInStrategy(InferenceStrategy):
 
 
 class RegisteredStrategy(InferenceStrategy):
-    def infer(self, full_data, window_interval=None, window_step=None, 
**kwargs):
+    def infer(self, full_data, window_interval=None, window_step=None, **_):
         _, dataset, _, length = full_data
         if window_interval is None or window_step is None:
             window_interval = length
@@ -159,7 +159,7 @@ class InferenceManager:
 
             # inference by strategy
             strategy = self._get_strategy(model_id, model)
-            outputs = strategy.infer(full_data)
+            outputs = strategy.infer(full_data, **inference_attrs)
 
             # construct response
             status = get_status(TSStatusCode.SUCCESS_STATUS)
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py 
b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
index d894d3d5ed3..04571161660 100644
--- a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
+++ b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
@@ -54,6 +54,11 @@ class TSGenerationMixin(GenerationMixin):
     ) -> Union[GenerateOutput, torch.LongTensor]:
         if len(inputs.shape) != 2:
             raise ValueError("Input shape must be: [batch_size, seq_len]")
+        batch_size, cur_len = inputs.shape
+        if cur_len < self.config.input_token_len:
+            raise ValueError(
+                f"Input length must be at least {self.config.input_token_len}"
+            )
         if revin:
             means = inputs.mean(dim=-1, keepdim=True)
             stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5

Reply via email to