This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b47be756ad [AINode] Fix bug of sundial and forecast udf (#16768)
2b47be756ad is described below

commit 2b47be756ad8703ce3673973260983f10c4f94e3
Author: Leo <[email protected]>
AuthorDate: Wed Nov 19 09:23:19 2025 +0800

    [AINode] Fix bug of sundial and forecast udf (#16768)
---
 .../iotdb/ainode/core/model/sundial/modeling_sundial.py     | 13 +++++++++----
 .../iotdb/ainode/core/model/timerxl/modeling_timer.py       |  6 +++++-
 .../apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java  |  1 -
 3 files changed, 14 insertions(+), 6 deletions(-)

diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
index 544193e4d9c..3ebf516f705 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
@@ -616,7 +616,11 @@ class SundialForPrediction(SundialPreTrainedModel, 
TSGenerationMixin):
             if attention_mask is not None and attention_mask.shape[1] > (
                 input_ids.shape[1] // self.config.input_token_len
             ):
-                input_ids = input_ids[:, -(attention_mask.shape[1] - 
past_length) :]
+                input_ids = input_ids[
+                    :,
+                    -(attention_mask.shape[1] - past_length)
+                    * self.config.input_token_len :,
+                ]
             # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
             # input_ids based on the past_length.
             elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
@@ -629,9 +633,10 @@ class SundialForPrediction(SundialPreTrainedModel, 
TSGenerationMixin):
             position_ids = attention_mask.long().cumsum(-1) - 1
             position_ids.masked_fill_(attention_mask == 0, 1)
             if past_key_values:
-                position_ids = position_ids[
-                    :, -(input_ids.shape[1] // self.config.input_token_len) :
-                ]
+                token_num = (
+                    input_ids.shape[1] + self.config.input_token_len - 1
+                ) // self.config.input_token_len
+                position_ids = position_ids[:, -token_num:]
 
         # if `inputs_embeds` are passed, we only want to use them in the 1st 
generation step
         if inputs_embeds is not None and past_key_values is None:
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
index 37bf56dfc59..0a33c682742 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
@@ -606,7 +606,11 @@ class TimerForPrediction(TimerPreTrainedModel, 
TSGenerationMixin):
             if attention_mask is not None and attention_mask.shape[1] > (
                 input_ids.shape[1] // self.config.input_token_len
             ):
-                input_ids = input_ids[:, -(attention_mask.shape[1] - 
past_length) :]
+                input_ids = input_ids[
+                    :,
+                    -(attention_mask.shape[1] - past_length)
+                    * self.config.input_token_len :,
+                ]
             # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
             # input_ids based on the past_length.
             elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
index e77e0641ae9..260410954d4 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -114,7 +114,6 @@ public class UDTFForecast implements UDTF {
     }
     ModelInferenceDescriptor descriptor = 
modelFetcher.fetchModel(this.model_id);
     this.targetAINode = descriptor.getTargetAINode();
-    this.maxInputLength = descriptor.getModelInformation().getInputShape()[0];
 
     this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, 
DEFAULT_OUTPUT_INTERVAL);
     this.outputLength =

Reply via email to