This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 2b47be756ad [AINode] Fix bug of sundial and forecast udf (#16768)
2b47be756ad is described below
commit 2b47be756ad8703ce3673973260983f10c4f94e3
Author: Leo <[email protected]>
AuthorDate: Wed Nov 19 09:23:19 2025 +0800
[AINode] Fix bug of sundial and forecast udf (#16768)
---
.../iotdb/ainode/core/model/sundial/modeling_sundial.py | 13 +++++++++----
.../iotdb/ainode/core/model/timerxl/modeling_timer.py | 6 +++++-
.../apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java | 1 -
3 files changed, 14 insertions(+), 6 deletions(-)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
index 544193e4d9c..3ebf516f705 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
@@ -616,7 +616,11 @@ class SundialForPrediction(SundialPreTrainedModel,
TSGenerationMixin):
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
+ input_ids = input_ids[
+ :,
+ -(attention_mask.shape[1] - past_length)
+ * self.config.input_token_len :,
+ ]
# 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
@@ -629,9 +633,10 @@ class SundialForPrediction(SundialPreTrainedModel,
TSGenerationMixin):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- position_ids = position_ids[
- :, -(input_ids.shape[1] // self.config.input_token_len) :
- ]
+ token_num = (
+ input_ids.shape[1] + self.config.input_token_len - 1
+ ) // self.config.input_token_len
+ position_ids = position_ids[:, -token_num:]
# if `inputs_embeds` are passed, we only want to use them in the 1st
generation step
if inputs_embeds is not None and past_key_values is None:
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
index 37bf56dfc59..0a33c682742 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
@@ -606,7 +606,11 @@ class TimerForPrediction(TimerPreTrainedModel,
TSGenerationMixin):
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
+ input_ids = input_ids[
+ :,
+ -(attention_mask.shape[1] - past_length)
+ * self.config.input_token_len :,
+ ]
# 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
index e77e0641ae9..260410954d4 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -114,7 +114,6 @@ public class UDTFForecast implements UDTF {
}
ModelInferenceDescriptor descriptor =
modelFetcher.fetchModel(this.model_id);
this.targetAINode = descriptor.getTargetAINode();
- this.maxInputLength = descriptor.getModelInformation().getInputShape()[0];
this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL,
DEFAULT_OUTPUT_INTERVAL);
this.outputLength =