This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ba8b0c17ac [AINode] Integrate Sundial as built-in model (#15586)
3ba8b0c17ac is described below

commit 3ba8b0c17ac1842278770b78ad053ee572440948
Author: Yongzao <[email protected]>
AuthorDate: Thu May 29 21:32:05 2025 +0800

    [AINode] Integrate Sundial as built-in model (#15586)
---
 .../relational/it/schema/IoTDBDatabaseIT.java      |   2 +
 iotdb-core/ainode/ainode/core/constant.py          |  11 +-
 .../ainode/core/manager/inference_manager.py       |  13 +
 .../ainode/core/model/built_in_model_factory.py    | 104 +++-
 .../ainode/ainode/core/model/sundial/__init__.py   |  17 +
 .../core/model/sundial/configuration_sundial.py    |  66 +++
 .../ainode/ainode/core/model/sundial/flow_loss.py  | 254 +++++++++
 .../ainode/core/model/sundial/modeling_sundial.py  | 612 +++++++++++++++++++++
 .../core/model/sundial/ts_generation_mixin.py      | 306 +++++++++++
 iotdb-core/ainode/pyproject.toml                   |   1 +
 .../iotdb/confignode/persistence/ModelInfo.java    |   1 +
 11 files changed, 1384 insertions(+), 3 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
index 0af5ffd5a1d..1956aa58222 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
@@ -577,6 +577,7 @@ public class IoTDBDatabaseIT {
               Arrays.asList(
                   "_STLForecaster,",
                   "_NaiveForecaster,",
+                  "_sundial,",
                   "_HoltWinters,",
                   "_TimerXL,",
                   "_ExponentialSmoothing,",
@@ -700,6 +701,7 @@ public class IoTDBDatabaseIT {
           "model_id,",
           new HashSet<>(
               Arrays.asList(
+                  "_sundial,",
                   "_TimerXL,",
                   "_STLForecaster,",
                   "_NaiveForecaster,",
diff --git a/iotdb-core/ainode/ainode/core/constant.py 
b/iotdb-core/ainode/ainode/core/constant.py
index 467ffbfdec7..1078836ff65 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -152,6 +152,9 @@ class BuiltInModelType(Enum):
     # timerxl
     TIMER_XL = "_timerxl"
 
+    # sundial
+    SUNDIAL = "_sundial"
+
     @classmethod
     def values(cls) -> List[str]:
         values = []
@@ -259,7 +262,13 @@ class AttributeName(Enum):
     ATTENTION_DROPOUT = "attention_dropout"
     INITIALIZER_RANGE = "initializer_range"
     MAX_POSITION_EMBEDDINGS = "max_position_embeddings"
-    TIMERXL_CKPT_PATH = "ckpt_path"
+    CKPT_PATH = "ckpt_path"
+
+    # sundial
+    DROPOUT_RATE = "dropout_rate"
+    FLOW_LOSS_DEPTH = "flow_loss_depth"
+    NUM_SAMPLING_STEPS = "num_sampling_steps"
+    DIFFUSION_BATCH_MUL = "diffusion_batch_mul"
 
     def name(self) -> str:
         return self.value
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index e5142b0205b..9df60fe1a40 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -52,6 +52,17 @@ class TimerXLStrategy(InferenceStrategy):
         df = pd.DataFrame(output[0])
         return convert_to_binary(df)
 
+class SundialStrategy(InferenceStrategy):
+    def infer(self, full_data, predict_length=96, **_):
+        data = full_data[1][0]
+        if data.dtype.byteorder not in ('=', '|'):
+            data = data.byteswap().newbyteorder()
+        seqs = torch.tensor(data).unsqueeze(0).float()
+        # TODO: unify model inference input
+        output = self.model.generate(seqs, max_new_tokens=predict_length, 
num_samples=10, revin=True)
+        df = pd.DataFrame(output[0].mean(dim=0))
+        return convert_to_binary(df)
+
 
 class BuiltInStrategy(InferenceStrategy):
     def infer(self, full_data, **_):
@@ -94,6 +105,8 @@ class RegisteredStrategy(InferenceStrategy):
 def _get_strategy(model_id, model):
     if model_id == '_timerxl':
         return TimerXLStrategy(model)
+    if model_id == '_sundial':
+        return SundialStrategy(model)
     if model_id.startswith('_'):
         return BuiltInStrategy(model)
     return RegisteredStrategy(model)
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index f395e0bae9e..f699934e0c5 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -30,12 +30,14 @@ from sktime.forecasting.trend import STLForecaster
 
 from ainode.TimerXL.models import timer_xl
 from ainode.TimerXL.models.configuration_timer import TimerxlConfig
+from ainode.core.model.sundial import modeling_sundial
 from ainode.core.config import AINodeDescriptor
 from ainode.core.constant import AttributeName, BuiltInModelType
 from ainode.core.exception import InferenceModelInternalError
 from ainode.core.exception import WrongAttributeTypeError, 
NumericalRangeException, StringRangeException, \
     ListRangeException, BuiltInModelNotSupportError
 from ainode.core.log import Logger
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
 
 logger = Logger()
 
@@ -57,6 +59,8 @@ def get_model_attributes(model_id: str):
         attribute_map = stray_attribute_map
     elif model_id == BuiltInModelType.TIMER_XL.value:
         attribute_map = timerxl_attribute_map
+    elif model_id == BuiltInModelType.SUNDIAL.value:
+        attribute_map = sundial_attribute_map
     else:
         raise BuiltInModelNotSupportError(model_id)
     return attribute_map
@@ -99,6 +103,8 @@ def fetch_built_in_model(model_id, inference_attributes):
         model = STRAYModel(attributes)
     elif model_id == BuiltInModelType.TIMER_XL.value:
         model = timer_xl.Model(TimerxlConfig.from_dict(attributes))
+    elif model_id == BuiltInModelType.SUNDIAL.value:
+        model = 
modeling_sundial.SundialForPrediction(SundialConfig.from_dict(attributes))
     else:
         raise BuiltInModelNotSupportError(model_id)
 
@@ -321,6 +327,100 @@ def parse_attribute(input_attributes: Dict[str, str], 
attribute_map: Dict[str, A
                 raise e
     return attributes
 
+sundial_attribute_map = {
+    AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
+        name=AttributeName.INPUT_TOKEN_LEN.value,
+        default_value=16,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.HIDDEN_SIZE.value: IntAttribute(
+        name=AttributeName.HIDDEN_SIZE.value,
+        default_value=768,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.INTERMEDIATE_SIZE.value: IntAttribute(
+        name=AttributeName.INTERMEDIATE_SIZE.value,
+        default_value=3072,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute(
+        name=AttributeName.OUTPUT_TOKEN_LENS.value,
+        default_value=[720],
+        value_type=int
+    ),
+    AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute(
+        name=AttributeName.NUM_HIDDEN_LAYERS.value,
+        default_value=12,
+        default_low=1,
+        default_high=16
+    ),
+    AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute(
+        name=AttributeName.NUM_ATTENTION_HEADS.value,
+        default_value=12,
+        default_low=1,
+        default_high=192
+    ),
+    AttributeName.HIDDEN_ACT.value: StringAttribute(
+        name=AttributeName.HIDDEN_ACT.value,
+        default_value="silu",
+        value_choices=["relu", "gelu", "silu", "tanh"],
+    ),
+    AttributeName.USE_CACHE.value: BooleanAttribute(
+        name=AttributeName.USE_CACHE.value,
+        default_value=True,
+    ),
+    AttributeName.ROPE_THETA.value: IntAttribute(
+        name=AttributeName.ROPE_THETA.value,
+        default_value=10000,
+        default_low=1000,
+        default_high=50000
+    ),
+    AttributeName.DROPOUT_RATE.value: FloatAttribute(
+        name=AttributeName.DROPOUT_RATE.value,
+        default_value=0.1,
+        default_low=0.0,
+        default_high=1.0
+    ),
+    AttributeName.INITIALIZER_RANGE.value: FloatAttribute(
+        name=AttributeName.INITIALIZER_RANGE.value,
+        default_value=0.02,
+        default_low=0.0,
+        default_high=1.0
+    ),
+    AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute(
+        name=AttributeName.MAX_POSITION_EMBEDDINGS.value,
+        default_value=10000,
+        default_low=1,
+        default_high=50000
+    ),
+    AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute(
+        name=AttributeName.FLOW_LOSS_DEPTH.value,
+        default_value=3,
+        default_low=1,
+        default_high=50
+    ),
+    AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute(
+        name=AttributeName.NUM_SAMPLING_STEPS.value,
+        default_value=50,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute(
+        name=AttributeName.DIFFUSION_BATCH_MUL.value,
+        default_value=4,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.CKPT_PATH.value: StringAttribute(
+        name=AttributeName.CKPT_PATH.value,
+        default_value=os.path.join(os.getcwd(), 
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
+                                   'sundial'),
+        value_choices=['']
+    )
+}
 
 timerxl_attribute_map = {
     AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
@@ -391,8 +491,8 @@ timerxl_attribute_map = {
         default_low=1,
         default_high=50000
     ),
-    AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute(
-        name=AttributeName.TIMERXL_CKPT_PATH.value,
+    AttributeName.CKPT_PATH.value: StringAttribute(
+        name=AttributeName.CKPT_PATH.value,
         default_value=os.path.join(os.getcwd(), 
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
                                    'timerxl', 'model.safetensors'),
         value_choices=['']
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/__init__.py 
b/iotdb-core/ainode/ainode/core/model/sundial/__init__.py
new file mode 100644
index 00000000000..4b8ee97fad2
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
\ No newline at end of file
diff --git 
a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py 
b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
new file mode 100644
index 00000000000..41c54ff4a72
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import List
+from transformers import PretrainedConfig
+
+
+class SundialConfig(PretrainedConfig):
+    model_type = "sundial"
+    keys_to_ignore_at_inference = ["past_key_values"]
+
+    def __init__(
+        self,
+        input_token_len: int = 16,
+        hidden_size: int = 768,
+        intermediate_size: int = 3072,
+        output_token_lens: List[int] = [720],
+        num_hidden_layers: int = 12,
+        num_attention_heads: int = 12,
+        hidden_act: str = "silu",
+        use_cache: bool = True,
+        rope_theta: int = 10000,
+        dropout_rate: float = 0.1,
+        initializer_range: float = 0.02,
+        max_position_embeddings: int = 10000,
+        flow_loss_depth: int = 3,
+        num_sampling_steps: int = 50,
+        diffusion_batch_mul: int = 4,
+        ckpt_path: str = None,              # weight path
+        **kwargs,
+    ):
+        self.input_token_len = input_token_len
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.hidden_act = hidden_act
+        self.output_token_lens = output_token_lens
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.dropout_rate = dropout_rate
+        self.initializer_range = initializer_range
+        self.max_position_embeddings = max_position_embeddings
+        self.flow_loss_depth = flow_loss_depth
+        self.num_sampling_steps = num_sampling_steps
+        self.diffusion_batch_mul = diffusion_batch_mul
+        self.ckpt_path = ckpt_path
+
+        super().__init__(
+            **kwargs,
+        )
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py 
b/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py
new file mode 100644
index 00000000000..79fcd73c155
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py
@@ -0,0 +1,254 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+import torch.nn as nn
+import math
+
+
+class FlowLoss(nn.Module):
+    """Flow Loss"""
+
+    def __init__(self, target_channels, z_channels, depth, width, 
num_sampling_steps):
+        super(FlowLoss, self).__init__()
+        self.in_channels = target_channels
+        self.net = SimpleMLPAdaLN(
+            in_channels=target_channels,
+            model_channels=width,
+            out_channels=target_channels,
+            z_channels=z_channels,
+            num_res_blocks=depth
+        )
+        self.num_sampling_steps = num_sampling_steps
+
+    def forward(self, target, z, mask=None, mask_y=None):
+        noise = torch.randn_like(target)
+        t = torch.rand(target.shape[0], device=target.device)
+
+        noised_target = t[:, None] * target + (1 - t[:, None]) * noise
+
+        predict_v = self.net(noised_target, t * 1000, z)
+
+        weights = 1.0 / \
+            torch.arange(1, self.in_channels + 1, dtype=torch.float32, 
device=target.device)
+        if mask_y is not None:
+            loss = (mask_y * weights * (predict_v - target) ** 2).sum(dim=-1)
+        else:
+            loss = (weights * (predict_v - target) ** 2).sum(dim=-1)
+
+        if mask is not None:
+            loss = (loss * mask).sum() / mask.sum()
+        return loss.mean()
+
+    def sample(self, z, num_samples=1):
+        z = z.repeat(num_samples, 1)
+        noise = torch.randn(z.shape[0], self.in_channels).to(z.device)
+        x = noise
+        dt = 1.0 / self.num_sampling_steps
+        for i in range(self.num_sampling_steps):
+            t = (torch.ones((x.shape[0])) * i /
+                 self.num_sampling_steps).to(x.device)
+            pred = self.net(x, t * 1000, z)
+            x = x + (pred - noise) * dt
+        x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1)
+        return x
+
+
+def modulate(x, shift, scale):
+    return x * (1 + scale) + shift
+
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+
+    def __init__(self, hidden_size, frequency_embedding_size=256):
+        super().__init__()
+        self.mlp = nn.Sequential(
+            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+            nn.SiLU(),
+            nn.Linear(hidden_size, hidden_size, bias=True),
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+
+    @staticmethod
+    def timestep_embedding(t, dim, max_period=10000):
+        """
+        Create sinusoidal timestep embeddings.
+        :param t: a 1-D Tensor of N indices, one per batch element.
+                          These may be fractional.
+        :param dim: the dimension of the output.
+        :param max_period: controls the minimum frequency of the embeddings.
+        :return: an (N, D) Tensor of positional embeddings.
+        """
+        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period) * torch.arange(start=0,
+                                                 end=half, 
dtype=torch.float32) / half
+        ).to(device=t.device)
+        args = t[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat(
+                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+        return embedding
+
+    def forward(self, t):
+        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+class ResBlock(nn.Module):
+    """
+    A residual block that can optionally change the number of channels.
+    :param channels: the number of input channels.
+    """
+
+    def __init__(
+        self,
+        channels
+    ):
+        super().__init__()
+        self.channels = channels
+
+        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
+        self.mlp = nn.Sequential(
+            nn.Linear(channels, channels, bias=True),
+            nn.SiLU(),
+            nn.Linear(channels, channels, bias=True),
+        )
+
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(),
+            nn.Linear(channels, 3 * channels, bias=True)
+        )
+
+    def forward(self, x, y):
+        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+            y).chunk(3, dim=-1)
+        h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
+        h = self.mlp(h)
+        return x + gate_mlp * h
+
+
+class FinalLayer(nn.Module):
+    """
+    The final layer adopted from DiT.
+    """
+
+    def __init__(self, model_channels, out_channels):
+        super().__init__()
+        self.norm_final = nn.LayerNorm(
+            model_channels, elementwise_affine=False, eps=1e-6)
+        self.linear = nn.Linear(model_channels, out_channels, bias=True)
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(),
+            nn.Linear(model_channels, 2 * model_channels, bias=True)
+        )
+
+    def forward(self, x, c):
+        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+        x = modulate(self.norm_final(x), shift, scale)
+        x = self.linear(x)
+        return x
+
+
+class SimpleMLPAdaLN(nn.Module):
+    """
+    The MLP for Diffusion Loss.
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param z_channels: channels in the condition.
+    :param num_res_blocks: number of residual blocks per downsample.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        model_channels,
+        out_channels,
+        z_channels,
+        num_res_blocks,
+    ):
+        super().__init__()
+
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        self.num_res_blocks = num_res_blocks
+
+        self.time_embed = TimestepEmbedder(model_channels)
+        self.cond_embed = nn.Linear(z_channels, model_channels)
+
+        self.input_proj = nn.Linear(in_channels, model_channels)
+
+        res_blocks = []
+        for i in range(num_res_blocks):
+            res_blocks.append(ResBlock(
+                model_channels,
+            ))
+
+        self.res_blocks = nn.ModuleList(res_blocks)
+        self.final_layer = FinalLayer(model_channels, out_channels)
+
+        self.initialize_weights()
+
+    def initialize_weights(self):
+        def _basic_init(module):
+            if isinstance(module, nn.Linear):
+                torch.nn.init.xavier_uniform_(module.weight)
+                if module.bias is not None:
+                    nn.init.constant_(module.bias, 0)
+        self.apply(_basic_init)
+
+        # Initialize timestep embedding MLP
+        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
+        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
+
+        # Zero-out adaLN modulation layers
+        for block in self.res_blocks:
+            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+        # Zero-out output layers
+        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+        nn.init.constant_(self.final_layer.linear.weight, 0)
+        nn.init.constant_(self.final_layer.linear.bias, 0)
+
+    def forward(self, x, t, c):
+        """
+        Apply the model to an input batch.
+        :param x: an [N x C] Tensor of inputs.
+        :param t: a 1-D batch of timesteps.
+        :param c: conditioning from AR transformer.
+        :return: an [N x C] Tensor of outputs.
+        """
+        x = self.input_proj(x)
+        t = self.time_embed(t)
+        c = self.cond_embed(c)
+        y = t + c
+
+        for block in self.res_blocks:
+            x = block(x, y)
+
+        return self.final_layer(x, y)
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py 
b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
new file mode 100644
index 00000000000..bdb853cd72d
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
@@ -0,0 +1,612 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from typing import Optional, Tuple, List, Union
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import PreTrainedModel, Cache, DynamicCache
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import 
_prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import MoeModelOutputWithPast, 
MoeCausalLMOutputWithPast
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
+from ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin
+from ainode.core.model.sundial.flow_loss import FlowLoss
+
+from safetensors.torch import load_file as load_safetensors
+from huggingface_hub import hf_hub_download
+
+from ainode.core.log import Logger
+logger = Logger()
+
+def rotate_half(x):
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2:]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class SundialPatchEmbedding(nn.Module):
+    def __init__(self, config: SundialConfig):
+        super().__init__()
+        self.dropout = nn.Dropout(config.dropout_rate)
+        self.hidden_layer = nn.Linear(
+            config.input_token_len * 2, config.intermediate_size)
+        self.act = ACT2FN[config.hidden_act]
+        self.output_layer = nn.Linear(
+            config.intermediate_size, config.hidden_size)
+        self.residual_layer = nn.Linear(
+            config.input_token_len * 2, config.hidden_size)
+        self.input_token_len = config.input_token_len
+
+    def forward(self, x):
+        mask = torch.ones_like(x, dtype=torch.float32)
+        input_length = x.shape[-1]
+        padding_length = (self.input_token_len - (input_length %
+                          self.input_token_len)) % self.input_token_len
+        x = F.pad(x, (padding_length, 0))
+        mask = F.pad(mask, (padding_length, 0))
+        x = x.unfold(dimension=-1, size=self.input_token_len,
+                     step=self.input_token_len)
+        mask = mask.unfold(
+            dimension=-1, size=self.input_token_len, step=self.input_token_len)
+
+        x = torch.cat([x, mask], dim=-1)
+        hid = self.act(self.hidden_layer(x))
+        out = self.dropout(self.output_layer(hid))
+        res = self.residual_layer(x)
+        out = out + res
+        return out
+
+
+class SundialRotaryEmbedding(torch.nn.Module):
+    def __init__(self, dim, max_position_embeddings=10000, base=10000, 
device=None):
+        super().__init__()
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
+                          2, dtype=torch.int64).float().to(device) / self.dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+        # Build here to make `torch.jit.trace` work.
+        self._set_cos_sin_cache(
+            seq_len=max_position_embeddings, device=self.inv_freq.device, 
dtype=torch.get_default_dtype()
+        )
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        self.max_seq_len_cached = seq_len
+        t = torch.arange(self.max_seq_len_cached, device=device,
+                         dtype=torch.int64).type_as(self.inv_freq)
+
+        freqs = torch.outer(t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order 
to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer(
+            "cos_cached", emb.cos().to(dtype), persistent=False)
+        self.register_buffer(
+            "sin_cached", emb.sin().to(dtype), persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if seq_len > self.max_seq_len_cached:
+            self._set_cos_sin_cache(
+                seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+        return (
+            self.cos_cached[:seq_len].to(dtype=x.dtype),
+            self.sin_cached[:seq_len].to(dtype=x.dtype),
+        )
+
+
+class SundialAttention(nn.Module):
+    def __init__(self, config: SundialConfig, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.layer_idx = layer_idx
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.attention_dropout = config.dropout_rate
+        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+        self.rotary_emb = SundialRotaryEmbedding(
+            self.head_dim, 
max_position_embeddings=config.max_position_embeddings)
+
+    def forward(
+            self,
+            hidden_states: torch.Tensor,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_value: Optional[Cache] = None,
+            output_attentions: bool = False,
+            **kwargs,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 
Optional[Tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(
+            bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(
+            bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(
+            bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value.get_usable_length(
+                kv_seq_len, self.layer_idx)
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        query_states, key_states = apply_rotary_pos_emb(
+            query_states, key_states, cos, sin, position_ids)
+
+        if past_key_value is not None:
+            key_states, value_states = past_key_value.update(
+                key_states, value_states, self.layer_idx)
+
+        attn_output = F.scaled_dot_product_attention(
+            query_states, key_states, value_states, attention_mask, 
dropout_p=(self.attention_dropout if self.training else 0.0))
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class SundialMLP(nn.Module):
+    def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: 
str):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.gate_proj = nn.Linear(
+            self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(
+            self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(
+            self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = ACT2FN[hidden_act]
+
+    def forward(self, hidden_state):
+        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * 
self.up_proj(hidden_state))
+
+
+class SundialDecoderLayer(nn.Module):
+    def __init__(self, config: SundialConfig, layer_idx: int):
+        super().__init__()
+        self.self_attn = SundialAttention(config, layer_idx)
+
+        self.ffn_layer = SundialMLP(
+            hidden_size=config.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+        )
+        self.norm1 = torch.nn.LayerNorm(config.hidden_size)
+        self.norm2 = torch.nn.LayerNorm(config.hidden_size)
+
+    def forward(
+            self,
+            hidden_states: torch.Tensor,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_value: Optional[Tuple[torch.Tensor]] = None,
+            output_attentions: Optional[bool] = False,
+            use_cache: Optional[bool] = False,
+            **kwargs,
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, 
Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
+        residual = hidden_states
+
+        hidden_states = self.norm1(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.norm2(hidden_states)
+        hidden_states = self.ffn_layer(hidden_states)
+        hidden_states = residual + hidden_states
+
+        if not output_attentions:
+            self_attn_weights = None
+
+        if not use_cache:
+            present_key_value = None
+        return hidden_states, self_attn_weights, present_key_value
+
+
+class SundialPreTrainedModel(PreTrainedModel):
+    config_class = SundialConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["SundialDecoderLayer"]
+    _skip_keys_device_placement = "past_key_values"
+    _supports_flash_attn_2 = True
+    _supports_sdpa = False
+    _supports_cache_class = True
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, torch.nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, torch.nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+class SundialModel(SundialPreTrainedModel):
+    def __init__(self, config: SundialConfig):
+        super().__init__(config)
+        self.embed_layer = SundialPatchEmbedding(config)
+        self.layers = nn.ModuleList(
+            [SundialDecoderLayer(config, layer_idx)
+             for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = torch.nn.LayerNorm(config.hidden_size)
+        self.gradient_checkpointing = False
+
+    def forward(
+            self,
+            input_ids: torch.FloatTensor = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_values: Optional[List[torch.FloatTensor]] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            use_cache: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, MoeModelOutputWithPast]:
+        # input_ids is the input of time series, its shape is [batch_size, 
seq_len]
+        output_attentions = output_attentions if output_attentions is not None 
else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else 
self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else 
self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else 
self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both decoder_input_ids and 
decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = inputs_embeds.shape
+        else:
+            raise ValueError(
+                "You have to specify either decoder_input_ids or 
decoder_inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_layer(input_ids)
+            seq_length = inputs_embeds.shape[1]
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                use_cache = False
+
+        past_key_values_length = 0
+
+        if use_cache:
+            use_legacy_cache = not isinstance(past_key_values, Cache)
+            if use_legacy_cache:
+                past_key_values = DynamicCache.from_legacy_cache(
+                    past_key_values)
+            past_key_values_length = past_key_values.get_usable_length(
+                seq_length)
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else 
inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, 
dtype=torch.long, device=device
+            )
+            # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+            position_ids = position_ids.view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        # 4d mask is passed through the layers
+        attention_mask = _prepare_4d_causal_attention_mask(
+            attention_mask,
+            (batch_size, seq_length),
+            inputs_embeds,
+            past_key_values_length,
+            sliding_window=None,
+        )
+
+        hidden_states = inputs_embeds
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = None
+
+        for decoder_layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    past_key_values,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_values,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+            if use_cache:
+                next_decoder_cache = layer_outputs[2]
+
+        hidden_states = self.norm(hidden_states)
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = None
+        if use_cache:
+            next_cache = next_decoder_cache.to_legacy_cache(
+            ) if use_legacy_cache else next_decoder_cache
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, 
all_self_attns]
+                if v is not None
+            )
+        return MoeModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+
+class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin):
+    def __init__(self, config: SundialConfig):
+        super().__init__(config)
+        self.config = config
+        self.model = SundialModel(self.config)
+        self.flow_loss = FlowLoss(self.config.output_token_lens[-1], 
self.config.hidden_size,
+                                  self.config.flow_loss_depth, 
self.config.hidden_size, self.config.num_sampling_steps)
+        # TODO: Unify data loader
+        if not os.path.exists(config.ckpt_path):
+            os.mkdir(config.ckpt_path)
+        weights_path = os.path.join(config.ckpt_path, "model.safetensors")
+        if not os.path.exists(weights_path):
+            logger.info(f"Weight not found at {weights_path}, downloading from 
HuggingFace...")
+            repo_id = "thuml/sundial-base-128m"
+            try:
+                hf_hub_download(repo_id=repo_id, filename="model.safetensors", 
local_dir=config.ckpt_path)
+                logger.info(f"Got weight to {weights_path}")
+            except Exception as e:
+                logger.error(f"Failed to download weight to {weights_path} due 
to {e}")
+                raise e
+        state_dict = load_safetensors(weights_path)
+        self.load_state_dict(state_dict, strict=True)
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    def forward(
+            self,
+            input_ids: torch.FloatTensor = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_values: Optional[List[torch.FloatTensor]] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            labels: Optional[torch.FloatTensor] = None,
+            loss_masks: Optional[torch.FloatTensor] = None,
+            mask_y: Optional[torch.FloatTensor] = None,
+            use_cache: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+            max_output_length: Optional[int] = None,
+            revin: Optional[bool] = False,
+            num_samples: Optional[int] = 1,
+    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+
+        output_attentions = output_attentions if output_attentions is not None 
else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else 
self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else 
self.config.use_return_dict
+
+        if revin:
+            means = input_ids.mean(1, keepdim=True).detach()
+            stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach()
+            stdev = torch.where(stdev > 1e-2, stdev, torch.tensor(1.0, 
device=input_ids.device))
+            input_ids = (input_ids - means) / stdev
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0] if not return_dict else 
outputs.last_hidden_state
+        predictions = None
+
+        loss = None
+        if labels is not None:
+            if revin:
+                labels = (labels - means) / stdev
+            output_token_len = self.config.output_token_lens[-1]
+            seq_len = hidden_states.shape[1] * self.config.input_token_len
+            labels = labels[:, :seq_len -
+                            self.config.input_token_len + output_token_len]
+            shift_labels = labels.unfold(
+                dimension=-1, size=output_token_len, 
step=self.config.input_token_len)
+
+            bsz, L, _ = shift_labels.shape
+            shift_labels = shift_labels.reshape(
+                bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
+            hidden_states = hidden_states.reshape(
+                bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
+            loss_masks = loss_masks.reshape(
+                bsz * L).repeat(self.config.diffusion_batch_mul)
+            mask_y = mask_y.repeat(L * self.config.diffusion_batch_mul, 1)
+
+            loss = self.flow_loss(shift_labels, hidden_states, loss_masks, 
mask_y)
+        else:
+            if max_output_length is None:
+                output_token_len = self.config.output_token_lens[0]
+                max_output_length = output_token_len
+            else:
+                output_token_len = self.config.output_token_lens[0]
+                for h in self.config.output_token_lens[1:]:
+                    if h > max_output_length:
+                        break
+                    else:
+                        output_token_len = h
+
+            bsz = hidden_states.shape[0]
+            hidden_states = hidden_states[:, -1, :]
+            predictions = self.flow_loss.sample(hidden_states, num_samples)
+            if output_token_len > max_output_length:
+                predictions = predictions[:, :, :max_output_length]
+            if revin:
+                predictions = predictions * stdev + means
+        if not return_dict:
+            output = (predictions,) + outputs[1:]
+            return (loss) + output if loss is not None else output
+
+        return MoeCausalLMOutputWithPast(
+            loss=loss,
+            logits=predictions,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(
+            self, input_ids, past_key_values=None, attention_mask=None, 
inputs_embeds=None, revin=False, num_samples=1, **kwargs
+    ):
+        # Omit tokens covered by past_key_values
+        if past_key_values is not None:
+            if isinstance(past_key_values, Cache):
+                cache_length = past_key_values.get_seq_length()
+                if isinstance(past_key_values, DynamicCache):
+                    past_length = past_key_values.seen_tokens
+                else:
+                    past_length = cache_length
+
+                max_cache_length = past_key_values.get_max_length()
+            else:
+                cache_length = past_length = past_key_values[0][0].shape[2]
+                max_cache_length = None
+
+            # Keep only the unprocessed tokens:
+            # 1 - If the length of the attention_mask exceeds the length of 
input_ids, then we are in a setting where
+            # some of the inputs are exclusively passed as part of the cache 
(e.g. when passing input_embeds as
+            # input)
+            if attention_mask is not None and attention_mask.shape[1] > 
(input_ids.shape[1] // self.config.input_token_len):
+                input_ids = input_ids[:, -
+                                      (attention_mask.shape[1] - past_length) 
* self.config.input_token_len:]
+            # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
+            # input_ids based on the past_length.
+            elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
+                input_ids = input_ids[:, past_length *
+                                      self.config.input_token_len:]
+            # 3 - Otherwise (past_length >= (input_ids.shape[1] // 
self.config.input_token_len)), let's assume input_ids only has unprocessed 
tokens.
+
+            # If we are about to go beyond the maximum cache length, we need 
to crop the input attention mask.
+            if (
+                    max_cache_length is not None
+                    and attention_mask is not None
+                    and cache_length + (input_ids.shape[1] // 
self.config.input_token_len) > max_cache_length
+            ):
+                attention_mask = attention_mask[:, -max_cache_length:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -
+                                            (input_ids.shape[1] // 
self.config.input_token_len):]
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st 
generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+                "revin": revin,
+                "num_samples": num_samples,
+            }
+        )
+        return model_inputs
\ No newline at end of file
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py 
b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
new file mode 100644
index 00000000000..723b4a1332a
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
@@ -0,0 +1,306 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import warnings
+from typing import Any, Dict, List, Optional, Union, Callable
+import torch
+from transformers import GenerationMixin, LogitsProcessorList, 
StoppingCriteriaList
+from transformers.generation import validate_stopping_criteria, 
EosTokenCriteria
+from transformers.generation.utils import GenerateNonBeamOutput, 
GenerateEncoderDecoderOutput, \
+    GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
+from transformers.utils import ModelOutput
+
+
+class TSGenerationMixin(GenerationMixin):
+    @torch.no_grad()
+    def generate(
+            self,
+            inputs: Optional[torch.Tensor] = None,
+            generation_config: Optional[GenerationConfig] = None,
+            logits_processor: Optional[LogitsProcessorList] = None,
+            stopping_criteria: Optional[StoppingCriteriaList] = None,
+            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], 
List[int]]] = None,
+            synced_gpus: Optional[bool] = None,
+            assistant_model: Optional["PreTrainedModel"] = None,
+            streamer: Optional["BaseStreamer"] = None,
+            negative_prompt_ids: Optional[torch.Tensor] = None,
+            negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+            revin: Optional[bool] = True,
+            num_samples: Optional[int] = 1,
+            **kwargs,
+    ) -> Union[GenerateOutput, torch.LongTensor]:
+        if len(inputs.shape) != 2:
+            raise ValueError('Input shape must be: [batch_size, seq_len]')
+        if revin:
+            means = inputs.mean(dim=-1, keepdim=True)
+            stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5
+            inputs = (inputs - means) / stdev
+        outputs = super().generate(inputs=inputs, 
generation_config=generation_config,
+                                   logits_processor=logits_processor, 
stopping_criteria=stopping_criteria,
+                                   
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus,
+                                   assistant_model=assistant_model, 
streamer=streamer,
+                                   negative_prompt_ids=negative_prompt_ids,
+                                   
negative_prompt_attention_mask=negative_prompt_attention_mask,
+                                   num_samples=num_samples, **kwargs)
+        if revin:
+            stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1)
+            means = means.unsqueeze(1).repeat(1, num_samples, 1)
+            outputs = (outputs * stdev) + means
+        return outputs
+
+    def _greedy_search(
+            self,
+            input_ids: torch.Tensor,
+            logits_processor: Optional[LogitsProcessorList] = None,
+            stopping_criteria: Optional[StoppingCriteriaList] = None,
+            max_length: Optional[int] = None,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[Union[int, List[int]]] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            output_scores: Optional[bool] = None,
+            output_logits: Optional[bool] = None,
+            return_dict_in_generate: Optional[bool] = None,
+            synced_gpus: bool = False,
+            streamer: Optional["BaseStreamer"] = None,
+            **model_kwargs,
+    ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
+        input_ids = input_ids.to(self.device)
+        batch_size, cur_len = input_ids.shape
+        # init values
+        logits_processor = logits_processor if logits_processor is not None 
else LogitsProcessorList()
+        stopping_criteria = stopping_criteria if stopping_criteria is not None 
else StoppingCriteriaList()
+        if max_length is not None:
+            warnings.warn(
+                "`max_length` is deprecated in this function, use"
+                " 
`stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`
 instead.",
+                UserWarning,
+            )
+            stopping_criteria = validate_stopping_criteria(
+                stopping_criteria, max_length)
+        pad_token_id = pad_token_id if pad_token_id is not None else 
self.generation_config.pad_token_id
+        if eos_token_id is not None:
+            stopping_criteria.append(
+                EosTokenCriteria(eos_token_id=eos_token_id))
+        else:
+            # remove when the method is totally private
+            # need to get `eos_token_id` and add stopping criteria, so that 
generation does not go forever
+            eos_token_id = [
+                criteria.eos_token_id.tolist() for criteria in 
stopping_criteria if hasattr(criteria, "eos_token_id")
+            ]
+            eos_token_id = eos_token_id[0] if eos_token_id else None
+            if eos_token_id is None and self.generation_config.eos_token_id is 
not None:
+                eos_token_id = self.generation_config.eos_token_id
+                stopping_criteria.append(
+                    EosTokenCriteria(eos_token_id=eos_token_id))
+
+        if isinstance(eos_token_id, int):
+            eos_token_id = [eos_token_id]
+        output_scores = output_scores if output_scores is not None else 
self.generation_config.output_scores
+        output_attentions = (
+            output_attentions if output_attentions is not None else 
self.generation_config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else 
self.generation_config.output_hidden_states
+        )
+        return_dict_in_generate = (
+            return_dict_in_generate
+            if return_dict_in_generate is not None
+            else self.generation_config.return_dict_in_generate
+        )
+
+        # init attention / hidden states / scores tuples
+        raw_logits = () if (return_dict_in_generate and output_logits) else 
None
+        scores = () if (return_dict_in_generate and output_scores) else None
+        decoder_attentions = () if (return_dict_in_generate and 
output_attentions) else None
+        cross_attentions = () if (return_dict_in_generate and 
output_attentions) else None
+        decoder_hidden_states = () if (
+                return_dict_in_generate and output_hidden_states) else None
+
+        # if model is an encoder-decoder, retrieve encoder attention weights 
and hidden states
+        if return_dict_in_generate and self.config.is_encoder_decoder:
+            encoder_attentions = model_kwargs["encoder_outputs"].get(
+                "attentions") if output_attentions else None
+            encoder_hidden_states = (
+                model_kwargs["encoder_outputs"].get(
+                    "hidden_states") if output_hidden_states else None
+            )
+
+        # keep track of which sequences are already finished
+        if "inputs_embeds" in model_kwargs:
+            cur_len = model_kwargs["inputs_embeds"].shape[1]
+        this_peer_finished = False
+        unfinished_sequences = torch.ones(
+            batch_size, dtype=torch.long, device=input_ids.device)
+        model_kwargs["cache_position"] = torch.arange(
+            cur_len, device=input_ids.device)
+        true_seq_len = (cur_len + self.config.input_token_len - 1) // 
self.config.input_token_len
+        model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 
-true_seq_len:]
+        max_length = stopping_criteria.max_length
+        generate_results = None
+        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, 
device=input_ids.device):
+            # prepare model inputs
+            model_inputs = self.prepare_inputs_for_generation(
+                input_ids, **model_kwargs)
+
+            input_length = input_ids.shape[1]
+
+            # forward pass to get next token
+            outputs = self(
+                **model_inputs,
+                return_dict=True,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                max_output_length=max_length - input_length,
+            )
+
+            if synced_gpus and this_peer_finished:
+                continue  # don't waste resources running the code we don't 
need
+            next_token_logits = outputs.logits
+
+            # pre-process distribution
+            next_tokens_scores = logits_processor(input_ids, next_token_logits)
+
+            # Store scores, attentions and hidden_states when required
+            if return_dict_in_generate:
+                if output_scores:
+                    scores += (next_tokens_scores,)
+                if output_logits:
+                    raw_logits += (next_token_logits,)
+                if output_attentions:
+                    decoder_attentions += (
+                        (outputs.decoder_attentions,) if 
self.config.is_encoder_decoder else (
+                            outputs.attentions,)
+                    )
+                    if self.config.is_encoder_decoder:
+                        cross_attentions += (outputs.cross_attentions,)
+
+                if output_hidden_states:
+                    decoder_hidden_states += (
+                        (outputs.decoder_hidden_states,)
+                        if self.config.is_encoder_decoder
+                        else (outputs.hidden_states,)
+                    )
+
+            # argmax
+            # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+            next_tokens = next_tokens_scores
+
+            # finished sentences should have their next token be a padding 
token
+            if eos_token_id is not None:
+                if pad_token_id is None:
+                    raise ValueError(
+                        "If `eos_token_id` is defined, make sure that 
`pad_token_id` is defined.")
+                next_tokens = next_tokens * unfinished_sequences + \
+                              pad_token_id * (1 - unfinished_sequences)
+
+            # update generated ids, model inputs, and length for next step
+            horizon_length = next_tokens.shape[-1] // 
self.config.input_token_len
+
+            past_key_values = model_kwargs.get("past_key_values")
+            if past_key_values is None:
+                generate_results = next_tokens
+            else:
+                generate_results = torch.cat([generate_results, next_tokens], 
dim=-1)
+            input_ids = torch.cat([input_ids, next_tokens.median(dim=1)[0]], 
dim=-1)
+
+            if streamer is not None:
+                streamer.put(next_tokens.cpu())
+            model_kwargs = self._update_model_kwargs_for_generation(
+                outputs,
+                model_kwargs,
+                horizon_length=horizon_length,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+            )
+            unfinished_sequences = unfinished_sequences & ~stopping_criteria(
+                input_ids, scores)
+            this_peer_finished = unfinished_sequences.max() == 0
+
+        if input_ids.shape[-1] > max_length:
+            input_ids = input_ids[:, :max_length]
+
+        if streamer is not None:
+            streamer.end()
+
+        if return_dict_in_generate:
+            if self.config.is_encoder_decoder:
+                return GenerateEncoderDecoderOutput(
+                    sequences=input_ids,
+                    scores=scores,
+                    logits=raw_logits,
+                    encoder_attentions=encoder_attentions,
+                    encoder_hidden_states=encoder_hidden_states,
+                    decoder_attentions=decoder_attentions,
+                    cross_attentions=cross_attentions,
+                    decoder_hidden_states=decoder_hidden_states,
+                    past_key_values=model_kwargs.get("past_key_values"),
+                )
+            else:
+                return GenerateDecoderOnlyOutput(
+                    sequences=input_ids,
+                    scores=scores,
+                    logits=raw_logits,
+                    attentions=decoder_attentions,
+                    hidden_states=decoder_hidden_states,
+                    past_key_values=model_kwargs.get("past_key_values"),
+                )
+        else:
+            return generate_results[:, :, :(max_length - cur_len)]
+
+    def _update_model_kwargs_for_generation(
+            self,
+            outputs: ModelOutput,
+            model_kwargs: Dict[str, Any],
+            horizon_length: int = 1,
+            is_encoder_decoder: bool = False,
+            standardize_cache_format: bool = False,
+    ) -> Dict[str, Any]:
+        # update past_key_values
+        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+            outputs, standardize_cache_format=standardize_cache_format
+        )
+        if getattr(outputs, "state", None) is not None:
+            model_kwargs["state"] = outputs.state
+
+        # update token_type_ids with last value
+        if "token_type_ids" in model_kwargs:
+            token_type_ids = model_kwargs["token_type_ids"]
+            model_kwargs["token_type_ids"] = torch.cat(
+                [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+        if not is_encoder_decoder:
+            # update attention mask
+            if "attention_mask" in model_kwargs:
+                attention_mask = model_kwargs["attention_mask"]
+                model_kwargs["attention_mask"] = torch.cat(
+                    [attention_mask, 
attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
+                )
+        else:
+            # update decoder attention mask
+            if "decoder_attention_mask" in model_kwargs:
+                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+                model_kwargs["decoder_attention_mask"] = torch.cat(
+                    [decoder_attention_mask, decoder_attention_mask.new_ones(
+                        (decoder_attention_mask.shape[0], horizon_length))],
+                    dim=-1,
+                )
+
+        if "cache_position" in model_kwargs and model_kwargs["cache_position"] 
is not None:
+            model_kwargs["cache_position"] = 
model_kwargs["cache_position"][-1:] + horizon_length
+
+        return model_kwargs
\ No newline at end of file
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index 70557d4f2c0..eaa418b623f 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -63,6 +63,7 @@ apache-iotdb = "2.0.4.dev0"
 einops = "^0.8.1"
 safetensors = "^0.5.1"
 huggingface_hub = "^0.30.1"
+transformers = "==4.40.1"
 
 [tool.poetry.scripts]
 ainode = "ainode.core.script:main"
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 44093d10053..8f00fde8773 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -83,6 +83,7 @@ public class ModelInfo implements SnapshotProcessor {
     builtInForecastModel.add("_STLForecaster");
     builtInForecastModel.add("_HoltWinters");
     builtInForecastModel.add("_ExponentialSmoothing");
+    builtInForecastModel.add("_sundial");
     builtInAnomalyDetectionModel.add("_GaussianHMM");
     builtInAnomalyDetectionModel.add("_GMMHMM");
     builtInAnomalyDetectionModel.add("_Stray");

Reply via email to