This is an automated email from the ASF dual-hosted git repository. ycycse pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push: new 25c933f53fc Fixed the issue of abnormal output when the input length of timerxl is not a multiple of 96 (#15495) 25c933f53fc is described below commit 25c933f53fc84cd26c696dea65514b6d88b6805b Author: jtmer <107352646+jt...@users.noreply.github.com> AuthorDate: Wed May 14 00:31:29 2025 +0800 Fixed the issue of abnormal output when the input length of timerxl is not a multiple of 96 (#15495) Co-authored-by: OswinGuai <peiz...@gmail.com> --- iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py index 4e4d8588fd2..0e66542405e 100644 --- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py +++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py @@ -253,6 +253,18 @@ class Model(nn.Module): self.eval() self.device = next(self.model.parameters()).device + if len(x.shape) == 2: + batch_size, cur_len = x.shape + if cur_len < self.config.input_token_len: + raise ValueError( + f"Input length must be at least {self.config.input_token_len}") + elif cur_len % self.config.input_token_len != 0: + new_len = (cur_len // self.config.input_token_len) * \ + self.config.input_token_len + x = x[:, -new_len:] + else: + raise ValueError('Input shape must be: [batch_size, seq_len]') + use_cache = self.config.use_cache all_input_ids = x