RkGrit commented on code in PR #16768:
URL: https://github.com/apache/iotdb/pull/16768#discussion_r2536909900
##########
iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py:
##########
@@ -616,7 +616,11 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
+ input_ids = input_ids[
+ :,
+ -(attention_mask.shape[1] - past_length)
+ * self.config.input_token_len :,
+ ]
# 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
Review Comment:
The same as above
##########
iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py:
##########
@@ -606,7 +606,11 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
+ input_ids = input_ids[
+ :,
+ -(attention_mask.shape[1] - past_length)
+ * self.config.input_token_len :,
+ ]
# 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
Review Comment:
The same as above
##########
iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py:
##########
@@ -616,7 +616,11 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
Review Comment:
When calculating the token length of input_ids, since Sundial performs
padding in the embedding stage, fragment cannot be discarded but should be
rounded up. Likeļ¼
token_num = (
input_ids.shape[1] + self.config.input_token_len - 1
) // self.config.input_token_len
##########
iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py:
##########
@@ -606,7 +606,11 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
Review Comment:
The same as above
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]