yunbow30944 commented on code in PR #16768:
URL: https://github.com/apache/iotdb/pull/16768#discussion_r2536176058


##########
iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py:
##########
@@ -603,35 +603,42 @@ def prepare_inputs_for_generation(
         **kwargs,
     ):
         # Omit tokens covered by past_key_values
+        past_length = 0
+        token_num = (
+            input_ids.shape[1] + self.config.input_token_len - 1
+        ) // self.config.input_token_len
+
         if past_key_values is not None:
             if isinstance(past_key_values, Cache):
                 past_length = past_key_values.get_seq_length()
             else:
                 past_length = past_key_values[0][0].shape[2]
 
+        if past_key_values is not None and past_length > 0:
             # Keep only the unprocessed tokens:
             # 1 - If the length of the attention_mask exceeds the length of 
input_ids, then we are in a setting where
             # some of the inputs are exclusively passed as part of the cache 
(e.g. when passing input_embeds as
             # input)
-            if attention_mask is not None and attention_mask.shape[1] > (
-                input_ids.shape[1] // self.config.input_token_len
-            ):
+            if attention_mask is not None and attention_mask.shape[1] > 
token_num:
                 input_ids = input_ids[:, -(attention_mask.shape[1] - 
past_length) :]
             # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
             # input_ids based on the past_length.
-            elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
-                input_ids = input_ids[:, past_length * 
self.config.input_token_len :]
+            elif past_length < token_num:
+                # TODO: Actually, we need to know the output_token_lens used 
in the last generation step.
+                #  Sundial will pad the input when it is non-divisible, so we 
cannot use past_length to slice input_ids
+                input_ids = input_ids[:, -self.config.output_token_lens[0] :]
             # 3 - Otherwise (past_length >= (input_ids.shape[1] // 
self.config.input_token_len)), let's assume input_ids only has unprocessed 
tokens.
 
         position_ids = kwargs.get("position_ids", None)
         if attention_mask is not None and position_ids is None:
             # create position_ids on the fly for batch generation
             position_ids = attention_mask.long().cumsum(-1) - 1
             position_ids.masked_fill_(attention_mask == 0, 1)
-            if past_key_values:
-                position_ids = position_ids[
-                    :, -(input_ids.shape[1] // self.config.input_token_len) :
-                ]
+            if past_key_values is not None and past_length > 0:
+                token_num = (
+                    input_ids.shape[1] + self.config.input_token_len - 1
+                ) // self.config.input_token_len

Review Comment:
   Because the input_ids is sliced, so the position_ids should be consistent 
with the new length of input.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to