This is an automated email from the ASF dual-hosted git repository.

ycycse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 25c933f53fc Fixed the issue of abnormal output when the input length 
of timerxl is not a multiple of 96 (#15495)
25c933f53fc is described below

commit 25c933f53fc84cd26c696dea65514b6d88b6805b
Author: jtmer <107352646+jt...@users.noreply.github.com>
AuthorDate: Wed May 14 00:31:29 2025 +0800

    Fixed the issue of abnormal output when the input length of timerxl is not 
a multiple of 96 (#15495)
    
    Co-authored-by: OswinGuai <peiz...@gmail.com>
---
 iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py 
b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
index 4e4d8588fd2..0e66542405e 100644
--- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
+++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
@@ -253,6 +253,18 @@ class Model(nn.Module):
         self.eval()
         self.device = next(self.model.parameters()).device
         
+        if len(x.shape) == 2:
+            batch_size, cur_len = x.shape
+            if cur_len < self.config.input_token_len:
+                raise ValueError(
+                    f"Input length must be at least 
{self.config.input_token_len}")
+            elif cur_len % self.config.input_token_len != 0:
+                new_len = (cur_len // self.config.input_token_len) * \
+                    self.config.input_token_len
+                x = x[:, -new_len:]
+        else:
+            raise ValueError('Input shape must be: [batch_size, seq_len]')
+        
         use_cache = self.config.use_cache
         all_input_ids = x
         

Reply via email to