This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch model-management
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 22a8589d7e5c8c6d7322684265dd96bac5ba4fea
Author: RkGrit <[email protected]>
AuthorDate: Mon Nov 10 17:20:30 2025 +0800

    refactor_built_in_models
---
 .../iotdb/ainode/core/model/model_factory.py       |  60 ++
 .../iotdb/ainode/core/model/model_storage.py       |   7 +-
 .../iotdb/ainode/core/model/sktime/__init__.py     |  17 +
 .../configuration_sktime.py}                       | 732 ++++++---------------
 .../ainode/core/model/sktime/modeling_sktime.py    | 261 ++++++++
 .../core/model/sundial/configuration_sundial.py    |   2 +
 6 files changed, 526 insertions(+), 553 deletions(-)

diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py
index 26d863156f3..ceedf11b4e3 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py
@@ -21,6 +21,7 @@ import shutil
 from urllib.parse import urljoin
 
 import yaml
+from huggingface_hub import hf_hub_download
 
 from iotdb.ainode.core.constant import (
     MODEL_CONFIG_FILE_IN_YAML,
@@ -34,12 +35,71 @@ from iotdb.ainode.core.model.uri_utils import (
     download_file,
     download_snapshot_from_hf,
 )
+from iotdb.ainode.core.model.model_enums import BuiltInModelType
 from iotdb.ainode.core.util.serde import get_data_type_byte_from_str
 from iotdb.thrift.ainode.ttypes import TConfigs
+from iotdb.ainode.core.model.model_info import TIMER_REPO_ID
+from iotdb.ainode.core.constant import (
+    MODEL_CONFIG_FILE_IN_JSON,
+    MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+)
 
 logger = Logger()
 
 
+def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool:
+    weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
+    config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON)
+    if not os.path.exists(weights_path):
+        logger.info(
+            f"Model weights file not found at {weights_path}, downloading from 
HuggingFace..."
+        )
+        try:
+            hf_hub_download(
+                repo_id=repo_id,
+                filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+                local_dir=local_dir,
+            )
+            logger.info(f"Got file to {weights_path}")
+        except Exception as e:
+            logger.error(
+                f"Failed to download model weights file to {local_dir} due to 
{e}"
+            )
+            return False
+    if not os.path.exists(config_path):
+        logger.info(
+            f"Model config file not found at {config_path}, downloading from 
HuggingFace..."
+        )
+        try:
+            hf_hub_download(
+                repo_id=repo_id,
+                filename=MODEL_CONFIG_FILE_IN_JSON,
+                local_dir=local_dir,
+            )
+            logger.info(f"Got file to {config_path}")
+        except Exception as e:
+            logger.error(
+                f"Failed to download model config file to {local_dir} due to 
{e}"
+            )
+            return False
+    return True
+
+
+def download_built_in_ltsm_from_hf_if_necessary(
+        model_type: BuiltInModelType, local_dir: str
+) -> bool:
+    """
+    Download the built-in ltsm from HuggingFace repository when necessary.
+
+    Return:
+        bool: True if the model is existed or downloaded successfully, False 
otherwise.
+    """
+    repo_id = TIMER_REPO_ID[model_type]
+    if not _download_file_from_hf_if_necessary(local_dir, repo_id):
+        return False
+    return True
+
+
 def fetch_model_by_uri(
     uri_type: UriType, uri: str, storage_path: str, model_file_type: 
ModelFileType
 ):
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
index e346f569102..1c63b56e519 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
@@ -38,17 +38,14 @@ from iotdb.ainode.core.exception import (
     UnsupportedError,
 )
 from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.built_in_model_factory import (
-    download_built_in_ltsm_from_hf_if_necessary,
-    fetch_built_in_model,
-)
+from iotdb.ainode.core.model.sktime.modeling_sktime import fetch_built_in_model
 from iotdb.ainode.core.model.model_enums import (
     BuiltInModelType,
     ModelCategory,
     ModelFileType,
     ModelStates,
 )
-from iotdb.ainode.core.model.model_factory import fetch_model_by_uri
+from iotdb.ainode.core.model.model_factory import fetch_model_by_uri, 
download_built_in_ltsm_from_hf_if_necessary
 from iotdb.ainode.core.model.model_info import (
     BUILT_IN_LTSM_MAP,
     BUILT_IN_MACHINE_LEARNING_MODEL_MAP,
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
similarity index 56%
rename from iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py
rename to 
iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
index 3b55142350b..18fea61b6ff 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
@@ -15,29 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import os
+
 from abc import abstractmethod
 from typing import Callable, Dict, List
-
-import numpy as np
-from huggingface_hub import hf_hub_download
-from sklearn.preprocessing import MinMaxScaler
-from sktime.detection.hmm_learn import GMMHMM, GaussianHMM
-from sktime.detection.stray import STRAY
-from sktime.forecasting.arima import ARIMA
-from sktime.forecasting.exp_smoothing import ExponentialSmoothing
-from sktime.forecasting.naive import NaiveForecaster
-from sktime.forecasting.trend import STLForecaster
-
-from iotdb.ainode.core.config import AINodeDescriptor
-from iotdb.ainode.core.constant import (
-    MODEL_CONFIG_FILE_IN_JSON,
-    MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
-    AttributeName,
-)
+from enum import Enum
 from iotdb.ainode.core.exception import (
     BuiltInModelNotSupportError,
-    InferenceModelInternalError,
     ListRangeException,
     NumericalRangeException,
     StringRangeException,
@@ -45,134 +28,119 @@ from iotdb.ainode.core.exception import (
 )
 from iotdb.ainode.core.log import Logger
 from iotdb.ainode.core.model.model_enums import BuiltInModelType
-from iotdb.ainode.core.model.model_info import TIMER_REPO_ID
-from iotdb.ainode.core.model.sundial import modeling_sundial
-from iotdb.ainode.core.model.timerxl import modeling_timer
 
 logger = Logger()
 
 
-def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool:
-    weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
-    config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON)
-    if not os.path.exists(weights_path):
-        logger.info(
-            f"Model weights file not found at {weights_path}, downloading from 
HuggingFace..."
-        )
-        try:
-            hf_hub_download(
-                repo_id=repo_id,
-                filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
-                local_dir=local_dir,
-            )
-            logger.info(f"Got file to {weights_path}")
-        except Exception as e:
-            logger.error(
-                f"Failed to download model weights file to {local_dir} due to 
{e}"
-            )
-            return False
-    if not os.path.exists(config_path):
-        logger.info(
-            f"Model config file not found at {config_path}, downloading from 
HuggingFace..."
-        )
-        try:
-            hf_hub_download(
-                repo_id=repo_id,
-                filename=MODEL_CONFIG_FILE_IN_JSON,
-                local_dir=local_dir,
-            )
-            logger.info(f"Got file to {config_path}")
-        except Exception as e:
-            logger.error(
-                f"Failed to download model config file to {local_dir} due to 
{e}"
-            )
-            return False
-    return True
-
-
-def download_built_in_ltsm_from_hf_if_necessary(
-    model_type: BuiltInModelType, local_dir: str
-) -> bool:
-    """
-    Download the built-in ltsm from HuggingFace repository when necessary.
-
-    Return:
-        bool: True if the model is existed or downloaded successfully, False 
otherwise.
-    """
-    repo_id = TIMER_REPO_ID[model_type]
-    if not _download_file_from_hf_if_necessary(local_dir, repo_id):
-        return False
-    return True
-
-
-def get_model_attributes(model_type: BuiltInModelType):
-    if model_type == BuiltInModelType.ARIMA:
-        attribute_map = arima_attribute_map
-    elif model_type == BuiltInModelType.NAIVE_FORECASTER:
-        attribute_map = naive_forecaster_attribute_map
-    elif (
-        model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
-        or model_type == BuiltInModelType.HOLTWINTERS
-    ):
-        attribute_map = exponential_smoothing_attribute_map
-    elif model_type == BuiltInModelType.STL_FORECASTER:
-        attribute_map = stl_forecaster_attribute_map
-    elif model_type == BuiltInModelType.GMM_HMM:
-        attribute_map = gmmhmm_attribute_map
-    elif model_type == BuiltInModelType.GAUSSIAN_HMM:
-        attribute_map = gaussian_hmm_attribute_map
-    elif model_type == BuiltInModelType.STRAY:
-        attribute_map = stray_attribute_map
-    elif model_type == BuiltInModelType.TIMER_XL:
-        attribute_map = timerxl_attribute_map
-    elif model_type == BuiltInModelType.SUNDIAL:
-        attribute_map = sundial_attribute_map
-    else:
-        raise BuiltInModelNotSupportError(model_type.value)
-    return attribute_map
-
-
-def fetch_built_in_model(
-    model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str]
-) -> Callable:
-    """
-    Fetch the built-in model according to its id and directory, not that this 
directory only contains model weights and config.
-    Args:
-        model_type: the type of the built-in model
-        model_dir: for huggingface models only, the directory where the model 
is stored
-    Returns:
-        model: the built-in model
-    """
-    default_attributes = get_model_attributes(model_type)
-    # parse the attributes from inference_attrs
-    attributes = parse_attribute(inference_attrs, default_attributes)
-
-    # build the built-in model
-    if model_type == BuiltInModelType.ARIMA:
-        model = ArimaModel(attributes)
-    elif (
-        model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
-        or model_type == BuiltInModelType.HOLTWINTERS
-    ):
-        model = ExponentialSmoothingModel(attributes)
-    elif model_type == BuiltInModelType.NAIVE_FORECASTER:
-        model = NaiveForecasterModel(attributes)
-    elif model_type == BuiltInModelType.STL_FORECASTER:
-        model = STLForecasterModel(attributes)
-    elif model_type == BuiltInModelType.GMM_HMM:
-        model = GMMHMMModel(attributes)
-    elif model_type == BuiltInModelType.GAUSSIAN_HMM:
-        model = GaussianHmmModel(attributes)
-    elif model_type == BuiltInModelType.STRAY:
-        model = STRAYModel(attributes)
-    elif model_type == BuiltInModelType.TIMER_XL:
-        model = modeling_timer.TimerForPrediction.from_pretrained(model_dir)
-    elif model_type == BuiltInModelType.SUNDIAL:
-        model = 
modeling_sundial.SundialForPrediction.from_pretrained(model_dir)
-    else:
-        raise BuiltInModelNotSupportError(model_type.value)
-
-    return model
+class AttributeName(Enum):
+    # forecast Attribute
+    PREDICT_LENGTH = "predict_length"
+
+    # NaiveForecaster
+    STRATEGY = "strategy"
+    SP = "sp"
+
+    # STLForecaster
+    # SP = 'sp'
+    SEASONAL = "seasonal"
+    SEASONAL_DEG = "seasonal_deg"
+    TREND_DEG = "trend_deg"
+    LOW_PASS_DEG = "low_pass_deg"
+    SEASONAL_JUMP = "seasonal_jump"
+    TREND_JUMP = "trend_jump"
+    LOSS_PASS_JUMP = "low_pass_jump"
+
+    # ExponentialSmoothing
+    DAMPED_TREND = "damped_trend"
+    INITIALIZATION_METHOD = "initialization_method"
+    OPTIMIZED = "optimized"
+    REMOVE_BIAS = "remove_bias"
+    USE_BRUTE = "use_brute"
+
+    # Arima
+    ORDER = "order"
+    SEASONAL_ORDER = "seasonal_order"
+    METHOD = "method"
+    MAXITER = "maxiter"
+    SUPPRESS_WARNINGS = "suppress_warnings"
+    OUT_OF_SAMPLE_SIZE = "out_of_sample_size"
+    SCORING = "scoring"
+    WITH_INTERCEPT = "with_intercept"
+    TIME_VARYING_REGRESSION = "time_varying_regression"
+    ENFORCE_STATIONARITY = "enforce_stationarity"
+    ENFORCE_INVERTIBILITY = "enforce_invertibility"
+    SIMPLE_DIFFERENCING = "simple_differencing"
+    MEASUREMENT_ERROR = "measurement_error"
+    MLE_REGRESSION = "mle_regression"
+    HAMILTON_REPRESENTATION = "hamilton_representation"
+    CONCENTRATE_SCALE = "concentrate_scale"
+
+    # GAUSSIAN_HMM
+    N_COMPONENTS = "n_components"
+    COVARIANCE_TYPE = "covariance_type"
+    MIN_COVAR = "min_covar"
+    STARTPROB_PRIOR = "startprob_prior"
+    TRANSMAT_PRIOR = "transmat_prior"
+    MEANS_PRIOR = "means_prior"
+    MEANS_WEIGHT = "means_weight"
+    COVARS_PRIOR = "covars_prior"
+    COVARS_WEIGHT = "covars_weight"
+    ALGORITHM = "algorithm"
+    N_ITER = "n_iter"
+    TOL = "tol"
+    PARAMS = "params"
+    INIT_PARAMS = "init_params"
+    IMPLEMENTATION = "implementation"
+
+    # GMMHMM
+    # N_COMPONENTS = "n_components"
+    N_MIX = "n_mix"
+    # MIN_COVAR = "min_covar"
+    # STARTPROB_PRIOR = "startprob_prior"
+    # TRANSMAT_PRIOR = "transmat_prior"
+    WEIGHTS_PRIOR = "weights_prior"
+
+    # MEANS_PRIOR = "means_prior"
+    # MEANS_WEIGHT = "means_weight"
+    # ALGORITHM = "algorithm"
+    # COVARIANCE_TYPE = "covariance_type"
+    # N_ITER = "n_iter"
+    # TOL = "tol"
+    # INIT_PARAMS = "init_params"
+    # PARAMS = "params"
+    # IMPLEMENTATION = "implementation"
+
+    # STRAY
+    ALPHA = "alpha"
+    K = "k"
+    KNN_ALGORITHM = "knn_algorithm"
+    P = "p"
+    SIZE_THRESHOLD = "size_threshold"
+    OUTLIER_TAIL = "outlier_tail"
+
+    # timerxl
+    INPUT_TOKEN_LEN = "input_token_len"
+    HIDDEN_SIZE = "hidden_size"
+    INTERMEDIATE_SIZE = "intermediate_size"
+    OUTPUT_TOKEN_LENS = "output_token_lens"
+    NUM_HIDDEN_LAYERS = "num_hidden_layers"
+    NUM_ATTENTION_HEADS = "num_attention_heads"
+    HIDDEN_ACT = "hidden_act"
+    USE_CACHE = "use_cache"
+    ROPE_THETA = "rope_theta"
+    ATTENTION_DROPOUT = "attention_dropout"
+    INITIALIZER_RANGE = "initializer_range"
+    MAX_POSITION_EMBEDDINGS = "max_position_embeddings"
+    CKPT_PATH = "ckpt_path"
+
+    # sundial
+    DROPOUT_RATE = "dropout_rate"
+    FLOW_LOSS_DEPTH = "flow_loss_depth"
+    NUM_SAMPLING_STEPS = "num_sampling_steps"
+    DIFFUSION_BATCH_MUL = "diffusion_batch_mul"
+
+    def name(self) -> str:
+        return self.value
 
 
 class Attribute(object):
@@ -198,11 +166,11 @@ class Attribute(object):
 
 class IntAttribute(Attribute):
     def __init__(
-        self,
-        name: str,
-        default_value: int,
-        default_low: int,
-        default_high: int,
+            self,
+            name: str,
+            default_value: int,
+            default_low: int,
+            default_high: int,
     ):
         super(IntAttribute, self).__init__(name)
         self.__default_value = default_value
@@ -229,11 +197,11 @@ class IntAttribute(Attribute):
 
 class FloatAttribute(Attribute):
     def __init__(
-        self,
-        name: str,
-        default_value: float,
-        default_low: float,
-        default_high: float,
+            self,
+            name: str,
+            default_value: float,
+            default_low: float,
+            default_high: float,
     ):
         super(FloatAttribute, self).__init__(name)
         self.__default_value = default_value
@@ -376,216 +344,8 @@ class TupleAttribute(Attribute):
         return tuple_value
 
 
-def parse_attribute(
-    input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute]
-):
-    """
-    Args:
-        input_attributes: a dict of attributes, where the key is the attribute 
name, the value is the string value of
-            the attribute
-        attribute_map: a dict of hyperparameters, where the key is the 
attribute name, the value is the Attribute
-            object
-    Returns:
-        a dict of attributes, where the key is the attribute name, the value 
is the parsed value of the attribute
-    """
-    attributes = {}
-    for attribute_name in attribute_map:
-        # user specified the attribute
-        if attribute_name in input_attributes:
-            attribute = attribute_map[attribute_name]
-            value = attribute.parse(input_attributes[attribute_name])
-            attribute.validate_value(value)
-            attributes[attribute_name] = value
-        # user did not specify the attribute, use the default value
-        else:
-            try:
-                attributes[attribute_name] = attribute_map[
-                    attribute_name
-                ].get_default_value()
-            except NotImplementedError as e:
-                logger.error(f"attribute {attribute_name} is not implemented.")
-                raise e
-    return attributes
-
-
-sundial_attribute_map = {
-    AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
-        name=AttributeName.INPUT_TOKEN_LEN.value,
-        default_value=16,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.HIDDEN_SIZE.value: IntAttribute(
-        name=AttributeName.HIDDEN_SIZE.value,
-        default_value=768,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.INTERMEDIATE_SIZE.value: IntAttribute(
-        name=AttributeName.INTERMEDIATE_SIZE.value,
-        default_value=3072,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute(
-        name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], 
value_type=int
-    ),
-    AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute(
-        name=AttributeName.NUM_HIDDEN_LAYERS.value,
-        default_value=12,
-        default_low=1,
-        default_high=16,
-    ),
-    AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute(
-        name=AttributeName.NUM_ATTENTION_HEADS.value,
-        default_value=12,
-        default_low=1,
-        default_high=192,
-    ),
-    AttributeName.HIDDEN_ACT.value: StringAttribute(
-        name=AttributeName.HIDDEN_ACT.value,
-        default_value="silu",
-        value_choices=["relu", "gelu", "silu", "tanh"],
-    ),
-    AttributeName.USE_CACHE.value: BooleanAttribute(
-        name=AttributeName.USE_CACHE.value,
-        default_value=True,
-    ),
-    AttributeName.ROPE_THETA.value: IntAttribute(
-        name=AttributeName.ROPE_THETA.value,
-        default_value=10000,
-        default_low=1000,
-        default_high=50000,
-    ),
-    AttributeName.DROPOUT_RATE.value: FloatAttribute(
-        name=AttributeName.DROPOUT_RATE.value,
-        default_value=0.1,
-        default_low=0.0,
-        default_high=1.0,
-    ),
-    AttributeName.INITIALIZER_RANGE.value: FloatAttribute(
-        name=AttributeName.INITIALIZER_RANGE.value,
-        default_value=0.02,
-        default_low=0.0,
-        default_high=1.0,
-    ),
-    AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute(
-        name=AttributeName.MAX_POSITION_EMBEDDINGS.value,
-        default_value=10000,
-        default_low=1,
-        default_high=50000,
-    ),
-    AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute(
-        name=AttributeName.FLOW_LOSS_DEPTH.value,
-        default_value=3,
-        default_low=1,
-        default_high=50,
-    ),
-    AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute(
-        name=AttributeName.NUM_SAMPLING_STEPS.value,
-        default_value=50,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute(
-        name=AttributeName.DIFFUSION_BATCH_MUL.value,
-        default_value=4,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.CKPT_PATH.value: StringAttribute(
-        name=AttributeName.CKPT_PATH.value,
-        default_value=os.path.join(
-            os.getcwd(),
-            AINodeDescriptor().get_config().get_ain_models_dir(),
-            "weights",
-            "sundial",
-        ),
-        value_choices=[""],
-    ),
-}
-
-timerxl_attribute_map = {
-    AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
-        name=AttributeName.INPUT_TOKEN_LEN.value,
-        default_value=96,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.HIDDEN_SIZE.value: IntAttribute(
-        name=AttributeName.HIDDEN_SIZE.value,
-        default_value=1024,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.INTERMEDIATE_SIZE.value: IntAttribute(
-        name=AttributeName.INTERMEDIATE_SIZE.value,
-        default_value=2048,
-        default_low=1,
-        default_high=5000,
-    ),
-    AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute(
-        name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], 
value_type=int
-    ),
-    AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute(
-        name=AttributeName.NUM_HIDDEN_LAYERS.value,
-        default_value=8,
-        default_low=1,
-        default_high=16,
-    ),
-    AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute(
-        name=AttributeName.NUM_ATTENTION_HEADS.value,
-        default_value=8,
-        default_low=1,
-        default_high=192,
-    ),
-    AttributeName.HIDDEN_ACT.value: StringAttribute(
-        name=AttributeName.HIDDEN_ACT.value,
-        default_value="silu",
-        value_choices=["relu", "gelu", "silu", "tanh"],
-    ),
-    AttributeName.USE_CACHE.value: BooleanAttribute(
-        name=AttributeName.USE_CACHE.value,
-        default_value=True,
-    ),
-    AttributeName.ROPE_THETA.value: IntAttribute(
-        name=AttributeName.ROPE_THETA.value,
-        default_value=10000,
-        default_low=1000,
-        default_high=50000,
-    ),
-    AttributeName.ATTENTION_DROPOUT.value: FloatAttribute(
-        name=AttributeName.ATTENTION_DROPOUT.value,
-        default_value=0.0,
-        default_low=0.0,
-        default_high=1.0,
-    ),
-    AttributeName.INITIALIZER_RANGE.value: FloatAttribute(
-        name=AttributeName.INITIALIZER_RANGE.value,
-        default_value=0.02,
-        default_low=0.0,
-        default_high=1.0,
-    ),
-    AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute(
-        name=AttributeName.MAX_POSITION_EMBEDDINGS.value,
-        default_value=10000,
-        default_low=1,
-        default_high=50000,
-    ),
-    AttributeName.CKPT_PATH.value: StringAttribute(
-        name=AttributeName.CKPT_PATH.value,
-        default_value=os.path.join(
-            os.getcwd(),
-            AINodeDescriptor().get_config().get_ain_models_dir(),
-            "weights",
-            "timerxl",
-            "model.safetensors",
-        ),
-        value_choices=[""],
-    ),
-}
-
 # built-in sktime model attributes
+
 # NaiveForecaster
 naive_forecaster_attribute_map = {
     AttributeName.PREDICT_LENGTH.value: IntAttribute(
@@ -603,6 +363,7 @@ naive_forecaster_attribute_map = {
         name=AttributeName.SP.value, default_value=1, default_low=1, 
default_high=5000
     ),
 }
+
 # ExponentialSmoothing
 exponential_smoothing_attribute_map = {
     AttributeName.PREDICT_LENGTH.value: IntAttribute(
@@ -633,6 +394,7 @@ exponential_smoothing_attribute_map = {
         default_value=False,
     ),
 }
+
 # Arima
 arima_attribute_map = {
     AttributeName.PREDICT_LENGTH.value: IntAttribute(
@@ -712,6 +474,7 @@ arima_attribute_map = {
         default_value=False,
     ),
 }
+
 # STLForecaster
 stl_forecaster_attribute_map = {
     AttributeName.PREDICT_LENGTH.value: IntAttribute(
@@ -1045,194 +808,67 @@ stray_attribute_map = {
 }
 
 
-class BuiltInModel(object):
-    def __init__(self, attributes):
-        self._attributes = attributes
-        self._model = None
-
-    @abstractmethod
-    def inference(self, data):
-        raise NotImplementedError
-
-
-class ArimaModel(BuiltInModel):
-    def __init__(self, attributes):
-        super(ArimaModel, self).__init__(attributes)
-        self._model = ARIMA(
-            order=attributes["order"],
-            seasonal_order=attributes["seasonal_order"],
-            method=attributes["method"],
-            suppress_warnings=attributes["suppress_warnings"],
-            maxiter=attributes["maxiter"],
-            out_of_sample_size=attributes["out_of_sample_size"],
-            scoring=attributes["scoring"],
-            with_intercept=attributes["with_intercept"],
-            time_varying_regression=attributes["time_varying_regression"],
-            enforce_stationarity=attributes["enforce_stationarity"],
-            enforce_invertibility=attributes["enforce_invertibility"],
-            simple_differencing=attributes["simple_differencing"],
-            measurement_error=attributes["measurement_error"],
-            mle_regression=attributes["mle_regression"],
-            hamilton_representation=attributes["hamilton_representation"],
-            concentrate_scale=attributes["concentrate_scale"],
-        )
-
-    def inference(self, data):
-        try:
-            predict_length = self._attributes["predict_length"]
-            self._model.fit(data)
-            output = self._model.predict(fh=range(predict_length))
-            output = np.array(output, dtype=np.float64)
-            return output
-        except Exception as e:
-            raise InferenceModelInternalError(str(e))
-
-
-class ExponentialSmoothingModel(BuiltInModel):
-    def __init__(self, attributes):
-        super(ExponentialSmoothingModel, self).__init__(attributes)
-        self._model = ExponentialSmoothing(
-            damped_trend=attributes["damped_trend"],
-            initialization_method=attributes["initialization_method"],
-            optimized=attributes["optimized"],
-            remove_bias=attributes["remove_bias"],
-            use_brute=attributes["use_brute"],
-        )
-
-    def inference(self, data):
-        try:
-            predict_length = self._attributes["predict_length"]
-            self._model.fit(data)
-            output = self._model.predict(fh=range(predict_length))
-            output = np.array(output, dtype=np.float64)
-            return output
-        except Exception as e:
-            raise InferenceModelInternalError(str(e))
-
-
-class NaiveForecasterModel(BuiltInModel):
-    def __init__(self, attributes):
-        super(NaiveForecasterModel, self).__init__(attributes)
-        self._model = NaiveForecaster(
-            strategy=attributes["strategy"], sp=attributes["sp"]
-        )
+def get_attributes(model_type: BuiltInModelType):
+    """
+    Get the attribute map of the built-in model.
 
-    def inference(self, data):
-        try:
-            predict_length = self._attributes["predict_length"]
-            self._model.fit(data)
-            output = self._model.predict(fh=range(predict_length))
-            output = np.array(output, dtype=np.float64)
-            return output
-        except Exception as e:
-            raise InferenceModelInternalError(str(e))
-
-
-class STLForecasterModel(BuiltInModel):
-    def __init__(self, attributes):
-        super(STLForecasterModel, self).__init__(attributes)
-        self._model = STLForecaster(
-            sp=attributes["sp"],
-            seasonal=attributes["seasonal"],
-            seasonal_deg=attributes["seasonal_deg"],
-            trend_deg=attributes["trend_deg"],
-            low_pass_deg=attributes["low_pass_deg"],
-            seasonal_jump=attributes["seasonal_jump"],
-            trend_jump=attributes["trend_jump"],
-            low_pass_jump=attributes["low_pass_jump"],
-        )
+    Args:
+        model_type: the type of the built-in model
 
-    def inference(self, data):
-        try:
-            predict_length = self._attributes["predict_length"]
-            self._model.fit(data)
-            output = self._model.predict(fh=range(predict_length))
-            output = np.array(output, dtype=np.float64)
-            return output
-        except Exception as e:
-            raise InferenceModelInternalError(str(e))
-
-
-class GMMHMMModel(BuiltInModel):
-    def __init__(self, attributes):
-        super(GMMHMMModel, self).__init__(attributes)
-        self._model = GMMHMM(
-            n_components=attributes["n_components"],
-            n_mix=attributes["n_mix"],
-            min_covar=attributes["min_covar"],
-            startprob_prior=attributes["startprob_prior"],
-            transmat_prior=attributes["transmat_prior"],
-            means_prior=attributes["means_prior"],
-            means_weight=attributes["means_weight"],
-            weights_prior=attributes["weights_prior"],
-            algorithm=attributes["algorithm"],
-            covariance_type=attributes["covariance_type"],
-            n_iter=attributes["n_iter"],
-            tol=attributes["tol"],
-            params=attributes["params"],
-            init_params=attributes["init_params"],
-            implementation=attributes["implementation"],
-        )
+    Returns:
+        the attribute map of the built-in model
 
-    def inference(self, data):
-        try:
-            self._model.fit(data)
-            output = self._model.predict(data)
-            output = np.array(output, dtype=np.int32)
-            return output
-        except Exception as e:
-            raise InferenceModelInternalError(str(e))
-
-
-class GaussianHmmModel(BuiltInModel):
-    def __init__(self, attributes):
-        super(GaussianHmmModel, self).__init__(attributes)
-        self._model = GaussianHMM(
-            n_components=attributes["n_components"],
-            covariance_type=attributes["covariance_type"],
-            min_covar=attributes["min_covar"],
-            startprob_prior=attributes["startprob_prior"],
-            transmat_prior=attributes["transmat_prior"],
-            means_prior=attributes["means_prior"],
-            means_weight=attributes["means_weight"],
-            covars_prior=attributes["covars_prior"],
-            covars_weight=attributes["covars_weight"],
-            algorithm=attributes["algorithm"],
-            n_iter=attributes["n_iter"],
-            tol=attributes["tol"],
-            params=attributes["params"],
-            init_params=attributes["init_params"],
-            implementation=attributes["implementation"],
-        )
+    """
+    if model_type == BuiltInModelType.ARIMA:
+        attribute_map = arima_attribute_map
+    elif model_type == BuiltInModelType.NAIVE_FORECASTER:
+        attribute_map = naive_forecaster_attribute_map
+    elif (
+            model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
+            or model_type == BuiltInModelType.HOLTWINTERS
+    ):
+        attribute_map = exponential_smoothing_attribute_map
+    elif model_type == BuiltInModelType.STL_FORECASTER:
+        attribute_map = stl_forecaster_attribute_map
+    elif model_type == BuiltInModelType.GMM_HMM:
+        attribute_map = gmmhmm_attribute_map
+    elif model_type == BuiltInModelType.GAUSSIAN_HMM:
+        attribute_map = gaussian_hmm_attribute_map
+    elif model_type == BuiltInModelType.STRAY:
+        attribute_map = stray_attribute_map
+    else:
+        raise BuiltInModelNotSupportError(model_type.value)
+    return attribute_map
 
-    def inference(self, data):
-        try:
-            self._model.fit(data)
-            output = self._model.predict(data)
-            output = np.array(output, dtype=np.int32)
-            return output
-        except Exception as e:
-            raise InferenceModelInternalError(str(e))
-
-
-class STRAYModel(BuiltInModel):
-    def __init__(self, attributes):
-        super(STRAYModel, self).__init__(attributes)
-        self._model = STRAY(
-            alpha=attributes["alpha"],
-            k=attributes["k"],
-            knn_algorithm=attributes["knn_algorithm"],
-            p=attributes["p"],
-            size_threshold=attributes["size_threshold"],
-            outlier_tail=attributes["outlier_tail"],
-        )
 
-    def inference(self, data):
-        try:
-            data = MinMaxScaler().fit_transform(data)
-            output = self._model.fit_transform(data)
-            # change the output to int
-            output = np.array(output, dtype=np.int32)
-            return output
-        except Exception as e:
-            raise InferenceModelInternalError(str(e))
+def update_attribute(
+        input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute]
+):
+    """
+    Update the attribute of the built-in model using the input attributes.
+    Args:
+        input_attributes: a dict of attributes, where the key is the attribute 
name, the value is the string value of
+            the attribute
+        attribute_map: a dict of hyperparameters, where the key is the 
attribute name, the value is the Attribute
+            object
+    Returns:
+        a dict of attributes, where the key is the attribute name, the value 
is the parsed value of the attribute
+    """
+    attributes = {}
+    for attribute_name in attribute_map:
+        # user specified the attribute
+        if attribute_name in input_attributes:
+            attribute = attribute_map[attribute_name]
+            value = attribute.parse(input_attributes[attribute_name])
+            attribute.validate_value(value)
+            attributes[attribute_name] = value
+        # user did not specify the attribute, use the default value
+        else:
+            try:
+                attributes[attribute_name] = attribute_map[
+                    attribute_name
+                ].get_default_value()
+            except NotImplementedError as e:
+                logger.error(f"attribute {attribute_name} is not implemented.")
+                raise e
+    return attributes
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
new file mode 100644
index 00000000000..7e8e41c4dcf
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Any, Dict
+from abc import abstractmethod
+import numpy as np
+from sklearn.preprocessing import MinMaxScaler
+from sktime.detection.hmm_learn import GMMHMM, GaussianHMM
+from sktime.detection.stray import STRAY
+from sktime.forecasting.arima import ARIMA
+from sktime.forecasting.exp_smoothing import ExponentialSmoothing
+from sktime.forecasting.naive import NaiveForecaster
+from sktime.forecasting.trend import STLForecaster
+
+from iotdb.ainode.core.model.sktime.configuration_sktime import 
get_attributes, update_attribute
+from iotdb.ainode.core.model.model_enums import BuiltInModelType
+from iotdb.ainode.core.exception import InferenceModelInternalError, 
BuiltInModelNotSupportError
+from iotdb.ainode.core.log import Logger
+
+logger = Logger()
+
+
+class BuiltInModel(object):
+    def __init__(self, attributes):
+        self._attributes = attributes
+        self._model = None
+
+    @abstractmethod
+    def inference(self, data):
+        raise NotImplementedError
+
+
+class ArimaModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(ArimaModel, self).__init__(attributes)
+        self._model = ARIMA(
+            order=attributes["order"],
+            seasonal_order=attributes["seasonal_order"],
+            method=attributes["method"],
+            suppress_warnings=attributes["suppress_warnings"],
+            maxiter=attributes["maxiter"],
+            out_of_sample_size=attributes["out_of_sample_size"],
+            scoring=attributes["scoring"],
+            with_intercept=attributes["with_intercept"],
+            time_varying_regression=attributes["time_varying_regression"],
+            enforce_stationarity=attributes["enforce_stationarity"],
+            enforce_invertibility=attributes["enforce_invertibility"],
+            simple_differencing=attributes["simple_differencing"],
+            measurement_error=attributes["measurement_error"],
+            mle_regression=attributes["mle_regression"],
+            hamilton_representation=attributes["hamilton_representation"],
+            concentrate_scale=attributes["concentrate_scale"],
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes["predict_length"]
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class ExponentialSmoothingModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(ExponentialSmoothingModel, self).__init__(attributes)
+        self._model = ExponentialSmoothing(
+            damped_trend=attributes["damped_trend"],
+            initialization_method=attributes["initialization_method"],
+            optimized=attributes["optimized"],
+            remove_bias=attributes["remove_bias"],
+            use_brute=attributes["use_brute"],
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes["predict_length"]
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class NaiveForecasterModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(NaiveForecasterModel, self).__init__(attributes)
+        self._model = NaiveForecaster(
+            strategy=attributes["strategy"], sp=attributes["sp"]
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes["predict_length"]
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class STLForecasterModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(STLForecasterModel, self).__init__(attributes)
+        self._model = STLForecaster(
+            sp=attributes["sp"],
+            seasonal=attributes["seasonal"],
+            seasonal_deg=attributes["seasonal_deg"],
+            trend_deg=attributes["trend_deg"],
+            low_pass_deg=attributes["low_pass_deg"],
+            seasonal_jump=attributes["seasonal_jump"],
+            trend_jump=attributes["trend_jump"],
+            low_pass_jump=attributes["low_pass_jump"],
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes["predict_length"]
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class GMMHMMModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(GMMHMMModel, self).__init__(attributes)
+        self._model = GMMHMM(
+            n_components=attributes["n_components"],
+            n_mix=attributes["n_mix"],
+            min_covar=attributes["min_covar"],
+            startprob_prior=attributes["startprob_prior"],
+            transmat_prior=attributes["transmat_prior"],
+            means_prior=attributes["means_prior"],
+            means_weight=attributes["means_weight"],
+            weights_prior=attributes["weights_prior"],
+            algorithm=attributes["algorithm"],
+            covariance_type=attributes["covariance_type"],
+            n_iter=attributes["n_iter"],
+            tol=attributes["tol"],
+            params=attributes["params"],
+            init_params=attributes["init_params"],
+            implementation=attributes["implementation"],
+        )
+
+    def inference(self, data):
+        try:
+            self._model.fit(data)
+            output = self._model.predict(data)
+            output = np.array(output, dtype=np.int32)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class GaussianHmmModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(GaussianHmmModel, self).__init__(attributes)
+        self._model = GaussianHMM(
+            n_components=attributes["n_components"],
+            covariance_type=attributes["covariance_type"],
+            min_covar=attributes["min_covar"],
+            startprob_prior=attributes["startprob_prior"],
+            transmat_prior=attributes["transmat_prior"],
+            means_prior=attributes["means_prior"],
+            means_weight=attributes["means_weight"],
+            covars_prior=attributes["covars_prior"],
+            covars_weight=attributes["covars_weight"],
+            algorithm=attributes["algorithm"],
+            n_iter=attributes["n_iter"],
+            tol=attributes["tol"],
+            params=attributes["params"],
+            init_params=attributes["init_params"],
+            implementation=attributes["implementation"],
+        )
+
+    def inference(self, data):
+        try:
+            self._model.fit(data)
+            output = self._model.predict(data)
+            output = np.array(output, dtype=np.int32)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class STRAYModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(STRAYModel, self).__init__(attributes)
+        self._model = STRAY(
+            alpha=attributes["alpha"],
+            k=attributes["k"],
+            knn_algorithm=attributes["knn_algorithm"],
+            p=attributes["p"],
+            size_threshold=attributes["size_threshold"],
+            outlier_tail=attributes["outlier_tail"],
+        )
+
+    def inference(self, data):
+        try:
+            data = MinMaxScaler().fit_transform(data)
+            output = self._model.fit_transform(data)
+            # change the output to int
+            output = np.array(output, dtype=np.int32)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+def fetch_built_in_model(
+        model_type: BuiltInModelType, inference_attrs: Dict[str, str]
+) -> Any:
+    default_attributes = get_attributes(model_type)
+    attributes = update_attribute(inference_attrs, default_attributes)
+
+    if model_type == BuiltInModelType.ARIMA:
+        model = ArimaModel(attributes)
+    elif (
+            model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
+            or model_type == BuiltInModelType.HOLTWINTERS
+    ):
+        model = ExponentialSmoothingModel(attributes)
+    elif model_type == BuiltInModelType.NAIVE_FORECASTER:
+        model = NaiveForecasterModel(attributes)
+    elif model_type == BuiltInModelType.STL_FORECASTER:
+        model = STLForecasterModel(attributes)
+    elif model_type == BuiltInModelType.GMM_HMM:
+        model = GMMHMMModel(attributes)
+    elif model_type == BuiltInModelType.GAUSSIAN_HMM:
+        model = GaussianHmmModel(attributes)
+    elif model_type == BuiltInModelType.STRAY:
+        model = STRAYModel(attributes)
+    # elif model_type == BuiltInModelType.TIMER_XL:
+    #     model = modeling_timer.TimerForPrediction.from_pretrained(model_dir)
+    # elif model_type == BuiltInModelType.SUNDIAL:
+    #     model = 
modeling_sundial.SundialForPrediction.from_pretrained(model_dir)
+    else:
+        raise BuiltInModelNotSupportError(model_type.value)
+
+    return model
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py
index 21eefef2933..5b9eb7f1f6b 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py
@@ -63,3 +63,5 @@ class SundialConfig(PretrainedConfig):
         super().__init__(
             **kwargs,
         )
+
+# TODO: Lacking checkpoint_path
\ No newline at end of file


Reply via email to