This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 5095862f75f [AINode]: Integrate toto as a builtin forecasting model 
(#17322)
5095862f75f is described below

commit 5095862f75f33a1f5c800666ceacd0c6a3b042e3
Author: Grace Li <[email protected]>
AuthorDate: Fri Mar 27 03:07:19 2026 -0400

    [AINode]: Integrate toto as a builtin forecasting model (#17322)
---
 NOTICE                                             |   9 +
 .../apache/iotdb/ainode/utils/AINodeTestUtils.java |   4 +-
 iotdb-core/ainode/build_binary.py                  |   8 +-
 .../ainode/iotdb/ainode/core/model/model_info.py   |  13 +
 .../iotdb/ainode/core/model/toto/__init__.py       |  17 +
 .../ainode/core/model/toto/configuration_toto.py   |  78 ++++
 .../iotdb/ainode/core/model/toto/data/__init__.py  |  20 +
 .../ainode/core/model/toto/data/util/__init__.py   |  20 +
 .../ainode/core/model/toto/data/util/dataset.py    | 127 ++++++
 .../ainode/core/model/toto/inference/__init__.py   |  20 +
 .../ainode/core/model/toto/inference/forecaster.py | 452 +++++++++++++++++++++
 .../iotdb/ainode/core/model/toto/model/__init__.py |  20 +
 .../ainode/core/model/toto/model/attention.py      | 276 +++++++++++++
 .../iotdb/ainode/core/model/toto/model/backbone.py | 258 ++++++++++++
 .../ainode/core/model/toto/model/distribution.py   | 112 +++++
 .../ainode/core/model/toto/model/embedding.py      |  83 ++++
 .../ainode/core/model/toto/model/feed_forward.py   |  35 ++
 .../iotdb/ainode/core/model/toto/model/fusion.py   |  58 +++
 .../iotdb/ainode/core/model/toto/model/rope.py     |  94 +++++
 .../iotdb/ainode/core/model/toto/model/scaler.py   | 328 +++++++++++++++
 .../iotdb/ainode/core/model/toto/model/toto.py     | 157 +++++++
 .../ainode/core/model/toto/model/transformer.py    | 318 +++++++++++++++
 .../iotdb/ainode/core/model/toto/model/util.py     | 251 ++++++++++++
 .../iotdb/ainode/core/model/toto/modeling_toto.py  | 167 ++++++++
 .../iotdb/ainode/core/model/toto/pipeline_toto.py  | 144 +++++++
 iotdb-core/ainode/pyproject.toml                   |   1 +
 26 files changed, 3066 insertions(+), 4 deletions(-)

diff --git a/NOTICE b/NOTICE
index fa52a36987f..429495c377b 100644
--- a/NOTICE
+++ b/NOTICE
@@ -17,6 +17,15 @@ grant the users the right to the use of patent under the 
requirement of Apache 2
 
 ============================================================================
 
+This product includes source code derived from the DataDog/toto project:
+
+  Toto – Timeseries-Optimized Transformer for Observability
+  Copyright 2025 Datadog, Inc.
+  Licensed under the Apache License, Version 2.0
+  https://github.com/DataDog/toto
+
+============================================================================
+
 Apache Commons Collections
 Copyright 2001-2019 The Apache Software Foundation
 
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index e41d3d4e0f9..bf758a083d4 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -58,7 +58,9 @@ public class AINodeTestUtils {
               new AbstractMap.SimpleEntry<>(
                   "chronos2", new FakeModelInfo("chronos2", "t5", "builtin", 
"active")),
               new AbstractMap.SimpleEntry<>(
-                  "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", 
"active")))
+                  "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", 
"active")),
+              new AbstractMap.SimpleEntry<>(
+                  "toto", new FakeModelInfo("toto", "toto", "builtin", 
"active")))
           .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
 
   public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;
diff --git a/iotdb-core/ainode/build_binary.py 
b/iotdb-core/ainode/build_binary.py
index c943de41581..f3b7fa1cedf 100644
--- a/iotdb-core/ainode/build_binary.py
+++ b/iotdb-core/ainode/build_binary.py
@@ -423,7 +423,7 @@ def install_dependencies(venv_python, venv_dir, script_dir):
         [str(poetry_exe), "lock"],
         cwd=str(script_dir),
         env=venv_env,
-        check=True,
+        check=False,
         capture_output=True,
         text=True,
     )
@@ -431,6 +431,9 @@ def install_dependencies(venv_python, venv_dir, script_dir):
         print(result.stdout)
     if result.stderr:
         print(result.stderr)
+    if result.returncode != 0:
+        print(f"ERROR: poetry lock failed with exit code {result.returncode}")
+        sys.exit(1)
     verify_poetry_env()  # Verify after lock
 
     accelerator = detect_accelerator()
@@ -438,11 +441,10 @@ def install_dependencies(venv_python, venv_dir, 
script_dir):
 
     print("Running poetry install...")
     subprocess.run(
-        [str(poetry_exe), "lock"],
+        [str(poetry_exe), "install", "--no-root"],
         cwd=str(script_dir),
         env=venv_env,
         check=True,
-        capture_output=True,
         text=True,
     )
     verify_poetry_env()  # Verify before install
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index da752cbd784..642986c42d2 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -160,4 +160,17 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
         },
         transformers_registered=True,
     ),
+    "toto": ModelInfo(
+        model_id="toto",
+        category=ModelCategory.BUILTIN,
+        state=ModelStates.INACTIVE,
+        model_type="toto",
+        pipeline_cls="pipeline_toto.TotoPipeline",
+        repo_id="Datadog/Toto-Open-Base-1.0",
+        auto_map={
+            "AutoConfig": "configuration_toto.TotoConfig",
+            "AutoModelForCausalLM": "modeling_toto.TotoForPrediction",
+        },
+        transformers_registered=True,
+    ),
 }
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py
new file mode 100644
index 00000000000..2a00fcc3be4
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import List, Optional
+
+from transformers import PretrainedConfig
+
+
+class TotoConfig(PretrainedConfig):
+    """
+    Configuration class for the Toto time series forecasting model.
+
+    Toto (Time Series Optimized Transformer for Observability) is a foundation 
model
+    for multivariate time series forecasting developed by Datadog. It uses a 
decoder-only
+    architecture with per-variate patch-based causal scaling, proportional 
time-variate
+    factorized attention, and a Student-T mixture prediction head.
+
+    Reference: https://github.com/DataDog/toto
+    """
+
+    model_type = "toto"
+
+    def __init__(
+        self,
+        patch_size: int = 32,
+        stride: int = 32,
+        embed_dim: int = 1024,
+        num_layers: int = 18,
+        num_heads: int = 16,
+        mlp_hidden_dim: int = 2816,
+        dropout: float = 0.0,
+        spacewise_every_n_layers: int = 3,
+        scaler_cls: str = "per_variate_causal",
+        output_distribution_classes: Optional[List[str]] = None,
+        output_distribution_kwargs: Optional[dict] = None,
+        spacewise_first: bool = True,
+        use_memory_efficient_attention: bool = True,
+        stabilize_with_global: bool = True,
+        scale_factor_exponent: float = 10.0,
+        **kwargs,
+    ):
+        self.patch_size = patch_size
+        self.stride = stride
+        self.embed_dim = embed_dim
+        self.num_layers = num_layers
+        self.num_heads = num_heads
+        self.mlp_hidden_dim = mlp_hidden_dim
+        self.dropout = dropout
+        self.spacewise_every_n_layers = spacewise_every_n_layers
+        self.scaler_cls = scaler_cls
+        self.output_distribution_classes = output_distribution_classes or [
+            "student_t_mixture"
+        ]
+        # k_components=5 is the default used by Datadog/Toto-Open-Base-1.0
+        self.output_distribution_kwargs = output_distribution_kwargs or {
+            "k_components": 5
+        }
+        self.spacewise_first = spacewise_first
+        self.use_memory_efficient_attention = use_memory_efficient_attention
+        self.stabilize_with_global = stabilize_with_global
+        self.scale_factor_exponent = scale_factor_exponent
+
+        super().__init__(**kwargs)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py
new file mode 100644
index 00000000000..6bccf35988c
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from functools import reduce
+from typing import NamedTuple
+
+import numpy as np
+import torch
+import torch.utils.data
+from einops import repeat
+from jaxtyping import Bool, Float, Int, Shaped
+
+
+def pad_array(
+    values: Shaped[torch.Tensor, "*batch variates series_len"],  # noqa: F722
+    patch_stride: int,
+) -> Shaped[torch.Tensor, "*batch variates padded_length"]:  # noqa: F722
+    """
+    Makes sure that the series length is divisible by the patch_stride
+    by adding left-padding.
+    """
+    if isinstance(values, np.ndarray):
+        values = torch.from_numpy(values)
+    series_len = values.shape[-1]
+    padded_length = int(np.ceil(series_len / patch_stride) * patch_stride)
+    if values.ndim == 2:
+        padded_values = torch.zeros((values.shape[0], padded_length), 
dtype=values.dtype, device=values.device)
+    elif values.ndim == 3:
+        padded_values = torch.zeros(
+            (values.shape[0], values.shape[1], padded_length),
+            dtype=values.dtype,
+            device=values.device,
+        )
+    else:
+        raise ValueError(f"Unsupported number of dimensions: {values.ndim}")
+    padded_values[..., -series_len:] = values
+
+    return padded_values
+
+
+def pad_id_mask(
+    id_mask: Int[torch.Tensor, "*batch variates series_len"],  # noqa: F722
+    patch_stride: int,
+) -> Int[torch.Tensor, "*batch variates padded_length"]:  # noqa: F722
+    """
+    Makes sure that the series length is divisible by the patch_stride
+    by adding left-padding to the id mask.
+    """
+    series_len = id_mask.shape[-1]
+    padded_length = int(np.ceil(series_len / patch_stride) * patch_stride)
+    padding_amount = padded_length - series_len
+    left_edge: Int[torch.Tensor, "*batch variates"] = id_mask[..., 0]  # noqa: 
F722
+    if id_mask.ndim == 2:
+        padding = repeat(
+            left_edge,
+            "variates -> variates padding_amount",
+            padding_amount=padding_amount,
+        )
+        id_mask = torch.cat([padding, id_mask], dim=1)
+    elif id_mask.ndim == 3:
+        padding = repeat(
+            left_edge,
+            "batch variates -> batch variates padding_amount",
+            padding_amount=padding_amount,
+        )
+        id_mask = torch.cat([padding, id_mask], dim=2)
+    else:
+        raise ValueError(f"Unsupported number of dimensions: {id_mask.ndim}")
+
+    return id_mask
+
+
+class MaskedTimeseries(NamedTuple):
+    series: Float[torch.Tensor, "*batch variates series_len"]  # noqa: F722
+    padding_mask: Bool[torch.Tensor, "*batch variates series_len"]  # noqa: 
F722
+    id_mask: Int[torch.Tensor, "*batch variates #series_len"]  # noqa: F722
+    timestamp_seconds: Int[torch.Tensor, "*batch variates series_len"]  # 
noqa: F722
+    time_interval_seconds: Int[torch.Tensor, "*batch variates"]  # noqa: F722
+    num_exogenous_variables: int = 0
+
+    def to(self, device: torch.device) -> "MaskedTimeseries":
+        return MaskedTimeseries(
+            series=self.series.to(device),
+            padding_mask=self.padding_mask.to(device),
+            id_mask=self.id_mask.to(device),
+            timestamp_seconds=self.timestamp_seconds.to(device),
+            time_interval_seconds=self.time_interval_seconds.to(device),
+            num_exogenous_variables=self.num_exogenous_variables,
+        )
+
+
+def is_extreme_value(t: torch.Tensor) -> torch.Tensor:
+    if torch.is_floating_point(t):
+        max_value = torch.finfo(t.dtype).max
+    else:
+        max_value = torch.iinfo(t.dtype).max
+
+    return reduce(
+        torch.logical_or,
+        (
+            torch.isinf(t),
+            torch.isnan(t),
+            t.abs() >= max_value / 2,
+        ),
+    )
+
+
+def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) -> 
torch.Tensor:
+    return torch.where(is_extreme_value(t), torch.tensor(replacement, 
dtype=t.dtype, device=t.device), t)
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py
new file mode 100644
index 00000000000..2a9db2aa629
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py
@@ -0,0 +1,452 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from dataclasses import dataclass
+from typing import cast
+
+import numpy as np
+import torch
+from einops import rearrange, repeat
+from jaxtyping import Bool, Float, Int
+from torch.distributions import Distribution, TransformedDistribution
+from torch.distributions.transforms import AffineTransform
+
+from ..data.util.dataset import (
+    MaskedTimeseries,
+    pad_array,
+    pad_id_mask,
+    replace_extreme_values,
+)
+from ..model.backbone import TotoBackbone
+
+
+class AffineTransformed(TransformedDistribution):
+    """
+    Thin wrapper around TransformedDistribution with AffineTransform,
+    replacing gluonts.torch.distributions.AffineTransformed.
+    """
+
+    def __init__(self, base_distribution, loc=0.0, scale=1.0):
+        super().__init__(base_distribution, AffineTransform(loc=loc, 
scale=scale))
+
+    @property
+    def mean(self):
+        loc = self.transforms[0].loc
+        scale = self.transforms[0].scale
+        return loc + scale * self.base_dist.mean
+
+    # Note: Do NOT override sample() here. TransformedDistribution.sample() 
correctly
+    # calls base_dist.sample() (not rsample), which works for 
non-reparameterizable
+    # distributions like MixtureSameFamily.
+
+
+@dataclass(frozen=True)
+class Forecast:
+    mean: Float[torch.Tensor, "batch variate future_time_steps"]
+    samples: Float[torch.Tensor, "batch variate future_time_steps samples"] | 
None = (
+        None
+    )
+
+    def quantile(
+        self, q: float | torch.Tensor
+    ) -> Float[torch.Tensor, "batch variate future_time_steps"]:
+        assert self.samples is not None, "samples must be provided to compute 
quantiles"
+        assert isinstance(q, (float, torch.Tensor)), "q must be a float or a 
tensor"
+        if isinstance(q, float):
+            q = torch.tensor(q, device=self.samples.device, 
dtype=self.samples.dtype)
+        return self.samples.quantile(q, dim=-1)
+
+    @property
+    def median(self) -> Float[torch.Tensor, "batch variate future_time_steps"]:
+        return self.quantile(0.5)
+
+    @property
+    def std(self) -> Float[torch.Tensor, "batch variate future_time_steps"]:
+        assert (
+            self.samples is not None
+        ), "samples must be provided to compute standard deviation"
+        return self.samples.std(dim=-1)
+
+
+class TotoForecaster:
+    """
+    A forecaster class for the Toto model that handles autoregressive decoding
+    for time series forecasting.
+    """
+
+    model: TotoBackbone
+
+    def __init__(self, model: TotoBackbone):
+        self.model = model
+        self.model.eval()
+
+    def forecast(
+        self,
+        inputs: MaskedTimeseries,
+        prediction_length: int,
+        num_samples: int | None = None,
+        samples_per_batch: int = 10,
+        use_kv_cache: bool = True,
+        future_exogenous_variables: (
+            Float[torch.Tensor, "batch exogenous_variables future_time_steps"] 
| None
+        ) = None,
+    ) -> Forecast:
+        if len(inputs.series.shape) == 2:
+            batch = cast(MaskedTimeseries, 
torch.utils.data.default_collate([inputs]))
+        else:
+            batch = inputs
+
+        if (
+            future_exogenous_variables is not None
+            and len(future_exogenous_variables.shape) == 2
+        ):
+            future_exogenous_variables = 
future_exogenous_variables.unsqueeze(0)
+
+        series = pad_array(batch.series, self.model.patch_embed.stride)
+        padding_mask = pad_array(batch.padding_mask, 
self.model.patch_embed.stride)
+        id_mask = batch.id_mask
+        if id_mask is not None:
+            id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride)
+        timestamp_seconds = pad_array(
+            batch.timestamp_seconds, self.model.patch_embed.stride
+        )
+        time_interval_seconds: Int[torch.Tensor, "batch variate series_len"] = 
(
+            torch.as_tensor(
+                batch.time_interval_seconds, device=series.device, 
dtype=torch.int
+            )
+        )
+
+        if num_samples is not None:
+            samples = self.generate_samples(
+                inputs=series,
+                prediction_length=prediction_length,
+                num_samples=num_samples,
+                timestamp_seconds=timestamp_seconds,
+                time_interval_seconds=time_interval_seconds,
+                input_padding_mask=padding_mask,
+                id_mask=id_mask,
+                sampling_batch_size=samples_per_batch,
+                use_kv_cache=use_kv_cache,
+                future_exogenous_variables=future_exogenous_variables,
+                num_exogenous_variables=batch.num_exogenous_variables,
+            )
+            mean = samples.mean(dim=-1)
+        else:
+            mean = self.generate_mean(
+                inputs=series,
+                prediction_length=prediction_length,
+                timestamp_seconds=timestamp_seconds,
+                time_interval_seconds=time_interval_seconds,
+                input_padding_mask=padding_mask,
+                id_mask=id_mask,
+                use_kv_cache=use_kv_cache,
+                future_exogenous_variables=future_exogenous_variables,
+                num_exogenous_variables=batch.num_exogenous_variables,
+            )
+            samples = None
+
+        return Forecast(mean=mean, samples=samples)
+
+    def assert_ev_compatibility(
+        self,
+        inputs,
+        future_exogenous_variables,
+        prediction_length,
+        num_exogenous_variables,
+    ) -> None:
+        assert future_exogenous_variables.shape[-1] == prediction_length
+        assert future_exogenous_variables.shape[0] == inputs.shape[0]
+        assert num_exogenous_variables == future_exogenous_variables.shape[-2]
+
+    def round_ft_ev(self, future_exogenous_variables, T_rounded):
+        B, V_ev, T_future = future_exogenous_variables.shape
+        dtype = future_exogenous_variables.dtype
+        device = future_exogenous_variables.device
+        padding = torch.zeros(B, V_ev, T_rounded - T_future, device=device, 
dtype=dtype)
+        return torch.cat([future_exogenous_variables, padding], dim=-1)
+
+    @torch.no_grad()
+    def generate_mean(
+        self,
+        inputs: Float[torch.Tensor, "batch variate time_steps"],
+        prediction_length: int,
+        timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"],
+        time_interval_seconds: Int[torch.Tensor, "batch variate"],
+        input_padding_mask: (
+            Bool[torch.Tensor, "batch variate time_steps"] | None
+        ) = None,
+        id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = 
None,
+        use_kv_cache: bool = False,
+        future_exogenous_variables=None,
+        num_exogenous_variables: int = 0,
+    ) -> Float[torch.Tensor, "batch variate time_steps"]:
+        if input_padding_mask is None:
+            input_padding_mask = torch.ones_like(
+                inputs, dtype=torch.bool, device=inputs.device
+            )
+        if id_mask is None:
+            id_mask = torch.zeros_like(inputs, dtype=torch.int, 
device=inputs.device)
+
+        if future_exogenous_variables is not None:
+            self.assert_ev_compatibility(
+                inputs,
+                future_exogenous_variables,
+                prediction_length,
+                num_exogenous_variables,
+            )
+
+        patch_size = self.model.patch_embed.stride
+        rounded_steps = int(np.ceil(prediction_length / patch_size) * 
patch_size)
+        if rounded_steps > prediction_length and future_exogenous_variables is 
not None:
+            future_exogenous_variables = self.round_ft_ev(
+                future_exogenous_variables, rounded_steps
+            )
+        start_index = inputs.shape[-1]
+        end_index = start_index + prediction_length
+
+        dummy_padding = torch.ones(
+            (input_padding_mask.shape[0], input_padding_mask.shape[1], 
patch_size),
+            device=inputs.device,
+            dtype=torch.bool,
+        )
+        dummy_id_mask = repeat(
+            id_mask[:, :, -1:],
+            "batch variates 1 -> batch variates patch_size",
+            patch_size=patch_size,
+        )
+        if use_kv_cache:
+            kv_cache = self.model.allocate_kv_cache(
+                batch_size=inputs.shape[0],
+                num_variates=inputs.shape[1],
+                max_time_steps=inputs.shape[2] + rounded_steps,
+                device=inputs.device,
+                dtype=inputs.dtype,
+            )
+        else:
+            kv_cache = None
+
+        scaling_prefix_length = inputs.shape[-1]
+
+        for idx in range(rounded_steps // patch_size):
+            base_distr, loc, scale = self.model(
+                inputs=inputs,
+                input_padding_mask=input_padding_mask,
+                id_mask=id_mask,
+                kv_cache=kv_cache,
+                scaling_prefix_length=scaling_prefix_length,
+                num_exogenous_variables=num_exogenous_variables,
+            )
+            distr = self.create_affine_transformed(base_distr, loc, scale)
+
+            samples = replace_extreme_values(distr.mean[:, :, -patch_size:])
+
+            if future_exogenous_variables is not None:
+                start, stop = idx * patch_size, (idx + 1) * patch_size
+                samples[:, -num_exogenous_variables:] = 
future_exogenous_variables[
+                    :, :, start:stop
+                ]
+
+            inputs = torch.cat([inputs, samples], dim=-1)
+            id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1)
+            input_padding_mask = torch.cat([input_padding_mask, 
dummy_padding], dim=-1)
+            for _ in range(patch_size):
+                next_timestamp = timestamp_seconds[:, :, -1] + 
time_interval_seconds
+                timestamp_seconds = torch.cat(
+                    [timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1
+                )
+
+        return inputs.detach()[:, :, start_index:end_index]
+
+    @torch.no_grad()
+    def generate_samples(
+        self,
+        inputs: Float[torch.Tensor, "batch variate time_steps"],
+        prediction_length: int,
+        num_samples: int,
+        timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"],
+        time_interval_seconds: Int[torch.Tensor, "batch variate"],
+        input_padding_mask: (
+            Bool[torch.Tensor, "batch variate time_steps"] | None
+        ) = None,
+        id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None = 
None,
+        sampling_batch_size: int = 10,
+        use_kv_cache: bool = False,
+        future_exogenous_variables=None,
+        num_exogenous_variables: int = 0,
+    ) -> Float[torch.Tensor, "batch variate time_steps samples"]:
+        if input_padding_mask is None:
+            input_padding_mask = torch.ones_like(
+                inputs, dtype=torch.bool, device=inputs.device
+            )
+        if id_mask is None:
+            id_mask = torch.zeros_like(inputs, dtype=torch.int, 
device=inputs.device)
+
+        if future_exogenous_variables is not None:
+            self.assert_ev_compatibility(
+                inputs,
+                future_exogenous_variables,
+                prediction_length,
+                num_exogenous_variables,
+            )
+
+        assert (
+            num_samples % sampling_batch_size == 0
+        ), "num_samples must be divisible by sampling_batch_size"
+        num_batches = num_samples // sampling_batch_size
+
+        patch_size = self.model.patch_embed.patch_size
+        rounded_steps = int(np.ceil(prediction_length / patch_size) * 
patch_size)
+        if rounded_steps > prediction_length and future_exogenous_variables is 
not None:
+            future_exogenous_variables = self.round_ft_ev(
+                future_exogenous_variables, rounded_steps
+            )
+        start_index = inputs.shape[-1]
+        end_index = start_index + prediction_length
+
+        dummy_padding = torch.ones(
+            (
+                input_padding_mask.shape[0] * sampling_batch_size,
+                input_padding_mask.shape[1],
+                patch_size,
+            ),
+            dtype=torch.bool,
+            device=inputs.device,
+        )
+        dummy_id_mask = repeat(
+            id_mask[:, :, -1:],
+            "batch variates 1 -> (sampling_batch_size batch) variates 
patch_size",
+            sampling_batch_size=sampling_batch_size,
+            patch_size=patch_size,
+        )
+        inputs = repeat(
+            inputs,
+            "batch variates seq_len -> (sampling_batch_size batch) variates 
seq_len",
+            sampling_batch_size=sampling_batch_size,
+        )
+        if future_exogenous_variables is not None:
+            future_exogenous_variables = repeat(
+                future_exogenous_variables,
+                "batch exogenous_variables future_time_steps -> 
(sampling_batch_size batch) exogenous_variables future_time_steps",
+                sampling_batch_size=sampling_batch_size,
+            )
+        input_padding_mask = repeat(
+            input_padding_mask,
+            "batch variates seq_len -> (sampling_batch_size batch) variates 
seq_len",
+            sampling_batch_size=sampling_batch_size,
+        )
+        id_mask = repeat(
+            id_mask,
+            "batch variates seq_len -> (sampling_batch_size batch) variates 
seq_len",
+            sampling_batch_size=sampling_batch_size,
+        )
+        timestamp_seconds = repeat(
+            timestamp_seconds,
+            "batch variates seq_len -> (sampling_batch_size batch) variates 
seq_len",
+            sampling_batch_size=sampling_batch_size,
+        )
+        time_interval_seconds = repeat(
+            time_interval_seconds,
+            "batch variates -> (sampling_batch_size batch) variates",
+            sampling_batch_size=sampling_batch_size,
+        )
+
+        all_samples = []
+        if use_kv_cache:
+            kv_cache = self.model.allocate_kv_cache(
+                batch_size=inputs.shape[0],
+                num_variates=inputs.shape[1],
+                max_time_steps=inputs.shape[2] + rounded_steps,
+                device=inputs.device,
+                dtype=inputs.dtype,
+            )
+        else:
+            kv_cache = None
+
+        scaling_prefix_length = inputs.shape[-1]
+
+        for _ in range(num_batches):
+            batch_inputs = torch.clone(inputs)
+            batch_input_padding_mask = torch.clone(input_padding_mask)
+            batch_id_mask = torch.clone(id_mask)
+            batch_timestamp_seconds = torch.clone(timestamp_seconds)
+
+            for idx in range(rounded_steps // patch_size):
+                base_distr, loc, scale = self.model(
+                    inputs=batch_inputs,
+                    input_padding_mask=batch_input_padding_mask,
+                    id_mask=batch_id_mask,
+                    kv_cache=kv_cache,
+                    scaling_prefix_length=scaling_prefix_length,
+                    num_exogenous_variables=num_exogenous_variables,
+                )
+                distr = self.create_affine_transformed(base_distr, loc, scale)
+
+                sample = distr.sample()
+                assert sample is not None
+
+                samples = replace_extreme_values(sample[:, :, -patch_size:])
+
+                if future_exogenous_variables is not None:
+                    start, stop = idx * patch_size, (idx + 1) * patch_size
+                    samples[:, -num_exogenous_variables:] = 
future_exogenous_variables[
+                        :, :, start:stop
+                    ]
+                batch_inputs = torch.cat([batch_inputs, samples], dim=-1)
+                batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask], 
dim=-1)
+                batch_input_padding_mask = torch.cat(
+                    [batch_input_padding_mask, dummy_padding], dim=-1
+                )
+                for _ in range(patch_size):
+                    next_timestamp = (
+                        batch_timestamp_seconds[:, :, -1] + 
time_interval_seconds
+                    )
+                    batch_timestamp_seconds = torch.cat(
+                        [batch_timestamp_seconds, 
next_timestamp.unsqueeze(-1)], dim=-1
+                    )
+            all_samples.append(batch_inputs)
+            if kv_cache is not None:
+                kv_cache.reset()
+
+        outputs = torch.cat(all_samples, dim=0)
+        unfolded_outputs = rearrange(
+            outputs,
+            "(samples batch) variates seq_len -> batch variates seq_len 
samples",
+            samples=num_samples,
+        ).detach()
+
+        return unfolded_outputs[:, :, start_index:end_index, :]
+
+    @staticmethod
+    def create_affine_transformed(
+        base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor
+    ) -> Distribution:
+        base_shape = base_distr.mean.shape
+        base_time_dim = base_shape[-1]
+        loc_time_dim = loc.shape[-1]
+
+        if loc_time_dim == 1:
+            return AffineTransformed(base_distr, loc=loc, scale=scale)
+
+        return AffineTransformed(
+            base_distr,
+            loc=loc[:, :, -base_time_dim:],
+            scale=scale[:, :, -base_time_dim:],
+        )
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py
new file mode 100644
index 00000000000..80f6d381ff2
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py
@@ -0,0 +1,276 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import logging
+import warnings
+from enum import Enum
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from einops import rearrange
+from jaxtyping import Bool, Float, Int
+
+from .rope import TimeAwareRotaryEmbedding
+
+if TYPE_CHECKING:
+    from .util import KVCache
+
+log = logging.getLogger(__name__)
+
+try:
+    from xformers.ops import LowerTriangularMask, memory_efficient_attention
+
+    XFORMERS_AVAILABLE = True
+    log.info("xFormers Memory-Efficient Attention available.")
+except ImportError:
+    warnings.warn(
+        "xFormers Memory-Efficient Attention not available. "
+        "Falling back to native PyTorch scaled_dot_product_attention.",
+        ImportWarning,
+    )
+
+    XFORMERS_AVAILABLE = False
+
+from torch.nn.functional import scaled_dot_product_attention
+
+
+class AttentionAxis(Enum):
+    TIME = 1
+    SPACE = 2
+
+
+class BaseMultiheadAttention(torch.nn.Module):
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        dropout: float,
+        rotary_emb: Optional[TimeAwareRotaryEmbedding],
+        use_memory_efficient_attention: bool,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        assert (
+            embed_dim % num_heads == 0
+        ), "Embedding dimension must be divisible by number of heads."
+        self.head_dim = embed_dim // num_heads
+        self.rotary_emb = rotary_emb
+
+        self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
+        self.dropout = dropout
+        self.use_memory_efficient_attention = use_memory_efficient_attention
+        self.wO = torch.nn.Linear(embed_dim, embed_dim)
+
+        assert not (
+            not XFORMERS_AVAILABLE and self.use_memory_efficient_attention
+        ), "XFORMERS_AVAILABLE is False, so use_memory_efficient_attention 
must be False"
+
+        if not hasattr(self, "attention_axis") or self.attention_axis not in (
+            AttentionAxis.TIME,
+            AttentionAxis.SPACE,
+        ):
+            raise ValueError(
+                "Child class must define attention_axis as AttentionAxis.TIME 
or AttentionAxis.SPACE."
+            )
+
+    def rearrange_inputs(
+        self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"]
+    ) -> Float[torch.Tensor, "... embed_dim"]:
+        pattern = (
+            "batch variate seq_len embed_dim -> (batch variate) seq_len 
embed_dim"
+            if self.attention_axis == AttentionAxis.TIME
+            else "batch variate seq_len embed_dim -> (batch seq_len) variate 
embed_dim"
+        )
+        return rearrange(inputs, pattern)
+
+    def get_qkv(self, inputs: torch.Tensor) -> tuple[torch.Tensor, ...]:
+        if (
+            self.attention_axis == AttentionAxis.TIME
+            and self.use_memory_efficient_attention
+        ):
+            pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv 
batch_X_variate seq_len n_heads head_dim"
+        elif (
+            self.attention_axis == AttentionAxis.TIME
+            and not self.use_memory_efficient_attention
+        ):
+            pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv 
batch_X_variate n_heads seq_len head_dim"
+        elif (
+            self.attention_axis == AttentionAxis.SPACE
+            and self.use_memory_efficient_attention
+        ):
+            pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv 
batch_X_seq_len variate n_heads head_dim"
+        elif (
+            self.attention_axis == AttentionAxis.SPACE
+            and not self.use_memory_efficient_attention
+        ):
+            pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv 
batch_X_seq_len n_heads variate head_dim"
+
+        qkv = self.wQKV(inputs.contiguous())
+        return rearrange(
+            qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads
+        ).unbind(dim=0)
+
+    def positional_embedding(self, q, k, v, kv_cache, layer_idx):
+        seq_pos_offset = 0
+        if self.rotary_emb is not None and self.attention_axis == 
AttentionAxis.TIME:
+            if kv_cache is not None:
+                seq_pos_offset = kv_cache.seq_len(layer_idx)
+            q, k = self.rotary_emb.rotate_queries_and_keys(
+                q, k, seq_pos_offset=seq_pos_offset
+            )
+
+        if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
+            kv_cache.append(layer_idx, (k, v))
+            k, v = kv_cache[layer_idx]
+
+        q = q.contiguous()
+        k = k.contiguous().to(q.dtype)
+        v = v.contiguous().to(q.dtype)
+
+        return q, k, v, seq_pos_offset
+
+    def rearrange_output(
+        self, output: torch.Tensor, batch: int, variate: int, seq_len: int
+    ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+        if (
+            self.attention_axis == AttentionAxis.TIME
+            and self.use_memory_efficient_attention
+        ):
+            pattern = "(batch variate) seq_len n_heads head_dim -> batch 
variate seq_len (n_heads head_dim)"
+        elif (
+            self.attention_axis == AttentionAxis.TIME
+            and not self.use_memory_efficient_attention
+        ):
+            pattern = "(batch variate) n_heads seq_len head_dim -> batch 
variate seq_len (n_heads head_dim)"
+        elif (
+            self.attention_axis == AttentionAxis.SPACE
+            and self.use_memory_efficient_attention
+        ):
+            pattern = "(batch seq_len) variate n_heads head_dim -> batch 
variate seq_len (n_heads head_dim)"
+        elif (
+            self.attention_axis == AttentionAxis.SPACE
+            and not self.use_memory_efficient_attention
+        ):
+            pattern = "(batch seq_len) n_heads variate head_dim -> batch 
variate seq_len (n_heads head_dim)"
+
+        return rearrange(output, pattern, batch=batch, variate=variate, 
seq_len=seq_len)
+
+    def run_attention(
+        self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, 
variate
+    ):
+        q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
+        kv_dim_start, kv_dim_end = 0, (
+            v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
+        )
+        if (
+            self.attention_axis == AttentionAxis.TIME
+            and self.use_memory_efficient_attention
+        ):
+            attention_mask = (
+                attention_mask[..., q_dim_start:q_dim_end, 
kv_dim_start:kv_dim_end]
+                if torch.is_tensor(attention_mask)
+                else LowerTriangularMask() if seq_pos_offset == 0 else None
+            )
+            return memory_efficient_attention(
+                q, k, v, attn_bias=attention_mask, p=dropout
+            )
+        elif (
+            self.attention_axis == AttentionAxis.TIME
+            and not self.use_memory_efficient_attention
+        ):
+            attention_mask = (
+                attention_mask[..., q_dim_start:q_dim_end, 
kv_dim_start:kv_dim_end]
+                if torch.is_tensor(attention_mask)
+                else None
+            )
+            return scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=attention_mask,
+                dropout_p=dropout,
+                is_causal=(attention_mask is None and seq_pos_offset == 0),
+            )
+        elif (
+            self.attention_axis == AttentionAxis.SPACE
+            and self.use_memory_efficient_attention
+        ):
+            attention_mask = (
+                attention_mask[..., kv_dim_start:kv_dim_end, 
kv_dim_start:kv_dim_end]
+                if torch.is_tensor(attention_mask)
+                else None
+            )
+            return memory_efficient_attention(
+                q, k, v, attn_bias=attention_mask, p=dropout
+            )
+        elif (
+            self.attention_axis == AttentionAxis.SPACE
+            and not self.use_memory_efficient_attention
+        ):
+            attention_mask = (
+                attention_mask[..., kv_dim_start:kv_dim_end, 
kv_dim_start:kv_dim_end]
+                if torch.is_tensor(attention_mask)
+                else None
+            )
+            return scaled_dot_product_attention(
+                q, k, v, attn_mask=attention_mask, dropout_p=dropout, 
is_causal=False
+            )
+
+    def forward(
+        self,
+        layer_idx: int,
+        inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+        attention_mask: Optional[
+            Union[
+                Bool[torch.Tensor, "batch_X_variate n_heads seq_len seq_len"],
+                Bool[torch.Tensor, "batch_X_seq_len n_heads variate variate"],
+            ]
+        ] = None,
+        kv_cache: Optional["KVCache"] = None,
+    ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+        batch_size, variate, seq_len, _ = inputs.shape
+        dropout = self.dropout if self.training else 0.0
+
+        rearranged_inputs = self.rearrange_inputs(inputs)
+        q, k, v = self.get_qkv(rearranged_inputs)
+
+        q, k, v, seq_pos_offset = self.positional_embedding(
+            q, k, v, kv_cache, layer_idx
+        )
+
+        output = self.run_attention(
+            attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate
+        )
+
+        output = self.rearrange_output(output, batch_size, variate, seq_len)
+        return self.wO(output)
+
+
+class TimeWiseMultiheadAttention(BaseMultiheadAttention):
+    attention_axis = AttentionAxis.TIME
+
+
+class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
+    attention_axis = AttentionAxis.SPACE
+
+
+MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py
new file mode 100644
index 00000000000..84fa537e3fc
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py
@@ -0,0 +1,258 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from math import ceil
+from typing import NamedTuple, Optional, Type, cast
+
+import torch
+from einops import rearrange, repeat
+from jaxtyping import Bool, Float, Int
+
+from .distribution import DISTRIBUTION_CLASSES_LOOKUP, DistributionOutput
+from .embedding import PatchEmbedding
+from .fusion import Fusion
+from .scaler import scaler_types
+from .transformer import Transformer
+from .util import KVCache
+
+
+class TotoOutput(NamedTuple):
+    """
+    Output of the Toto model. Contains the output distribution, the location 
parameters,
+    and the scale parameters.
+    """
+
+    distribution: torch.distributions.Distribution
+    loc: Float[torch.Tensor, "batch variate"]
+    scale: Float[torch.Tensor, "batch variate"]
+
+
+class TotoBackbone(torch.nn.Module):
+    """
+    Toto (Timeseries-Optimized Transformer for Observability) is a 
transformer-based model
+    for multivariate time series forecasting.
+    """
+
+    def __init__(
+        self,
+        patch_size: int,
+        stride: int,
+        embed_dim: int,
+        num_layers: int,
+        num_heads: int,
+        mlp_hidden_dim: int,
+        dropout: float,
+        spacewise_every_n_layers: int,
+        scaler_cls: str,
+        output_distribution_classes: list[str],
+        spacewise_first: bool = True,
+        output_distribution_kwargs: dict | None = None,
+        use_memory_efficient_attention: bool = True,
+        stabilize_with_global: bool = True,
+        scale_factor_exponent: float = 10.0,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.fusion: Optional[Fusion] = None
+        self.num_prepended_tokens: int = 0
+        self.target_variate_label: Optional[torch.nn.Parameter] = None
+        self.exogenous_variate_label: Optional[torch.nn.Parameter] = None
+
+        if scaler_cls in (
+            "<class 'model.scaler.CausalPatchStdMeanScaler'>",
+            "per_variate_causal_patch",
+        ):
+            self.scaler = scaler_types[scaler_cls](
+                patch_size=patch_size,
+                stabilize_with_global=stabilize_with_global,
+                scale_factor_exponent=scale_factor_exponent,
+            )
+        else:
+            self.scaler = scaler_types[scaler_cls]()
+
+        self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim)
+        self.dropout = dropout
+        self.num_layers = num_layers
+        self.use_memory_efficient_attention = use_memory_efficient_attention
+        self.transformer = Transformer(
+            embed_dim=embed_dim,
+            num_heads=num_heads,
+            num_layers=self.num_layers,
+            mlp_hidden_dim=mlp_hidden_dim,
+            dropout=dropout,
+            spacewise_every_n_layers=spacewise_every_n_layers,
+            spacewise_first=spacewise_first,
+            use_memory_efficient_attention=self.use_memory_efficient_attention,
+            fusion=self.fusion,
+        )
+        self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size)
+
+        output_distribution_classes_ = [
+            DISTRIBUTION_CLASSES_LOOKUP[c] for c in output_distribution_classes
+        ]
+        self.output_distribution = output_distribution_classes_[0](
+            embed_dim, **(output_distribution_kwargs or {})
+        )
+
+    def allocate_kv_cache(
+        self,
+        batch_size: int,
+        num_variates: int,
+        max_time_steps: int,
+        device: torch.device,
+        dtype: torch.dtype,
+    ) -> KVCache:
+        return KVCache(
+            batch_size=batch_size,
+            num_variates=num_variates,
+            transformer_layers=list(self.transformer.layers),
+            num_layers=self.num_layers,
+            embed_dim=self.embed_dim,
+            num_heads=cast(int, self.transformer.layers[0].num_heads),
+            max_seq_len=ceil(max_time_steps / self.patch_embed.stride),
+            device=device,
+            dtype=dtype,
+            use_memory_efficient_attention=self.use_memory_efficient_attention,
+        )
+
+    def backbone(
+        self,
+        inputs: Float[torch.Tensor, "batch variate time_steps"],
+        input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"],
+        id_mask: Float[torch.Tensor, "batch #variate time_steps"],
+        kv_cache: Optional[KVCache] = None,
+        scaling_prefix_length: Optional[int] = None,
+        num_exogenous_variables: int = 0,
+    ) -> tuple[
+        Float[torch.Tensor, "batch variates time_steps embed_dim"],
+        Float[torch.Tensor, "batch variates time_steps"],
+        Float[torch.Tensor, "batch variates time_steps"],
+    ]:
+        scaled_inputs, loc, scale = self.scaler(
+            inputs,
+            weights=torch.ones_like(inputs, device=inputs.device),
+            padding_mask=input_padding_mask,
+            prefix_length=scaling_prefix_length,
+        )
+
+        if kv_cache is not None:
+            kv_cache_len_tensor = kv_cache.current_len(0)
+            kv_cache_len = (
+                int(kv_cache_len_tensor)
+                if isinstance(kv_cache_len_tensor, torch.Tensor)
+                else kv_cache_len_tensor
+            )
+            prefix_len = max(
+                0, self.patch_embed.stride * (kv_cache_len - 
self.num_prepended_tokens)
+            )
+
+            scaled_inputs = scaled_inputs[:, :, prefix_len:]
+
+            assert (prefix_len == 0) or (
+                scaled_inputs.shape[-1] == self.patch_embed.stride
+            ), "Must decode one step at a time."
+
+            input_padding_mask = input_padding_mask[:, :, prefix_len:]
+            id_mask = id_mask[:, :, prefix_len:]
+
+        embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask)
+
+        variate_label_embeds = self.build_variate_label_embeds(
+            num_exogenous_variables, embeddings
+        )
+
+        original_seq_len = embeddings.shape[2]
+        transformed = self.transformer(
+            embeddings,
+            reduced_id_mask,
+            kv_cache,
+            variate_label_embeds=variate_label_embeds,
+        )
+        added_tokens = transformed.shape[2] - original_seq_len
+        if added_tokens > 0:
+            transformed = transformed[:, :, added_tokens:]
+
+        flattened: Float[torch.Tensor, "batch variates new_seq_len embed_dim"] 
= (
+            rearrange(
+                self.unembed(transformed),
+                "batch variates seq_len (patch_size embed_dim) -> batch 
variates (seq_len patch_size) embed_dim",
+                embed_dim=self.embed_dim,
+            )
+        )
+        return flattened, loc, scale
+
+    def forward(
+        self,
+        inputs: Float[torch.Tensor, "batch variate time_steps"],
+        input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"],
+        id_mask: Float[torch.Tensor, "batch #variate time_steps"],
+        kv_cache: Optional[KVCache] = None,
+        scaling_prefix_length: Optional[int] = None,
+        num_exogenous_variables: int = 0,
+    ) -> TotoOutput:
+        flattened, loc, scale = self.backbone(
+            inputs,
+            input_padding_mask,
+            id_mask,
+            kv_cache,
+            scaling_prefix_length,
+            num_exogenous_variables,
+        )
+
+        return TotoOutput(self.output_distribution(flattened), loc, scale)
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
+
+    def enable_variate_labels(self) -> None:
+        self.fusion = Fusion()
+        self.num_prepended_tokens = 1
+        self.target_variate_label = 
torch.nn.Parameter(torch.randn(self.embed_dim))
+        self.exogenous_variate_label = 
torch.nn.Parameter(torch.randn(self.embed_dim))
+        if hasattr(self, "transformer") and self.transformer is not None:
+            self.transformer.fusion = self.fusion
+
+    def build_variate_label_embeds(
+        self,
+        num_exogenous_variables: int,
+        embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+    ) -> Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]]:
+        if self.fusion is None:
+            return None
+
+        assert self.target_variate_label is not None
+        assert self.exogenous_variate_label is not None
+
+        batch_size, num_variates, _, _ = embeddings.shape
+
+        target_variate_label = repeat(
+            self.target_variate_label, "d -> b v 1 d", b=batch_size, 
v=num_variates
+        ).to(device=embeddings.device, dtype=embeddings.dtype)
+        exogenous_variate_label = repeat(
+            self.exogenous_variate_label, "d -> b v 1 d", b=batch_size, 
v=num_variates
+        ).to(device=embeddings.device, dtype=embeddings.dtype)
+        exog_mask = torch.zeros(
+            1, num_variates, 1, 1, dtype=torch.bool, device=embeddings.device
+        )
+        if num_exogenous_variables > 0:
+            exog_mask[:, -num_exogenous_variables:] = True
+        return torch.where(exog_mask, exogenous_variate_label, 
target_variate_label)
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py
new file mode 100644
index 00000000000..f34bd4afdf0
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from abc import ABC
+
+import torch
+import torch.nn.functional as F
+from torch.distributions import TransformedDistribution
+from torch.distributions.transforms import AffineTransform
+
+
+class AffineTransformed(TransformedDistribution):
+    """
+    A thin wrapper around TransformedDistribution with an AffineTransform,
+    replacing the gluonts.torch.distributions.AffineTransformed dependency.
+    Provides the same interface: mean, variance, sample(), log_prob().
+    """
+
+    def __init__(self, base_distribution, loc=0.0, scale=1.0):
+        super().__init__(base_distribution, AffineTransform(loc=loc, 
scale=scale))
+
+    @property
+    def mean(self):
+        # mean(aX + b) = a * mean(X) + b
+        loc = self.transforms[0].loc
+        scale = self.transforms[0].scale
+        return loc + scale * self.base_dist.mean
+
+    # Note: Do NOT override sample() here. TransformedDistribution.sample() 
correctly
+    # calls base_dist.sample() (not rsample), which works for 
non-reparameterizable
+    # distributions like MixtureSameFamily.
+
+
+class DistributionOutput(ABC, torch.nn.Module):
+    pass
+
+
+class StudentTOutput(DistributionOutput):
+    def __init__(self, embed_dim):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.df = torch.nn.Linear(embed_dim, 1)
+        self.loc_proj = torch.nn.Linear(embed_dim, 1)
+        self.scale_proj = torch.nn.Linear(embed_dim, 1)
+
+    def forward(self, inputs, loc=None, scale=None):
+        eps = torch.finfo(inputs.dtype).eps
+        df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
+        base_loc = self.loc_proj(inputs).squeeze(-1)
+        base_scale = 
F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
+
+        base_dist = torch.distributions.StudentT(
+            df, base_loc, base_scale, validate_args=False
+        )
+
+        if loc is not None and scale is not None:
+            return AffineTransformed(base_dist, loc=loc, scale=scale)
+        return base_dist
+
+
+class MixtureOfStudentTsOutput(DistributionOutput):
+    def __init__(self, embed_dim, k_components):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.k_components = k_components
+
+        self.df = torch.nn.Linear(embed_dim, k_components)
+        self.loc_proj = torch.nn.Linear(embed_dim, k_components)
+        self.scale_proj = torch.nn.Linear(embed_dim, k_components)
+        self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
+
+    def forward(self, inputs, loc=None, scale=None):
+        df = 2.0 + 
F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
+        component_loc = self.loc_proj(inputs)
+        component_scale = F.softplus(self.scale_proj(inputs)).clamp_min(
+            torch.finfo(inputs.dtype).eps
+        )
+        logits = self.mixture_weights(inputs)
+        probs = F.softmax(logits, dim=-1)
+        components = torch.distributions.StudentT(
+            df, component_loc, component_scale, validate_args=False
+        )
+        mixture_distribution = torch.distributions.Categorical(probs=probs)
+
+        return torch.distributions.MixtureSameFamily(mixture_distribution, 
components)
+
+
+DISTRIBUTION_CLASSES_LOOKUP = {
+    "<class 'model.distribution.StudentTOutput'>": StudentTOutput,
+    "<class 'model.distribution.MixtureOfStudentTsOutput'>": 
MixtureOfStudentTsOutput,
+    # Short-form aliases for convenience
+    "student_t": StudentTOutput,
+    "student_t_mixture": MixtureOfStudentTsOutput,
+}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py
new file mode 100644
index 00000000000..fc7eadac9af
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from typing import Optional
+
+import torch
+from jaxtyping import Float, Int, Num
+
+
+def patchify_id_mask(
+    id_mask: Int[torch.Tensor, "batch variate time_steps"], patch_size: int
+) -> Int[torch.Tensor, "batch variate seq_len patch_size"]:
+    patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, 
step=patch_size)
+    patched_id_mask_min = patched_id_mask.min(-1).values
+    patched_id_mask_max = patched_id_mask.max(-1).values
+    assert torch.eq(
+        patched_id_mask_min, patched_id_mask_max
+    ).all(), "Patches cannot span multiple datasets"
+    return patched_id_mask_min
+
+
+class PatchEmbedding(torch.nn.Module):
+    """
+    Multivariate time series patch embedding.
+    Patchifies each variate separately.
+    """
+
+    def __init__(self, patch_size: int, stride: int, embed_dim: int):
+        super().__init__()
+        self.patch_size = patch_size
+        self.embed_dim = embed_dim
+        self.stride = stride
+        self.projection = torch.nn.Linear(self.patch_size, self.embed_dim)
+
+    def _patchify(
+        self, x: Num[torch.Tensor, "batch variate time_steps"]
+    ) -> Num[torch.Tensor, "batch variate seq_len patch_size"]:
+        return x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
+
+    def forward(
+        self,
+        x: Float[torch.Tensor, "batch #variate time_steps"],
+        id_mask: Float[torch.Tensor, "batch time_steps"],
+    ) -> tuple[
+        Float[torch.Tensor, "batch variate seq_len embed_dim"],
+        Int[torch.Tensor, "batch seq_len"],
+    ]:
+        assert (
+            x.shape[-1] % self.patch_size == 0
+        ), f"Series length ({x.shape=}) must be divisible by 
({self.patch_size=})"
+        x_patched: Float[torch.Tensor, "batch variate seq_len patch_size"] = (
+            self._patchify(x)
+        )
+        id_mask_patched: Int[torch.Tensor, "batch variate seq_len patch_size"] 
= (
+            self._patchify(id_mask)
+        )
+
+        assert torch.eq(
+            id_mask_patched.min(-1).values, id_mask_patched.max(-1).values
+        ).all(), "Patches cannot span multiple datasets"
+
+        return (
+            self.projection(x_patched),
+            id_mask_patched.min(-1).values,
+        )
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py
new file mode 100644
index 00000000000..024a8bed727
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import torch
+import torch.nn.functional as F
+
+
+class SwiGLU(torch.nn.Module):
+    """
+    https://arxiv.org/abs/2002.05202
+    NOTE: x should be 2x the size you want
+    """
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # Note this ordering is unusual, but is done so to match xFormers
+        gate, x = x.chunk(2, dim=-1)
+        return F.silu(gate) * x
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py
new file mode 100644
index 00000000000..cfe364ac91e
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from jaxtyping import Float
+
+
+class Fusion(torch.nn.Module):
+    """
+    Prepends variate label embeddings to the input embeddings along the 
sequence dimension.
+    """
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def forward(
+        self,
+        embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+        variate_label_embeds: Optional[
+            Float[torch.Tensor, "batch variate 1 embed_dim"]
+        ] = None,
+    ) -> Float[torch.Tensor, "batch variate new_seq_len embed_dim"]:
+
+        if variate_label_embeds is None:
+            return embeddings
+
+        processed_embeddings = F.normalize(variate_label_embeds, p=2, dim=-1)
+
+        return torch.cat(
+            [
+                processed_embeddings.to(
+                    dtype=embeddings.dtype, device=embeddings.device, 
non_blocking=True
+                ),
+                embeddings,
+            ],
+            dim=2,
+        )
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py
new file mode 100644
index 00000000000..96e62517077
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from typing import Optional
+
+import torch
+from einops import rearrange
+from jaxtyping import Int
+from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
+from rotary_embedding_torch.rotary_embedding_torch import default
+
+
+def exists(val):
+    return val is not None
+
+
+class TimeAwareRotaryEmbedding(RotaryEmbedding):
+    """
+    A variant of the rotary position embedding that (optionally) uses the time 
index
+    to compute the sinusoidal and cosine embeddings. Useful for time series 
data.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # If the parent stored `freqs` as a Parameter, remove it and register 
as a buffer
+        if hasattr(self, "freqs") and isinstance(self.freqs, 
torch.nn.Parameter):
+            freqs_data = self.freqs.data
+            self._parameters.pop("freqs")
+            self.register_buffer("freqs", freqs_data, persistent=False)
+
+    def rotate_queries_and_keys(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        seq_dim: Optional[int] = None,
+        seq_pos: Optional[Int[torch.Tensor, "... seq_len"]] = None,
+        seq_pos_offset: int = 0,
+    ):
+        if seq_dim is None:
+            seq_dim = self.default_seq_dim
+
+        assert self.use_xpos
+        device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
+
+        seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, 
device=device))
+        seq = seq + seq_pos_offset
+
+        freqs = self.forward(seq)
+
+        scale = self.get_scale(seq).to(dtype)
+
+        if seq_dim == -3:
+            num_heads = q.shape[-2]
+            freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
+            scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
+
+        rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
+        rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, 
seq_dim=seq_dim)
+
+        rotated_q = rotated_q.type(q.dtype)
+        rotated_k = rotated_k.type(k.dtype)
+
+        return rotated_q, rotated_k
+
+    def get_scale(
+        self,
+        t: torch.Tensor,
+    ):
+        assert self.use_xpos
+
+        power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
+
+        scale = self.scale ** rearrange(power, "... n -> ... n 1")
+        scale = torch.cat((scale, scale), dim=-1)
+
+        return scale
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py
new file mode 100644
index 00000000000..e640e3ef3a2
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import warnings
+from typing import Tuple
+
+import torch
+from einops import reduce, repeat
+
+
+class Scaler(torch.nn.Module):
+    """
+    Minimal base class replacing gluonts.torch.scaler.Scaler.
+    Provides a __call__ interface for scaling data.
+    """
+
+    pass
+
+
+class StdMeanScaler(Scaler):
+    """
+    Scales data to have zero mean and unit variance along a given dimension.
+    """
+
+    def __init__(
+        self,
+        dim: int = -1,
+        keepdim: bool = True,
+        minimum_scale: float = 1e-3,
+    ) -> None:
+        super().__init__()
+        self.dim = dim
+        self.keepdim = keepdim
+        self.minimum_scale = minimum_scale
+
+    def __call__(
+        self,
+        data: torch.Tensor,
+        padding_mask: torch.Tensor,
+        weights: torch.Tensor,
+        prefix_length: int | None = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        assert data.shape == weights.shape, "data and weights must have same 
shape"
+        with torch.no_grad():
+            if prefix_length is not None:
+                prefix_mask = torch.zeros_like(weights)
+                prefix_mask[..., :prefix_length] = 1.0
+                weights = weights * prefix_mask
+
+            weights = weights * padding_mask
+
+            try:
+                high_precision_data = data.to(torch.float64)
+            except TypeError:
+                warnings.warn(
+                    f"Float64 is not supported by device {data.device}. "
+                    "Using float32 instead for accumulating denominator in 
input scaler. "
+                    "This may lead to overflow issues if the data contains 
extreme values.",
+                    RuntimeWarning,
+                )
+                high_precision_data = data.to(torch.float32)
+
+            denominator = (
+                weights.sum(self.dim, keepdim=self.keepdim)
+                .clamp_min(1.0)
+                .to(high_precision_data.dtype)
+            )
+            means = (high_precision_data * weights).sum(
+                self.dim, keepdim=self.keepdim
+            ) / denominator
+            means = torch.nan_to_num(means)
+
+            variance = (((high_precision_data - means) * weights) ** 2).sum(
+                self.dim, keepdim=self.keepdim
+            ) / denominator
+            scale = torch.sqrt(variance + self.minimum_scale).to(data.dtype)
+            loc = means.to(data.dtype)
+
+            return (data - loc) / scale, loc, scale
+
+
+def compute_causal_statistics(
+    data: torch.Tensor,
+    weights: torch.Tensor,
+    padding_mask: torch.Tensor,
+    dim: int,
+    minimum_scale: float,
+    use_bessel_correction: bool = True,
+    stabilize_with_global: bool = False,
+    scale_factor_exponent: float = 10.0,
+    prefix_length: int | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    assert dim == -1, "compute_causal_statistics only supports dim=-1 (last 
dimension)"
+
+    with torch.no_grad():
+        weights = weights * padding_mask
+
+        try:
+            high_precision_data = data.to(torch.float64)
+            high_precision_weights = weights.to(torch.float64)
+        except TypeError:
+            warnings.warn(
+                f"Float64 is not supported by device {data.device}. "
+                "Using float32 instead for causal scaler calculations.",
+                RuntimeWarning,
+            )
+            high_precision_data = data.to(torch.float32)
+            high_precision_weights = weights.to(torch.float32)
+
+        prev_deterministic = torch.are_deterministic_algorithms_enabled()
+        if prev_deterministic and data.device.type == "cuda":
+            torch.use_deterministic_algorithms(False)
+
+        try:
+            weighted_data = high_precision_weights * high_precision_data
+
+            cum_weights = torch.cumsum(high_precision_weights, dim=dim)
+            cum_values = torch.cumsum(weighted_data, dim=dim)
+
+            denominator = cum_weights.clamp_min(1.0)
+            causal_means = cum_values / denominator
+
+            shifted_means = torch.zeros_like(causal_means)
+            shifted_means[..., 1:] = causal_means[..., :-1]
+
+            delta = high_precision_data - shifted_means
+            increment = (
+                delta * (high_precision_data - causal_means) * 
high_precision_weights
+            )
+            m_2 = torch.cumsum(increment, dim=dim)
+
+            if use_bessel_correction:
+                causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0)
+            else:
+                causal_variance = m_2 / denominator
+
+            causal_scale = torch.sqrt(causal_variance + minimum_scale)
+
+            if stabilize_with_global:
+                if prefix_length is not None:
+                    prefix_mask = torch.zeros_like(weights)
+                    prefix_mask[..., :prefix_length] = 1.0
+                    weighted_data = weighted_data * prefix_mask
+                    weights = weights * prefix_mask
+                    padding_mask = padding_mask * prefix_mask
+
+                scale_factor_min = 10.0 ** (-scale_factor_exponent)
+                scale_factor_max = 10.0**scale_factor_exponent
+
+                global_denominator = (
+                    (weights * padding_mask).sum(dim, 
keepdim=True).clamp_min(1.0)
+                )
+                global_means = (weighted_data).sum(
+                    dim, keepdim=True
+                ) / global_denominator
+                global_means = torch.nan_to_num(global_means)
+
+                global_variance = (
+                    ((high_precision_data - global_means) * weights * 
padding_mask) ** 2
+                ).sum(dim, keepdim=True) / global_denominator
+                global_scale = torch.sqrt(global_variance + minimum_scale)
+
+                expanded_global_scale = global_scale.expand_as(causal_scale)
+                min_allowed_scale = expanded_global_scale * scale_factor_min
+                max_allowed_scale = expanded_global_scale * scale_factor_max
+
+                causal_scale = torch.clamp(
+                    causal_scale,
+                    min=torch.max(
+                        torch.tensor(minimum_scale, 
device=causal_scale.device),
+                        min_allowed_scale,
+                    ),
+                    max=max_allowed_scale,
+                )
+
+            causal_means = causal_means.to(data.dtype)
+            causal_scale = causal_scale.to(data.dtype)
+
+        finally:
+            if prev_deterministic and data.device.type == "cuda":
+                torch.use_deterministic_algorithms(True)
+
+        return causal_means, causal_scale
+
+
+class CausalStdMeanScaler(Scaler):
+    def __init__(
+        self,
+        dim: int = -1,
+        minimum_scale: float = 0.1,
+        use_bessel_correction: bool = True,
+        stabilize_with_global: bool = False,
+        scale_factor_exponent: float = 10.0,
+    ) -> None:
+        super().__init__()
+        assert dim == -1, "CausalStdMeanScaler only supports dim=-1 (last 
dimension)"
+        self.dim = dim
+        self.minimum_scale = minimum_scale
+        self.use_bessel_correction = use_bessel_correction
+        self.stabilize_with_global = stabilize_with_global
+        self.scale_factor_exponent = scale_factor_exponent
+
+    def __call__(
+        self,
+        data: torch.Tensor,
+        padding_mask: torch.Tensor,
+        weights: torch.Tensor,
+        prefix_length: int | None = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        assert data.shape == weights.shape, "data and weights must have same 
shape"
+        assert (
+            len(data.shape) == 3
+        ), "Input data must have shape [batch, variates, time_steps]"
+
+        causal_means, causal_scale = compute_causal_statistics(
+            data,
+            weights,
+            padding_mask,
+            self.dim,
+            self.minimum_scale,
+            self.use_bessel_correction,
+            self.stabilize_with_global,
+            self.scale_factor_exponent,
+            prefix_length,
+        )
+
+        scaled_data = (data - causal_means) / causal_scale
+
+        return scaled_data, causal_means, causal_scale
+
+
+class CausalPatchStdMeanScaler(Scaler):
+    def __init__(
+        self,
+        dim: int = -1,
+        patch_size: int = 32,
+        minimum_scale: float = 0.1,
+        use_bessel_correction: bool = True,
+        stabilize_with_global: bool = False,
+        scale_factor_exponent: float = 10.0,
+    ) -> None:
+        super().__init__()
+        assert (
+            dim == -1
+        ), "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)"
+        self.dim = dim
+        self.patch_size = patch_size
+        self.minimum_scale = minimum_scale
+        self.use_bessel_correction = use_bessel_correction
+        self.stabilize_with_global = stabilize_with_global
+        self.scale_factor_exponent = scale_factor_exponent
+
+    def __call__(
+        self,
+        data: torch.Tensor,
+        padding_mask: torch.Tensor,
+        weights: torch.Tensor,
+        prefix_length: int | None = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        assert data.shape == weights.shape, "data and weights must have same 
shape"
+        assert (
+            len(data.shape) == 3
+        ), "Input data must have shape [batch, variates, time_steps]"
+
+        with torch.no_grad():
+            time_steps = data.shape[-1]
+            assert (
+                time_steps % self.patch_size == 0
+            ), f"Time steps ({time_steps}) must be divisible by patch size 
({self.patch_size})"
+
+            causal_means, causal_scale = compute_causal_statistics(
+                data,
+                weights,
+                padding_mask,
+                -1,
+                self.minimum_scale,
+                self.use_bessel_correction,
+                self.stabilize_with_global,
+                self.scale_factor_exponent,
+                prefix_length,
+            )
+
+            means_unfolded = causal_means.unfold(-1, self.patch_size, 
self.patch_size)
+            scales_unfolded = causal_scale.unfold(-1, self.patch_size, 
self.patch_size)
+
+            patch_stats_means = means_unfolded[..., -1]
+            patch_stats_scales = scales_unfolded[..., -1]
+
+            patch_means = repeat(
+                patch_stats_means, "b v p -> b v (p s)", s=self.patch_size
+            )
+            patch_scales = repeat(
+                patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size
+            )
+
+            scaled_data = (data - patch_means) / patch_scales
+
+            return scaled_data, patch_means, patch_scales
+
+
+# for deserialization of SafeTensors checkpoints
+scaler_types = {
+    "<class 'model.scaler.StdMeanScaler'>": StdMeanScaler,
+    "<class 'model.scaler.CausalStdMeanScaler'>": CausalStdMeanScaler,
+    "<class 'model.scaler.CausalPatchStdMeanScaler'>": 
CausalPatchStdMeanScaler,
+    # Short aliases used in config.json
+    "per_variate": StdMeanScaler,
+    "per_variate_causal": CausalStdMeanScaler,
+    "per_variate_causal_patch": CausalPatchStdMeanScaler,
+}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py
new file mode 100644
index 00000000000..61595334171
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import json
+import os
+import re
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+import safetensors.torch as safetorch
+import torch
+
+from .attention import XFORMERS_AVAILABLE
+from .backbone import TotoBackbone
+from .transformer import XFORMERS_SWIGLU_AVAILABLE
+
+
+class Toto(torch.nn.Module):
+    """
+    PyTorch module for Toto (Timeseries-Optimized Transformer for 
Observability).
+    This class is used internally for checkpoint loading logic.
+    """
+
+    def __init__(
+        self,
+        patch_size: int,
+        stride: int,
+        embed_dim: int,
+        num_layers: int,
+        num_heads: int,
+        mlp_hidden_dim: int,
+        dropout: float,
+        spacewise_every_n_layers: int,
+        scaler_cls: str,
+        output_distribution_classes: list[str],
+        spacewise_first: bool = True,
+        output_distribution_kwargs: dict | None = None,
+        use_memory_efficient_attention: bool = True,
+        stabilize_with_global: bool = True,
+        scale_factor_exponent: float = 10.0,
+        **model_kwargs,
+    ):
+        super().__init__()
+        self.model = TotoBackbone(
+            patch_size=patch_size,
+            stride=stride,
+            embed_dim=embed_dim,
+            num_layers=num_layers,
+            num_heads=num_heads,
+            mlp_hidden_dim=mlp_hidden_dim,
+            dropout=dropout,
+            spacewise_every_n_layers=spacewise_every_n_layers,
+            scaler_cls=scaler_cls,
+            output_distribution_classes=output_distribution_classes,
+            spacewise_first=spacewise_first,
+            output_distribution_kwargs=output_distribution_kwargs,
+            use_memory_efficient_attention=use_memory_efficient_attention,
+            stabilize_with_global=stabilize_with_global,
+            scale_factor_exponent=scale_factor_exponent,
+        )
+
+    @classmethod
+    def load_from_checkpoint(
+        cls,
+        checkpoint_path,
+        map_location: str = "cpu",
+        strict=True,
+        **model_kwargs,
+    ):
+        if os.path.isdir(checkpoint_path):
+            safetensors_file = os.path.join(checkpoint_path, 
"model.safetensors")
+        else:
+            safetensors_file = checkpoint_path
+
+        if os.path.exists(safetensors_file):
+            model_state = safetorch.load_file(safetensors_file, 
device=map_location)
+        else:
+            raise FileNotFoundError(
+                f"Model checkpoint not found at: {safetensors_file}"
+            )
+
+        config_file = os.path.join(checkpoint_path, "config.json")
+        config = {}
+        if os.path.exists(config_file):
+            with open(config_file, "r") as f:
+                config = json.load(f)
+
+        config.update(model_kwargs)
+
+        remapped_state_dict = cls._map_state_dict_keys(
+            model_state,
+            XFORMERS_SWIGLU_AVAILABLE
+            and not config.get("pre_xformers_checkpoint", False),
+        )
+
+        if not XFORMERS_AVAILABLE and config.get(
+            "use_memory_efficient_attention", True
+        ):
+            config["use_memory_efficient_attention"] = False
+
+        instance = cls(**config)
+        instance.to(map_location)
+
+        filtered_remapped_state_dict = {
+            k: v
+            for k, v in remapped_state_dict.items()
+            if k in instance.state_dict() and not 
k.endswith("rotary_emb.freqs")
+        }
+
+        instance.load_state_dict(filtered_remapped_state_dict, strict=strict)
+        return instance
+
+    @staticmethod
+    def _map_state_dict_keys(state_dict, use_fused_swiglu):
+        if use_fused_swiglu:
+            remap_keys = {
+                "mlp.0.weight": "mlp.0.w12.weight",
+                "mlp.0.bias": "mlp.0.w12.bias",
+                "mlp.2.weight": "mlp.0.w3.weight",
+                "mlp.2.bias": "mlp.0.w3.bias",
+            }
+        else:
+            remap_keys = {
+                "mlp.0.w12.weight": "mlp.0.weight",
+                "mlp.0.w12.bias": "mlp.0.bias",
+                "mlp.0.w3.weight": "mlp.2.weight",
+                "mlp.0.w3.bias": "mlp.2.bias",
+            }
+
+        def replace_key(text):
+            for pattern, replacement in remap_keys.items():
+                text = re.sub(pattern, replacement, text)
+            return text
+
+        return {replace_key(k): v for k, v in state_dict.items()}
+
+    @property
+    def device(self):
+        return next(self.model.parameters()).device
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py
new file mode 100644
index 00000000000..58220c30e62
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py
@@ -0,0 +1,318 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import warnings
+from typing import Literal, Optional, Union, cast
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from jaxtyping import Bool, Float, Int
+from rotary_embedding_torch import RotaryEmbedding
+
+from .attention import (
+    AttentionAxis,
+    MultiHeadAttention,
+    SpaceWiseMultiheadAttention,
+    TimeWiseMultiheadAttention,
+)
+from .feed_forward import SwiGLU
+from .fusion import Fusion
+from .rope import TimeAwareRotaryEmbedding
+from .util import KVCache, RMSNorm, make_batched_block_mask
+
+try:
+    from xformers.ops.swiglu_op import SwiGLU as SwiGLU_fused
+
+    XFORMERS_SWIGLU_AVAILABLE = True
+except ImportError:
+    warnings.warn(
+        "xFormers fused SwiGLU kernel not found. "
+        "Using native PyTorch implementation for feed-forward layers.",
+        ImportWarning,
+    )
+    XFORMERS_SWIGLU_AVAILABLE = False
+
+
+class TransformerLayer(torch.nn.Module):
+    embed_dim: int
+    num_heads: int
+    mlp_hidden_dim: int
+    dropout: float
+    attention_axis: AttentionAxis
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        mlp_hidden_dim: int,
+        dropout: float,
+        rotary_emb: RotaryEmbedding = None,
+        attention_axis: AttentionAxis = AttentionAxis.TIME,
+        RMS_norm: bool = True,
+        use_memory_efficient_attention: bool = True,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.mlp_hidden_dim = mlp_hidden_dim
+        self.dropout = dropout
+        self.attention_axis = attention_axis
+
+        if RMS_norm:
+            self.norm1: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim)
+            self.norm2: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim)
+        else:
+            self.norm1 = torch.nn.LayerNorm(embed_dim)
+            self.norm2 = torch.nn.LayerNorm(embed_dim)
+
+        self.attention: MultiHeadAttention
+
+        if attention_axis == AttentionAxis.TIME:
+            self.attention = TimeWiseMultiheadAttention(
+                embed_dim=embed_dim,
+                num_heads=num_heads,
+                dropout=dropout,
+                rotary_emb=rotary_emb,
+                use_memory_efficient_attention=use_memory_efficient_attention,
+            )
+        elif attention_axis == AttentionAxis.SPACE:
+            self.attention = SpaceWiseMultiheadAttention(
+                embed_dim=embed_dim,
+                num_heads=num_heads,
+                dropout=dropout,
+                rotary_emb=None,
+                use_memory_efficient_attention=use_memory_efficient_attention,
+            )
+        else:
+            raise ValueError("Invalid attention axis")
+
+        if XFORMERS_SWIGLU_AVAILABLE:
+            self.mlp = torch.nn.Sequential(
+                SwiGLU_fused(in_features=embed_dim, 
hidden_features=mlp_hidden_dim),
+                torch.nn.Dropout(dropout),
+            )
+        else:
+            self.mlp = torch.nn.Sequential(
+                torch.nn.Linear(embed_dim, 2 * mlp_hidden_dim),
+                SwiGLU(),
+                torch.nn.Linear(mlp_hidden_dim, embed_dim),
+                torch.nn.Dropout(dropout),
+            )
+
+    def forward(
+        self,
+        layer_idx: int,
+        inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+        attention_mask: Optional[
+            Union[
+                Bool[torch.Tensor, "batch seq_len variate variate"],
+                Bool[torch.Tensor, "batch #variate seq_len seq_len"],
+            ]
+        ] = None,
+        kv_cache: Optional[KVCache] = None,
+    ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+        pre_norm_1 = self.norm1(inputs)
+        hidden_state = (
+            inputs
+            + self.attention(
+                layer_idx, pre_norm_1, attention_mask, kv_cache
+            ).contiguous()
+        )
+
+        pre_norm_2 = self.norm2(hidden_state)
+        return hidden_state + self.mlp(pre_norm_2)
+
+
+class Transformer(torch.nn.Module):
+    def __init__(
+        self,
+        num_layers: int,
+        embed_dim: int,
+        num_heads: int,
+        mlp_hidden_dim: int,
+        dropout: float,
+        spacewise_every_n_layers: int,
+        spacewise_first: bool,
+        use_memory_efficient_attention: bool = True,
+        *,
+        fusion: Optional[Fusion] = None,
+    ):
+        super().__init__()
+
+        assert (
+            embed_dim % num_heads == 0
+        ), "Embedding dimension must be divisible by number of heads."
+
+        self.rotary_emb = TimeAwareRotaryEmbedding(
+            embed_dim // num_heads,
+            use_xpos=True,
+            cache_if_possible=True,
+            seq_before_head_dim=use_memory_efficient_attention,
+        )
+        attention_axes = self._get_layer_types(
+            num_layers, spacewise_every_n_layers, spacewise_first
+        )
+
+        self.use_memory_efficient_attention = use_memory_efficient_attention
+        self.fusion = fusion
+
+        self.layers = torch.nn.ModuleList(
+            [
+                TransformerLayer(
+                    embed_dim=embed_dim,
+                    num_heads=num_heads,
+                    mlp_hidden_dim=mlp_hidden_dim,
+                    dropout=dropout,
+                    rotary_emb=self.rotary_emb,
+                    attention_axis=attention_axes[i],
+                    
use_memory_efficient_attention=self.use_memory_efficient_attention,
+                )
+                for i in range(num_layers)
+            ]
+        )
+
+    def _get_mask(
+        self,
+        num_heads: int,
+        dtype: torch.dtype,
+        id_mask: Optional[torch.Tensor] = None,
+    ) -> Union[
+        Bool[torch.Tensor, "batch num_heads seq_len seq_len"],
+        Float[torch.Tensor, "batch num_heads seq_len seq_len"],
+        Bool[torch.Tensor, "batch num_heads variate variate"],
+        Float[torch.Tensor, "batch num_heads variate variate"],
+    ]:
+        if id_mask is None:
+            raise ValueError("id_mask must be provided for spacewise masks.")
+
+        mask = make_batched_block_mask(id_mask.transpose(-1, -2))
+
+        if self.use_memory_efficient_attention:
+            mask = self._pad_to_multiple(mask)
+        mask = (
+            mask.float()
+            .masked_fill(~mask, float("-inf"))
+            .masked_fill(mask, 0.0)
+            .to(dtype)
+        )
+
+        mask = rearrange(
+            mask,
+            "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1 
variate2",
+        )
+        return mask.expand(-1, num_heads, -1, -1).contiguous()
+
+    def _pad_to_multiple(
+        self,
+        tensor: torch.Tensor,
+        multiple: int = 8,
+        causal: bool = False,
+    ) -> torch.Tensor:
+        pad_amount = (multiple - tensor.shape[-1] % multiple) % multiple
+        if pad_amount > 0:
+            new_size = tensor.shape[-1] + pad_amount
+            if causal:
+                full_mask = torch.tril(
+                    torch.ones(
+                        (new_size, new_size), dtype=tensor.dtype, 
device=tensor.device
+                    )
+                )
+                full_mask[: tensor.shape[-1], : tensor.shape[-1]] = tensor
+                tensor = full_mask
+            else:
+                tensor = F.pad(tensor, (0, pad_amount, 0, pad_amount))
+        return tensor
+
+    def _get_layer_types(
+        self,
+        num_layers: int,
+        spacewise_every_n_layers: int,
+        spacewise_first: bool,
+    ) -> list[AttentionAxis]:
+        if spacewise_every_n_layers == -1:
+            return [AttentionAxis.TIME] * num_layers
+        assert num_layers % spacewise_every_n_layers == 0
+
+        block = [AttentionAxis.TIME] * (spacewise_every_n_layers - 1)
+
+        if spacewise_first:
+            block = [AttentionAxis.SPACE] + block
+        else:
+            block = block + [AttentionAxis.SPACE]
+
+        return block * (num_layers // spacewise_every_n_layers)
+
+    def forward(
+        self,
+        inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+        id_mask: Float[torch.Tensor, "batch #variate seq_len"],
+        kv_cache: Optional[KVCache] = None,
+        variate_label_embeds: Optional[
+            Float[torch.Tensor, "batch variate 1 embed_dim"]
+        ] = None,
+    ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+
+        if self.fusion is not None and variate_label_embeds is not None:
+            should_apply_fusion = True
+            if kv_cache is not None:
+                kv_len_tensor = kv_cache.current_len(0)
+                kv_len = (
+                    int(kv_len_tensor)
+                    if isinstance(kv_len_tensor, torch.Tensor)
+                    else kv_len_tensor
+                )
+                should_apply_fusion = kv_len == 0
+            if should_apply_fusion:
+                inputs = self.fusion(inputs, 
variate_label_embeds=variate_label_embeds)
+
+        batch, _, seq_len, _ = inputs.shape
+
+        if id_mask is not None and id_mask.shape[-1] != seq_len:
+            added = int(seq_len - id_mask.shape[-1])
+            if added > 0:
+                pad_slice = id_mask[..., :1]
+                id_mask = torch.cat([pad_slice.expand(-1, -1, added), 
id_mask], dim=-1)
+
+        seq_len = (kv_cache.seq_len(1) if kv_cache else 0) + seq_len
+
+        num_heads: int = cast(int, self.layers[0].num_heads)
+
+        timewise_attention_mask = None
+
+        spacewise_attention_mask = self._get_mask(
+            num_heads=num_heads,
+            dtype=inputs.dtype,
+            id_mask=id_mask,
+        )
+
+        for layer_idx, layer in enumerate(self.layers):
+            inputs = layer(
+                layer_idx,
+                inputs,
+                (
+                    timewise_attention_mask
+                    if layer.attention_axis == AttentionAxis.TIME
+                    else spacewise_attention_mask
+                ),
+                kv_cache,
+            )
+        return inputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py
new file mode 100644
index 00000000000..d913329e7e8
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py
@@ -0,0 +1,251 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import warnings
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, TypeAlias, Union
+
+import torch
+from einops import rearrange
+from jaxtyping import Float, Int
+
+from .attention import TimeWiseMultiheadAttention
+
+if TYPE_CHECKING:
+    from .transformer import TransformerLayer
+
+try:
+    from xformers import _is_triton_available
+    from xformers.ops.rmsnorm import rms_norm, rms_norm_add
+
+    XFORMERS_RMSNORM_AVAILABLE = True
+except ImportError:
+    warnings.warn(
+        "xFormers fused RMSNorm implementation not available. Will not use "
+        "optimized kernel for inference.",
+        ImportWarning,
+    )
+
+    def _is_triton_available():
+        return False
+
+    XFORMERS_RMSNORM_AVAILABLE = False
+
+
+class RMSNorm(torch.nn.Module):
+    """
+    Wraps xFormers' rms_norm for eval/frozen mode, and does a Python fallback 
for train mode.
+    """
+
+    def __init__(self, dim: int, include_weight: bool = True, eps: float = 
1e-8):
+        super(RMSNorm, self).__init__()
+        self.eps = eps
+        if include_weight:
+            self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter(
+                torch.ones(dim)
+            )
+        else:
+            self.scale = None
+
+    def forward(self, x: torch.Tensor):
+        if (
+            (
+                (not self.training)
+                or (self.scale is not None and not self.scale.requires_grad)
+            )
+            and XFORMERS_RMSNORM_AVAILABLE
+            and _is_triton_available()
+        ):
+            return rms_norm(x, self.scale, self.eps)
+
+        x_normed = x / torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + 
self.eps)
+        return x_normed if self.scale is None else x_normed * self.scale
+
+    def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor):
+        if (not self.training) or (
+            self.scale is not None and not self.scale.requires_grad
+        ):
+            return rms_norm_add(x, y, self.scale, self.eps)
+        return self.forward(x + y)
+
+
+def make_batched_block_mask(t: torch.Tensor) -> torch.Tensor:
+    unsqueezed = rearrange(t, "... d -> ... 1 d")
+    return unsqueezed == unsqueezed.transpose(-1, -2)
+
+
+K: TypeAlias = Float[
+    torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"
+]
+V: TypeAlias = Float[
+    torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"
+]
+KV: TypeAlias = tuple[K, V]
+
+
+@dataclass
+class KVCache:
+    """
+    Key/Value cache for storing intermediate attention values during multistep 
inference.
+    Only stores KV cache for timewise layers, skipping spacewise layers.
+    """
+
+    batch_size: int
+    num_variates: int
+    transformer_layers: List["TransformerLayer"]
+    num_layers: int
+    embed_dim: int
+    num_heads: int
+    max_seq_len: int
+    device: torch.device = torch.device("cpu")
+    dtype: torch.dtype = torch.float32
+    use_memory_efficient_attention: bool = True
+
+    _keys: Union[
+        Float[
+            torch.Tensor,
+            "time_layer_count batch_size_X_num_variates max_seq_len num_heads 
head_dim",
+        ],
+        Float[
+            torch.Tensor,
+            "time_layer_count batch_size_X_num_variates num_heads max_seq_len 
head_dim",
+        ],
+    ] = field(init=False)
+
+    _values: Union[
+        Float[
+            torch.Tensor,
+            "time_layer_count batch_size_X_num_variates max_seq_len num_heads 
head_dim",
+        ],
+        Float[
+            torch.Tensor,
+            "time_layer_count batch_size_X_num_variates num_heads max_seq_len 
head_dim",
+        ],
+    ] = field(init=False)
+
+    _current_idx: Int[torch.Tensor, "time_layer_count"] = field(init=False)
+    _layer_cache_map: Int[torch.Tensor, "num_layers"] = field(init=False)
+
+    def __post_init__(self):
+        assert (
+            self.embed_dim % self.num_heads == 0
+        ), "embed_dim must be divisible by num_heads"
+        head_dim = self.embed_dim // self.num_heads
+
+        time_layer_indices = [
+            i
+            for i in range(self.num_layers)
+            if isinstance(
+                self.transformer_layers[i].attention, 
TimeWiseMultiheadAttention
+            )
+        ]
+
+        time_layer_count = max(1, len(time_layer_indices))
+        if self.use_memory_efficient_attention:
+            shape = (
+                time_layer_count,
+                self.batch_size * self.num_variates,
+                self.max_seq_len,
+                self.num_heads,
+                head_dim,
+            )
+        else:
+            shape = (
+                time_layer_count,
+                self.batch_size * self.num_variates,
+                self.num_heads,
+                self.max_seq_len,
+                head_dim,
+            )
+        self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
+        self._values = torch.zeros_like(self._keys)
+        self._current_idx = torch.zeros(
+            time_layer_count, device=self.device, dtype=torch.int
+        )
+        self._layer_cache_map = torch.zeros(
+            (self.num_layers,), dtype=torch.int, device=self.device
+        )
+        for cache_idx, layer_idx in enumerate(time_layer_indices):
+            self._layer_cache_map[layer_idx] = int(cache_idx)
+
+    def __getitem__(self, layer_idx: int) -> KV:
+        cache_idx = int(self._layer_cache_map[layer_idx].item())
+        end_idx = int(self._current_idx[cache_idx].item())
+
+        if self.use_memory_efficient_attention:
+            return (
+                self._keys[cache_idx, :, :end_idx, :, :],
+                self._values[cache_idx, :, :end_idx, :, :],
+            )
+        else:
+            return (
+                self._keys[cache_idx, :, :, :end_idx, :],
+                self._values[cache_idx, :, :, :end_idx, :],
+            )
+
+    def current_len(self, cache_idx: int) -> int:
+        return (
+            int(self._current_idx[cache_idx].item())
+            if self._current_idx.numel() > 0
+            else 0
+        )
+
+    def seq_len(self, layer_idx: int) -> int:
+        cache_idx = int(self._layer_cache_map[layer_idx].item())
+        return self.current_len(cache_idx)
+
+    def append(self, layer_idx: int, kv: KV):
+        cache_idx = int(self._layer_cache_map[layer_idx].item())
+        keys, values = kv
+
+        assert keys.shape == values.shape, "keys and values must have the same 
shape"
+        assert (
+            keys.shape[0] == self.batch_size * self.num_variates
+        ), "keys and values must have batch_size * num_variates as their first 
dimension"
+
+        if self.use_memory_efficient_attention:
+            assert keys.shape[2] == self.num_heads
+        else:
+            assert keys.shape[1] == self.num_heads
+        assert keys.shape[3] == self.embed_dim // self.num_heads
+
+        start_idx = self._current_idx[cache_idx]
+        if self.use_memory_efficient_attention:
+            end_idx = start_idx + keys.shape[1]
+        else:
+            end_idx = start_idx + keys.shape[2]
+        assert (
+            end_idx <= self.max_seq_len
+        ), f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: 
{keys.shape}"
+
+        if self.use_memory_efficient_attention:
+            self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
+            self._values[cache_idx, :, start_idx:end_idx, :, :] = values
+        else:
+            self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
+            self._values[cache_idx, :, :, start_idx:end_idx, :] = values
+
+        self._current_idx[cache_idx] = end_idx
+
+    def reset(self):
+        self._keys.zero_()
+        self._values.zero_()
+        self._current_idx.zero_()
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py
new file mode 100644
index 00000000000..08fda1c3c72
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py
@@ -0,0 +1,167 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import json
+import os
+
+import safetensors.torch as safetorch
+from transformers import PreTrainedModel
+
+from iotdb.ainode.core.log import Logger
+
+from .configuration_toto import TotoConfig
+from .model.attention import XFORMERS_AVAILABLE
+from .model.backbone import TotoBackbone
+from .model.toto import Toto
+from .model.transformer import XFORMERS_SWIGLU_AVAILABLE
+
+logger = Logger()
+
+
+class TotoPreTrainedModel(PreTrainedModel):
+    """Abstract base class for all Toto model variants."""
+
+    config_class = TotoConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = False
+
+    def _init_weights(self, module):
+        # Weights are loaded from the pretrained checkpoint; no random 
initialisation needed.
+        pass
+
+
+class TotoForPrediction(TotoPreTrainedModel):
+    """
+    Toto (Timeseries-Optimized Transformer for Observability) model for time 
series prediction.
+
+    Integrates the Toto backbone with AINode's model loading mechanism using 
the
+    transformers PreTrainedModel interface. Weights are loaded directly from 
the
+    Datadog/Toto-Open-Base-1.0 safetensors checkpoint.
+
+    The backbone is stored as ``self.model`` so that safetensors key prefixes
+    (``model.*``) map directly to parameters without any renaming.
+
+    Reference: https://huggingface.co/Datadog/Toto-Open-Base-1.0
+    """
+
+    def __init__(self, config: TotoConfig):
+        super().__init__(config)
+        # Backbone stored as self.model so safetensors keys (model.*) match 
directly.
+        self.model = TotoBackbone(
+            patch_size=config.patch_size,
+            stride=config.stride,
+            embed_dim=config.embed_dim,
+            num_layers=config.num_layers,
+            num_heads=config.num_heads,
+            mlp_hidden_dim=config.mlp_hidden_dim,
+            dropout=config.dropout,
+            spacewise_every_n_layers=config.spacewise_every_n_layers,
+            scaler_cls=config.scaler_cls,
+            output_distribution_classes=config.output_distribution_classes,
+            output_distribution_kwargs=config.output_distribution_kwargs,
+            spacewise_first=config.spacewise_first,
+            
use_memory_efficient_attention=config.use_memory_efficient_attention,
+            stabilize_with_global=config.stabilize_with_global,
+            scale_factor_exponent=config.scale_factor_exponent,
+        )
+        self.post_init()
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+        """
+        Load TotoForPrediction from a local directory containing 
``config.json``
+        and ``model.safetensors``.
+
+        This override is required because:
+        1. The safetensors file uses legacy SwiGLU key names that need 
remapping.
+        2. The config uses class-path strings for ``scaler_cls`` and
+           ``output_distribution_classes`` that must not be filtered out.
+
+        Args:
+            pretrained_model_name_or_path (str): Path to a local directory.
+            **kwargs: Extra key/value pairs merged into the config before 
construction.
+
+        Returns:
+            TotoForPrediction: Fully initialised and weight-loaded model in 
eval mode.
+        """
+        if os.path.isdir(pretrained_model_name_or_path):
+            config_file = os.path.join(pretrained_model_name_or_path, 
"config.json")
+            safetensors_file = os.path.join(
+                pretrained_model_name_or_path, "model.safetensors"
+            )
+        else:
+            raise ValueError(
+                f"pretrained_model_name_or_path must be a local directory, "
+                f"got: {pretrained_model_name_or_path}"
+            )
+
+        # ── Load config ──────────────────────────────────────────────────────
+        config_dict: dict = {}
+        if os.path.exists(config_file):
+            with open(config_file, "r") as f:
+                config_dict = json.load(f)
+        config_dict.update(kwargs)
+
+        # Disable xFormers memory-efficient attention if the library is absent.
+        if not XFORMERS_AVAILABLE and config_dict.get(
+            "use_memory_efficient_attention", True
+        ):
+            config_dict["use_memory_efficient_attention"] = False
+
+        config = TotoConfig(**config_dict)
+
+        # ── Instantiate model 
─────────────────────────────────────────────────
+        instance = cls(config)
+
+        # ── Load safetensors weights 
──────────────────────────────────────────
+        if not os.path.exists(safetensors_file):
+            raise FileNotFoundError(
+                f"Model checkpoint not found at: {safetensors_file}"
+            )
+
+        state_dict = safetorch.load_file(safetensors_file, device="cpu")
+
+        # Remap SwiGLU weight names if the fused xFormers kernel is available.
+        use_fused_swiglu = XFORMERS_SWIGLU_AVAILABLE and not config_dict.get(
+            "pre_xformers_checkpoint", False
+        )
+        state_dict = Toto._map_state_dict_keys(state_dict, use_fused_swiglu)
+
+        # Filter to keys that exist in the model, skipping cached rotary 
buffers.
+        model_state = instance.state_dict()
+        filtered_state_dict = {
+            k: v
+            for k, v in state_dict.items()
+            if k in model_state and not k.endswith("rotary_emb.freqs")
+        }
+
+        instance.load_state_dict(filtered_state_dict, strict=False)
+        instance.eval()
+
+        logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}")
+        return instance
+
+    @property
+    def backbone(self):
+        """The underlying ``TotoBackbone`` used for inference."""
+        return self.model
+
+    @property
+    def device(self):
+        """Device on which model parameters reside."""
+        return next(self.parameters()).device
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py
new file mode 100644
index 00000000000..c6778a5e90b
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import 
ForecastPipeline
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries
+from iotdb.ainode.core.model.toto.inference.forecaster import TotoForecaster
+
+logger = Logger()
+
+
+class TotoPipeline(ForecastPipeline):
+    """
+    Inference pipeline for the Toto time series foundation model.
+
+    Converts raw input tensors into ``MaskedTimeseries`` objects and delegates
+    autoregressive decoding to ``TotoForecaster``.  The forecaster is created
+    lazily on the first call to ``forecast()`` so that pipeline construction
+    does not require a live model (useful during import / registration time).
+    """
+
+    def __init__(self, model_info, **model_kwargs):
+        super().__init__(model_info, **model_kwargs)
+        # Forecaster is created lazily to avoid issues at construction time.
+        self._forecaster: TotoForecaster | None = None
+
+    def _get_forecaster(self) -> TotoForecaster:
+        """Return the cached forecaster, creating it on first call."""
+        if self._forecaster is None:
+            self._forecaster = TotoForecaster(self.model.backbone)
+        return self._forecaster
+
+    def _preprocess(self, inputs, **infer_kwargs):
+        """
+        Preprocess input data for Toto.
+
+        Converts each input dict into a ``MaskedTimeseries`` named-tuple that
+        the ``TotoForecaster`` expects.
+
+        Parameters
+        ----------
+        inputs : list of dict
+            A list of dictionaries containing input data. Each dictionary 
contains:
+            - 'targets': A tensor (1D or 2D) of shape (input_length,) or 
(target_count, input_length).
+
+        infer_kwargs: Additional keyword arguments for inference, such as:
+            - `output_length`(int): Prediction length.
+
+        Returns
+        -------
+        list of MaskedTimeseries
+            Processed inputs compatible with Toto's forecaster.
+        """
+        processed_inputs = []
+        for item in inputs:
+            targets = item["targets"]
+            if targets.ndim == 1:
+                targets = targets.unsqueeze(0)
+
+            n_variates, series_len = targets.shape
+            device = targets.device
+
+            if "past_covariates" in item or "future_covariates" in item:
+                logger.warning(
+                    "TotoPipeline does not support covariates; they will be 
ignored."
+                )
+
+            padding_mask = ~torch.isnan(targets)
+            targets = targets.nan_to_num(0.0)
+
+            id_mask = torch.zeros(
+                n_variates, series_len, dtype=torch.long, device=device
+            )
+            timestamp_seconds = (
+                torch.arange(series_len, dtype=torch.long, device=device)
+                .unsqueeze(0)
+                .expand(n_variates, series_len)
+            )
+            time_interval_seconds = torch.ones(
+                n_variates, dtype=torch.long, device=device
+            )
+
+            processed_inputs.append(
+                MaskedTimeseries(
+                    series=targets,
+                    padding_mask=padding_mask,
+                    id_mask=id_mask,
+                    timestamp_seconds=timestamp_seconds,
+                    time_interval_seconds=time_interval_seconds,
+                )
+            )
+
+        return processed_inputs
+
+    def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]:
+        output_length = infer_kwargs.get("output_length", 96)
+        num_samples = infer_kwargs.get("num_samples", None)
+        samples_per_batch = infer_kwargs.get("samples_per_batch", 10)
+
+        forecaster = self._get_forecaster()
+
+        outputs = []
+        for masked_ts in inputs:
+            masked_ts = masked_ts._replace(
+                series=masked_ts.series.to(self.model.device),
+                padding_mask=masked_ts.padding_mask.to(self.model.device),
+                id_mask=masked_ts.id_mask.to(self.model.device),
+                
timestamp_seconds=masked_ts.timestamp_seconds.to(self.model.device),
+                time_interval_seconds=masked_ts.time_interval_seconds.to(
+                    self.model.device
+                ),
+            )
+            result = forecaster.forecast(
+                masked_ts,
+                prediction_length=output_length,
+                num_samples=num_samples,
+                samples_per_batch=samples_per_batch,
+            )
+            mean = result.mean
+            # Remove batch dimension if present (batch=1 squeeze).
+            if mean.ndim == 3 and mean.shape[0] == 1:
+                mean = mean.squeeze(0)
+            outputs.append(mean)
+        return outputs
+
+    def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]:
+        return outputs
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index 9a142fe7259..0dc630fb0ff 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -117,6 +117,7 @@ setuptools = ">=75.3.0"
 joblib = ">=1.4.2"
 urllib3 = "2.6.3"
 jaxtyping = ">=0.2.24"
+rotary-embedding-torch = ">=0.8.0"
 
 [tool.poetry.scripts]
 ainode = "iotdb.ainode.core.script:main"

Reply via email to