This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 3ba8b0c17ac [AINode] Integrate Sundial as built-in model (#15586)
3ba8b0c17ac is described below
commit 3ba8b0c17ac1842278770b78ad053ee572440948
Author: Yongzao <[email protected]>
AuthorDate: Thu May 29 21:32:05 2025 +0800
[AINode] Integrate Sundial as built-in model (#15586)
---
.../relational/it/schema/IoTDBDatabaseIT.java | 2 +
iotdb-core/ainode/ainode/core/constant.py | 11 +-
.../ainode/core/manager/inference_manager.py | 13 +
.../ainode/core/model/built_in_model_factory.py | 104 +++-
.../ainode/ainode/core/model/sundial/__init__.py | 17 +
.../core/model/sundial/configuration_sundial.py | 66 +++
.../ainode/ainode/core/model/sundial/flow_loss.py | 254 +++++++++
.../ainode/core/model/sundial/modeling_sundial.py | 612 +++++++++++++++++++++
.../core/model/sundial/ts_generation_mixin.py | 306 +++++++++++
iotdb-core/ainode/pyproject.toml | 1 +
.../iotdb/confignode/persistence/ModelInfo.java | 1 +
11 files changed, 1384 insertions(+), 3 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
index 0af5ffd5a1d..1956aa58222 100644
---
a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
@@ -577,6 +577,7 @@ public class IoTDBDatabaseIT {
Arrays.asList(
"_STLForecaster,",
"_NaiveForecaster,",
+ "_sundial,",
"_HoltWinters,",
"_TimerXL,",
"_ExponentialSmoothing,",
@@ -700,6 +701,7 @@ public class IoTDBDatabaseIT {
"model_id,",
new HashSet<>(
Arrays.asList(
+ "_sundial,",
"_TimerXL,",
"_STLForecaster,",
"_NaiveForecaster,",
diff --git a/iotdb-core/ainode/ainode/core/constant.py
b/iotdb-core/ainode/ainode/core/constant.py
index 467ffbfdec7..1078836ff65 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -152,6 +152,9 @@ class BuiltInModelType(Enum):
# timerxl
TIMER_XL = "_timerxl"
+ # sundial
+ SUNDIAL = "_sundial"
+
@classmethod
def values(cls) -> List[str]:
values = []
@@ -259,7 +262,13 @@ class AttributeName(Enum):
ATTENTION_DROPOUT = "attention_dropout"
INITIALIZER_RANGE = "initializer_range"
MAX_POSITION_EMBEDDINGS = "max_position_embeddings"
- TIMERXL_CKPT_PATH = "ckpt_path"
+ CKPT_PATH = "ckpt_path"
+
+ # sundial
+ DROPOUT_RATE = "dropout_rate"
+ FLOW_LOSS_DEPTH = "flow_loss_depth"
+ NUM_SAMPLING_STEPS = "num_sampling_steps"
+ DIFFUSION_BATCH_MUL = "diffusion_batch_mul"
def name(self) -> str:
return self.value
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index e5142b0205b..9df60fe1a40 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -52,6 +52,17 @@ class TimerXLStrategy(InferenceStrategy):
df = pd.DataFrame(output[0])
return convert_to_binary(df)
+class SundialStrategy(InferenceStrategy):
+ def infer(self, full_data, predict_length=96, **_):
+ data = full_data[1][0]
+ if data.dtype.byteorder not in ('=', '|'):
+ data = data.byteswap().newbyteorder()
+ seqs = torch.tensor(data).unsqueeze(0).float()
+ # TODO: unify model inference input
+ output = self.model.generate(seqs, max_new_tokens=predict_length,
num_samples=10, revin=True)
+ df = pd.DataFrame(output[0].mean(dim=0))
+ return convert_to_binary(df)
+
class BuiltInStrategy(InferenceStrategy):
def infer(self, full_data, **_):
@@ -94,6 +105,8 @@ class RegisteredStrategy(InferenceStrategy):
def _get_strategy(model_id, model):
if model_id == '_timerxl':
return TimerXLStrategy(model)
+ if model_id == '_sundial':
+ return SundialStrategy(model)
if model_id.startswith('_'):
return BuiltInStrategy(model)
return RegisteredStrategy(model)
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index f395e0bae9e..f699934e0c5 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -30,12 +30,14 @@ from sktime.forecasting.trend import STLForecaster
from ainode.TimerXL.models import timer_xl
from ainode.TimerXL.models.configuration_timer import TimerxlConfig
+from ainode.core.model.sundial import modeling_sundial
from ainode.core.config import AINodeDescriptor
from ainode.core.constant import AttributeName, BuiltInModelType
from ainode.core.exception import InferenceModelInternalError
from ainode.core.exception import WrongAttributeTypeError,
NumericalRangeException, StringRangeException, \
ListRangeException, BuiltInModelNotSupportError
from ainode.core.log import Logger
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
logger = Logger()
@@ -57,6 +59,8 @@ def get_model_attributes(model_id: str):
attribute_map = stray_attribute_map
elif model_id == BuiltInModelType.TIMER_XL.value:
attribute_map = timerxl_attribute_map
+ elif model_id == BuiltInModelType.SUNDIAL.value:
+ attribute_map = sundial_attribute_map
else:
raise BuiltInModelNotSupportError(model_id)
return attribute_map
@@ -99,6 +103,8 @@ def fetch_built_in_model(model_id, inference_attributes):
model = STRAYModel(attributes)
elif model_id == BuiltInModelType.TIMER_XL.value:
model = timer_xl.Model(TimerxlConfig.from_dict(attributes))
+ elif model_id == BuiltInModelType.SUNDIAL.value:
+ model =
modeling_sundial.SundialForPrediction(SundialConfig.from_dict(attributes))
else:
raise BuiltInModelNotSupportError(model_id)
@@ -321,6 +327,100 @@ def parse_attribute(input_attributes: Dict[str, str],
attribute_map: Dict[str, A
raise e
return attributes
+sundial_attribute_map = {
+ AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
+ name=AttributeName.INPUT_TOKEN_LEN.value,
+ default_value=16,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.HIDDEN_SIZE.value: IntAttribute(
+ name=AttributeName.HIDDEN_SIZE.value,
+ default_value=768,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.INTERMEDIATE_SIZE.value: IntAttribute(
+ name=AttributeName.INTERMEDIATE_SIZE.value,
+ default_value=3072,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute(
+ name=AttributeName.OUTPUT_TOKEN_LENS.value,
+ default_value=[720],
+ value_type=int
+ ),
+ AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute(
+ name=AttributeName.NUM_HIDDEN_LAYERS.value,
+ default_value=12,
+ default_low=1,
+ default_high=16
+ ),
+ AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute(
+ name=AttributeName.NUM_ATTENTION_HEADS.value,
+ default_value=12,
+ default_low=1,
+ default_high=192
+ ),
+ AttributeName.HIDDEN_ACT.value: StringAttribute(
+ name=AttributeName.HIDDEN_ACT.value,
+ default_value="silu",
+ value_choices=["relu", "gelu", "silu", "tanh"],
+ ),
+ AttributeName.USE_CACHE.value: BooleanAttribute(
+ name=AttributeName.USE_CACHE.value,
+ default_value=True,
+ ),
+ AttributeName.ROPE_THETA.value: IntAttribute(
+ name=AttributeName.ROPE_THETA.value,
+ default_value=10000,
+ default_low=1000,
+ default_high=50000
+ ),
+ AttributeName.DROPOUT_RATE.value: FloatAttribute(
+ name=AttributeName.DROPOUT_RATE.value,
+ default_value=0.1,
+ default_low=0.0,
+ default_high=1.0
+ ),
+ AttributeName.INITIALIZER_RANGE.value: FloatAttribute(
+ name=AttributeName.INITIALIZER_RANGE.value,
+ default_value=0.02,
+ default_low=0.0,
+ default_high=1.0
+ ),
+ AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute(
+ name=AttributeName.MAX_POSITION_EMBEDDINGS.value,
+ default_value=10000,
+ default_low=1,
+ default_high=50000
+ ),
+ AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute(
+ name=AttributeName.FLOW_LOSS_DEPTH.value,
+ default_value=3,
+ default_low=1,
+ default_high=50
+ ),
+ AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute(
+ name=AttributeName.NUM_SAMPLING_STEPS.value,
+ default_value=50,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute(
+ name=AttributeName.DIFFUSION_BATCH_MUL.value,
+ default_value=4,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.CKPT_PATH.value: StringAttribute(
+ name=AttributeName.CKPT_PATH.value,
+ default_value=os.path.join(os.getcwd(),
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
+ 'sundial'),
+ value_choices=['']
+ )
+}
timerxl_attribute_map = {
AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
@@ -391,8 +491,8 @@ timerxl_attribute_map = {
default_low=1,
default_high=50000
),
- AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute(
- name=AttributeName.TIMERXL_CKPT_PATH.value,
+ AttributeName.CKPT_PATH.value: StringAttribute(
+ name=AttributeName.CKPT_PATH.value,
default_value=os.path.join(os.getcwd(),
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
'timerxl', 'model.safetensors'),
value_choices=['']
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/__init__.py
b/iotdb-core/ainode/ainode/core/model/sundial/__init__.py
new file mode 100644
index 00000000000..4b8ee97fad2
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
\ No newline at end of file
diff --git
a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
new file mode 100644
index 00000000000..41c54ff4a72
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import List
+from transformers import PretrainedConfig
+
+
+class SundialConfig(PretrainedConfig):
+ model_type = "sundial"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ input_token_len: int = 16,
+ hidden_size: int = 768,
+ intermediate_size: int = 3072,
+ output_token_lens: List[int] = [720],
+ num_hidden_layers: int = 12,
+ num_attention_heads: int = 12,
+ hidden_act: str = "silu",
+ use_cache: bool = True,
+ rope_theta: int = 10000,
+ dropout_rate: float = 0.1,
+ initializer_range: float = 0.02,
+ max_position_embeddings: int = 10000,
+ flow_loss_depth: int = 3,
+ num_sampling_steps: int = 50,
+ diffusion_batch_mul: int = 4,
+ ckpt_path: str = None, # weight path
+ **kwargs,
+ ):
+ self.input_token_len = input_token_len
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.output_token_lens = output_token_lens
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.dropout_rate = dropout_rate
+ self.initializer_range = initializer_range
+ self.max_position_embeddings = max_position_embeddings
+ self.flow_loss_depth = flow_loss_depth
+ self.num_sampling_steps = num_sampling_steps
+ self.diffusion_batch_mul = diffusion_batch_mul
+ self.ckpt_path = ckpt_path
+
+ super().__init__(
+ **kwargs,
+ )
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py
b/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py
new file mode 100644
index 00000000000..79fcd73c155
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py
@@ -0,0 +1,254 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+import torch.nn as nn
+import math
+
+
+class FlowLoss(nn.Module):
+ """Flow Loss"""
+
+ def __init__(self, target_channels, z_channels, depth, width,
num_sampling_steps):
+ super(FlowLoss, self).__init__()
+ self.in_channels = target_channels
+ self.net = SimpleMLPAdaLN(
+ in_channels=target_channels,
+ model_channels=width,
+ out_channels=target_channels,
+ z_channels=z_channels,
+ num_res_blocks=depth
+ )
+ self.num_sampling_steps = num_sampling_steps
+
+ def forward(self, target, z, mask=None, mask_y=None):
+ noise = torch.randn_like(target)
+ t = torch.rand(target.shape[0], device=target.device)
+
+ noised_target = t[:, None] * target + (1 - t[:, None]) * noise
+
+ predict_v = self.net(noised_target, t * 1000, z)
+
+ weights = 1.0 / \
+ torch.arange(1, self.in_channels + 1, dtype=torch.float32,
device=target.device)
+ if mask_y is not None:
+ loss = (mask_y * weights * (predict_v - target) ** 2).sum(dim=-1)
+ else:
+ loss = (weights * (predict_v - target) ** 2).sum(dim=-1)
+
+ if mask is not None:
+ loss = (loss * mask).sum() / mask.sum()
+ return loss.mean()
+
+ def sample(self, z, num_samples=1):
+ z = z.repeat(num_samples, 1)
+ noise = torch.randn(z.shape[0], self.in_channels).to(z.device)
+ x = noise
+ dt = 1.0 / self.num_sampling_steps
+ for i in range(self.num_sampling_steps):
+ t = (torch.ones((x.shape[0])) * i /
+ self.num_sampling_steps).to(x.device)
+ pred = self.net(x, t * 1000, z)
+ x = x + (pred - noise) * dt
+ x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1)
+ return x
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0,
+ end=half,
dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class ResBlock(nn.Module):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ """
+
+ def __init__(
+ self,
+ channels
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(channels, channels, bias=True),
+ nn.SiLU(),
+ nn.Linear(channels, channels, bias=True),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(channels, 3 * channels, bias=True)
+ )
+
+ def forward(self, x, y):
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
+ y).chunk(3, dim=-1)
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
+ h = self.mlp(h)
+ return x + gate_mlp * h
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer adopted from DiT.
+ """
+
+ def __init__(self, model_channels, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(
+ model_channels, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class SimpleMLPAdaLN(nn.Module):
+ """
+ The MLP for Diffusion Loss.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param z_channels: channels in the condition.
+ :param num_res_blocks: number of residual blocks per downsample.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ z_channels,
+ num_res_blocks,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+
+ self.time_embed = TimestepEmbedder(model_channels)
+ self.cond_embed = nn.Linear(z_channels, model_channels)
+
+ self.input_proj = nn.Linear(in_channels, model_channels)
+
+ res_blocks = []
+ for i in range(num_res_blocks):
+ res_blocks.append(ResBlock(
+ model_channels,
+ ))
+
+ self.res_blocks = nn.ModuleList(res_blocks)
+ self.final_layer = FinalLayer(model_channels, out_channels)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
+
+ # Zero-out adaLN modulation layers
+ for block in self.res_blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, c):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C] Tensor of inputs.
+ :param t: a 1-D batch of timesteps.
+ :param c: conditioning from AR transformer.
+ :return: an [N x C] Tensor of outputs.
+ """
+ x = self.input_proj(x)
+ t = self.time_embed(t)
+ c = self.cond_embed(c)
+ y = t + c
+
+ for block in self.res_blocks:
+ x = block(x, y)
+
+ return self.final_layer(x, y)
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
new file mode 100644
index 00000000000..bdb853cd72d
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
@@ -0,0 +1,612 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from typing import Optional, Tuple, List, Union
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import PreTrainedModel, Cache, DynamicCache
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import
_prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import MoeModelOutputWithPast,
MoeCausalLMOutputWithPast
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
+from ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin
+from ainode.core.model.sundial.flow_loss import FlowLoss
+
+from safetensors.torch import load_file as load_safetensors
+from huggingface_hub import hf_hub_download
+
+from ainode.core.log import Logger
+logger = Logger()
+
+def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class SundialPatchEmbedding(nn.Module):
+ def __init__(self, config: SundialConfig):
+ super().__init__()
+ self.dropout = nn.Dropout(config.dropout_rate)
+ self.hidden_layer = nn.Linear(
+ config.input_token_len * 2, config.intermediate_size)
+ self.act = ACT2FN[config.hidden_act]
+ self.output_layer = nn.Linear(
+ config.intermediate_size, config.hidden_size)
+ self.residual_layer = nn.Linear(
+ config.input_token_len * 2, config.hidden_size)
+ self.input_token_len = config.input_token_len
+
+ def forward(self, x):
+ mask = torch.ones_like(x, dtype=torch.float32)
+ input_length = x.shape[-1]
+ padding_length = (self.input_token_len - (input_length %
+ self.input_token_len)) % self.input_token_len
+ x = F.pad(x, (padding_length, 0))
+ mask = F.pad(mask, (padding_length, 0))
+ x = x.unfold(dimension=-1, size=self.input_token_len,
+ step=self.input_token_len)
+ mask = mask.unfold(
+ dimension=-1, size=self.input_token_len, step=self.input_token_len)
+
+ x = torch.cat([x, mask], dim=-1)
+ hid = self.act(self.hidden_layer(x))
+ out = self.dropout(self.output_layer(hid))
+ res = self.residual_layer(x)
+ out = out + res
+ return out
+
+
+class SundialRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=10000, base=10000,
device=None):
+ super().__init__()
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
+ 2, dtype=torch.int64).float().to(device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device,
dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device,
+ dtype=torch.int64).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order
to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer(
+ "cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer(
+ "sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(
+ seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+class SundialAttention(nn.Module):
+ def __init__(self, config: SundialConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.attention_dropout = config.dropout_rate
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+ self.rotary_emb = SundialRotaryEmbedding(
+ self.head_dim,
max_position_embeddings=config.max_position_embeddings)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(
+ kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx)
+
+ attn_output = F.scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask,
dropout_p=(self.attention_dropout if self.training else 0.0))
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class SundialMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act:
str):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(
+ self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) *
self.up_proj(hidden_state))
+
+
+class SundialDecoderLayer(nn.Module):
+ def __init__(self, config: SundialConfig, layer_idx: int):
+ super().__init__()
+ self.self_attn = SundialAttention(config, layer_idx)
+
+ self.ffn_layer = SundialMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor,
Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.ffn_layer(hidden_states)
+ hidden_states = residual + hidden_states
+
+ if not output_attentions:
+ self_attn_weights = None
+
+ if not use_cache:
+ present_key_value = None
+ return hidden_states, self_attn_weights, present_key_value
+
+
+class SundialPreTrainedModel(PreTrainedModel):
+ config_class = SundialConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["SundialDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, torch.nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, torch.nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class SundialModel(SundialPreTrainedModel):
+ def __init__(self, config: SundialConfig):
+ super().__init__(config)
+ self.embed_layer = SundialPatchEmbedding(config)
+ self.layers = nn.ModuleList(
+ [SundialDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = torch.nn.LayerNorm(config.hidden_size)
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ input_ids: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
+ # input_ids is the input of time series, its shape is [batch_size,
seq_len]
+ output_attentions = output_attentions if output_attentions is not None
else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else
self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else
self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and
decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError(
+ "You have to specify either decoder_input_ids or
decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_layer(input_ids)
+ seq_length = inputs_embeds.shape[1]
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ past_key_values_length = 0
+
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(
+ past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(
+ seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else
inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length,
dtype=torch.long, device=device
+ )
+ # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ position_ids = position_ids.view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=None,
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2]
+
+ hidden_states = self.norm(hidden_states)
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache(
+ ) if use_legacy_cache else next_decoder_cache
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states,
all_self_attns]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin):
+ def __init__(self, config: SundialConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = SundialModel(self.config)
+ self.flow_loss = FlowLoss(self.config.output_token_lens[-1],
self.config.hidden_size,
+ self.config.flow_loss_depth,
self.config.hidden_size, self.config.num_sampling_steps)
+ # TODO: Unify data loader
+ if not os.path.exists(config.ckpt_path):
+ os.mkdir(config.ckpt_path)
+ weights_path = os.path.join(config.ckpt_path, "model.safetensors")
+ if not os.path.exists(weights_path):
+ logger.info(f"Weight not found at {weights_path}, downloading from
HuggingFace...")
+ repo_id = "thuml/sundial-base-128m"
+ try:
+ hf_hub_download(repo_id=repo_id, filename="model.safetensors",
local_dir=config.ckpt_path)
+ logger.info(f"Got weight to {weights_path}")
+ except Exception as e:
+ logger.error(f"Failed to download weight to {weights_path} due
to {e}")
+ raise e
+ state_dict = load_safetensors(weights_path)
+ self.load_state_dict(state_dict, strict=True)
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.FloatTensor] = None,
+ loss_masks: Optional[torch.FloatTensor] = None,
+ mask_y: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ max_output_length: Optional[int] = None,
+ revin: Optional[bool] = False,
+ num_samples: Optional[int] = 1,
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+
+ output_attentions = output_attentions if output_attentions is not None
else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else
self.config.use_return_dict
+
+ if revin:
+ means = input_ids.mean(1, keepdim=True).detach()
+ stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach()
+ stdev = torch.where(stdev > 1e-2, stdev, torch.tensor(1.0,
device=input_ids.device))
+ input_ids = (input_ids - means) / stdev
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0] if not return_dict else
outputs.last_hidden_state
+ predictions = None
+
+ loss = None
+ if labels is not None:
+ if revin:
+ labels = (labels - means) / stdev
+ output_token_len = self.config.output_token_lens[-1]
+ seq_len = hidden_states.shape[1] * self.config.input_token_len
+ labels = labels[:, :seq_len -
+ self.config.input_token_len + output_token_len]
+ shift_labels = labels.unfold(
+ dimension=-1, size=output_token_len,
step=self.config.input_token_len)
+
+ bsz, L, _ = shift_labels.shape
+ shift_labels = shift_labels.reshape(
+ bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
+ hidden_states = hidden_states.reshape(
+ bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1)
+ loss_masks = loss_masks.reshape(
+ bsz * L).repeat(self.config.diffusion_batch_mul)
+ mask_y = mask_y.repeat(L * self.config.diffusion_batch_mul, 1)
+
+ loss = self.flow_loss(shift_labels, hidden_states, loss_masks,
mask_y)
+ else:
+ if max_output_length is None:
+ output_token_len = self.config.output_token_lens[0]
+ max_output_length = output_token_len
+ else:
+ output_token_len = self.config.output_token_lens[0]
+ for h in self.config.output_token_lens[1:]:
+ if h > max_output_length:
+ break
+ else:
+ output_token_len = h
+
+ bsz = hidden_states.shape[0]
+ hidden_states = hidden_states[:, -1, :]
+ predictions = self.flow_loss.sample(hidden_states, num_samples)
+ if output_token_len > max_output_length:
+ predictions = predictions[:, :, :max_output_length]
+ if revin:
+ predictions = predictions * stdev + means
+ if not return_dict:
+ output = (predictions,) + outputs[1:]
+ return (loss) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ logits=predictions,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None,
inputs_embeds=None, revin=False, num_samples=1, **kwargs
+ ):
+ # Omit tokens covered by past_key_values
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ if isinstance(past_key_values, DynamicCache):
+ past_length = past_key_values.seen_tokens
+ else:
+ past_length = cache_length
+
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of
input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache
(e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] >
(input_ids.shape[1] // self.config.input_token_len):
+ input_ids = input_ids[:, -
+ (attention_mask.shape[1] - past_length)
* self.config.input_token_len:]
+ # 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
+ input_ids = input_ids[:, past_length *
+ self.config.input_token_len:]
+ # 3 - Otherwise (past_length >= (input_ids.shape[1] //
self.config.input_token_len)), let's assume input_ids only has unprocessed
tokens.
+
+ # If we are about to go beyond the maximum cache length, we need
to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + (input_ids.shape[1] //
self.config.input_token_len) > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -
+ (input_ids.shape[1] //
self.config.input_token_len):]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st
generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "revin": revin,
+ "num_samples": num_samples,
+ }
+ )
+ return model_inputs
\ No newline at end of file
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
new file mode 100644
index 00000000000..723b4a1332a
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py
@@ -0,0 +1,306 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import warnings
+from typing import Any, Dict, List, Optional, Union, Callable
+import torch
+from transformers import GenerationMixin, LogitsProcessorList,
StoppingCriteriaList
+from transformers.generation import validate_stopping_criteria,
EosTokenCriteria
+from transformers.generation.utils import GenerateNonBeamOutput,
GenerateEncoderDecoderOutput, \
+ GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
+from transformers.utils import ModelOutput
+
+
+class TSGenerationMixin(GenerationMixin):
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
List[int]]] = None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ revin: Optional[bool] = True,
+ num_samples: Optional[int] = 1,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ if len(inputs.shape) != 2:
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
+ if revin:
+ means = inputs.mean(dim=-1, keepdim=True)
+ stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5
+ inputs = (inputs - means) / stdev
+ outputs = super().generate(inputs=inputs,
generation_config=generation_config,
+ logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus,
+ assistant_model=assistant_model,
streamer=streamer,
+ negative_prompt_ids=negative_prompt_ids,
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
+ num_samples=num_samples, **kwargs)
+ if revin:
+ stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1)
+ means = means.unsqueeze(1).repeat(1, num_samples, 1)
+ outputs = (outputs * stdev) + means
+ return outputs
+
+ def _greedy_search(
+ self,
+ input_ids: torch.Tensor,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ output_logits: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ synced_gpus: bool = False,
+ streamer: Optional["BaseStreamer"] = None,
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
+ input_ids = input_ids.to(self.device)
+ batch_size, cur_len = input_ids.shape
+ # init values
+ logits_processor = logits_processor if logits_processor is not None
else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None
else StoppingCriteriaList()
+ if max_length is not None:
+ warnings.warn(
+ "`max_length` is deprecated in this function, use"
+ "
`stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`
instead.",
+ UserWarning,
+ )
+ stopping_criteria = validate_stopping_criteria(
+ stopping_criteria, max_length)
+ pad_token_id = pad_token_id if pad_token_id is not None else
self.generation_config.pad_token_id
+ if eos_token_id is not None:
+ stopping_criteria.append(
+ EosTokenCriteria(eos_token_id=eos_token_id))
+ else:
+ # remove when the method is totally private
+ # need to get `eos_token_id` and add stopping criteria, so that
generation does not go forever
+ eos_token_id = [
+ criteria.eos_token_id.tolist() for criteria in
stopping_criteria if hasattr(criteria, "eos_token_id")
+ ]
+ eos_token_id = eos_token_id[0] if eos_token_id else None
+ if eos_token_id is None and self.generation_config.eos_token_id is
not None:
+ eos_token_id = self.generation_config.eos_token_id
+ stopping_criteria.append(
+ EosTokenCriteria(eos_token_id=eos_token_id))
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ output_scores = output_scores if output_scores is not None else
self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else
self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else
self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
+
+ # init attention / hidden states / scores tuples
+ raw_logits = () if (return_dict_in_generate and output_logits) else
None
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = () if (return_dict_in_generate and
output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and
output_attentions) else None
+ decoder_hidden_states = () if (
+ return_dict_in_generate and output_hidden_states) else None
+
+ # if model is an encoder-decoder, retrieve encoder attention weights
and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get(
+ "attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get(
+ "hidden_states") if output_hidden_states else None
+ )
+
+ # keep track of which sequences are already finished
+ if "inputs_embeds" in model_kwargs:
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
+ this_peer_finished = False
+ unfinished_sequences = torch.ones(
+ batch_size, dtype=torch.long, device=input_ids.device)
+ model_kwargs["cache_position"] = torch.arange(
+ cur_len, device=input_ids.device)
+ true_seq_len = (cur_len + self.config.input_token_len - 1) //
self.config.input_token_len
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:,
-true_seq_len:]
+ max_length = stopping_criteria.max_length
+ generate_results = None
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus,
device=input_ids.device):
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(
+ input_ids, **model_kwargs)
+
+ input_length = input_ids.shape[1]
+
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ max_output_length=max_length - input_length,
+ )
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't
need
+ next_token_logits = outputs.logits
+
+ # pre-process distribution
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_tokens_scores,)
+ if output_logits:
+ raw_logits += (next_token_logits,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if
self.config.is_encoder_decoder else (
+ outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (outputs.hidden_states,)
+ )
+
+ # argmax
+ # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+ next_tokens = next_tokens_scores
+
+ # finished sentences should have their next token be a padding
token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError(
+ "If `eos_token_id` is defined, make sure that
`pad_token_id` is defined.")
+ next_tokens = next_tokens * unfinished_sequences + \
+ pad_token_id * (1 - unfinished_sequences)
+
+ # update generated ids, model inputs, and length for next step
+ horizon_length = next_tokens.shape[-1] //
self.config.input_token_len
+
+ past_key_values = model_kwargs.get("past_key_values")
+ if past_key_values is None:
+ generate_results = next_tokens
+ else:
+ generate_results = torch.cat([generate_results, next_tokens],
dim=-1)
+ input_ids = torch.cat([input_ids, next_tokens.median(dim=1)[0]],
dim=-1)
+
+ if streamer is not None:
+ streamer.put(next_tokens.cpu())
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ horizon_length=horizon_length,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
+ input_ids, scores)
+ this_peer_finished = unfinished_sequences.max() == 0
+
+ if input_ids.shape[-1] > max_length:
+ input_ids = input_ids[:, :max_length]
+
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ if self.config.is_encoder_decoder:
+ return GenerateEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return generate_results[:, :, :(max_length - cur_len)]
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ horizon_length: int = 1,
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ # update past_key_values
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+ if getattr(outputs, "state", None) is not None:
+ model_kwargs["state"] = outputs.state
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat(
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+ if not is_encoder_decoder:
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask,
attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
+ )
+ else:
+ # update decoder attention mask
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ model_kwargs["decoder_attention_mask"] = torch.cat(
+ [decoder_attention_mask, decoder_attention_mask.new_ones(
+ (decoder_attention_mask.shape[0], horizon_length))],
+ dim=-1,
+ )
+
+ if "cache_position" in model_kwargs and model_kwargs["cache_position"]
is not None:
+ model_kwargs["cache_position"] =
model_kwargs["cache_position"][-1:] + horizon_length
+
+ return model_kwargs
\ No newline at end of file
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index 70557d4f2c0..eaa418b623f 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -63,6 +63,7 @@ apache-iotdb = "2.0.4.dev0"
einops = "^0.8.1"
safetensors = "^0.5.1"
huggingface_hub = "^0.30.1"
+transformers = "==4.40.1"
[tool.poetry.scripts]
ainode = "ainode.core.script:main"
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 44093d10053..8f00fde8773 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -83,6 +83,7 @@ public class ModelInfo implements SnapshotProcessor {
builtInForecastModel.add("_STLForecaster");
builtInForecastModel.add("_HoltWinters");
builtInForecastModel.add("_ExponentialSmoothing");
+ builtInForecastModel.add("_sundial");
builtInAnomalyDetectionModel.add("_GaussianHMM");
builtInAnomalyDetectionModel.add("_GMMHMM");
builtInAnomalyDetectionModel.add("_Stray");