This is an automated email from the ASF dual-hosted git repository. jackietien pushed a commit to branch force_ci/object_type in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 81b6d5922516e8c834d04b8efa79cd4fad2d585d Author: Leo <[email protected]> AuthorDate: Wed Nov 19 09:23:19 2025 +0800 [AINode] Fix bug of sundial and forecast udf (#16768) (cherry picked from commit 2b47be756ad8703ce3673973260983f10c4f94e3) --- .../iotdb/ainode/core/model/sundial/modeling_sundial.py | 13 +++++++++---- .../iotdb/ainode/core/model/timerxl/modeling_timer.py | 6 +++++- .../apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java | 1 - 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py index 544193e4d9c..3ebf516f705 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -616,7 +616,11 @@ class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin): if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[ + :, + -(attention_mask.shape[1] - past_length) + * self.config.input_token_len :, + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): @@ -629,9 +633,10 @@ class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[ - :, -(input_ids.shape[1] // self.config.input_token_len) : - ] + token_num = ( + input_ids.shape[1] + self.config.input_token_len - 1 + ) // self.config.input_token_len + position_ids = position_ids[:, -token_num:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py index 37bf56dfc59..0a33c682742 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py @@ -606,7 +606,11 @@ class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin): if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[ + :, + -(attention_mask.shape[1] - past_length) + * self.config.input_token_len :, + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index e77e0641ae9..260410954d4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -114,7 +114,6 @@ public class UDTFForecast implements UDTF { } ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id); this.targetAINode = descriptor.getTargetAINode(); - this.maxInputLength = descriptor.getModelInformation().getInputShape()[0]; this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL); this.outputLength =
