This is an automated email from the ASF dual-hosted git repository.

ycycse pushed a commit to branch ycy/AINodeTraining
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/ycy/AINodeTraining by this 
push:
     new 8404a7c5c56 restore model part
8404a7c5c56 is described below

commit 8404a7c5c564be46f1f493746087d9ff2fc2c03d
Author: YangCaiyin <[email protected]>
AuthorDate: Wed Mar 5 19:00:05 2025 +0800

    restore model part
---
 iotdb-core/ainode/ainode/core/model/__init__.py    |  17 +
 .../ainode/core/model/built_in_model_factory.py    | 924 +++++++++++++++++++++
 .../ainode/ainode/core/model/model_factory.py      | 235 ++++++
 .../ainode/ainode/core/model/model_storage.py      | 123 +++
 4 files changed, 1299 insertions(+)

diff --git a/iotdb-core/ainode/ainode/core/model/__init__.py 
b/iotdb-core/ainode/ainode/core/model/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
new file mode 100644
index 00000000000..4524272dfaf
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -0,0 +1,924 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from abc import abstractmethod
+from typing import List, Dict
+
+import numpy as np
+from sklearn.preprocessing import MinMaxScaler
+from sktime.annotation.hmm_learn import GaussianHMM, GMMHMM
+from sktime.annotation.stray import STRAY
+from sktime.forecasting.arima import ARIMA
+from sktime.forecasting.exp_smoothing import ExponentialSmoothing
+from sktime.forecasting.naive import NaiveForecaster
+from sktime.forecasting.trend import STLForecaster
+
+from ainode.core.constant import AttributeName, BuiltInModelType
+from ainode.core.exception import InferenceModelInternalError, 
AttributeNotSupportError
+from ainode.core.exception import WrongAttributeTypeError, 
NumericalRangeException, StringRangeException, \
+    ListRangeException, BuiltInModelNotSupportError
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+def get_model_attributes(model_id: str):
+    if model_id == BuiltInModelType.ARIMA.value:
+        attribute_map = arima_attribute_map
+    elif model_id == BuiltInModelType.NAIVE_FORECASTER.value:
+        attribute_map = naive_forecaster_attribute_map
+    elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value:
+        attribute_map = exponential_smoothing_attribute_map
+    elif model_id == BuiltInModelType.STL_FORECASTER.value:
+        attribute_map = stl_forecaster_attribute_map
+    elif model_id == BuiltInModelType.GMM_HMM.value:
+        attribute_map = gmmhmm_attribute_map
+    elif model_id == BuiltInModelType.GAUSSIAN_HMM.value:
+        attribute_map = gaussian_hmm_attribute_map
+    elif model_id == BuiltInModelType.STRAY.value:
+        attribute_map = stray_attribute_map
+    else:
+        raise BuiltInModelNotSupportError(model_id)
+    return attribute_map
+
+
+def fetch_built_in_model(model_id, inference_attributes):
+    """
+    Args:
+        model_id: the unique id of the model
+        inference_attributes: a list of attributes to be inferred, in this 
function, the attributes will include some
+            parameters of the built-in model. Some parameters are optional, 
and if the parameters are not
+            specified, the default value will be used.
+    Returns:
+        model: the built-in model
+        attributes: a dict of attributes, where the key is the attribute name, 
the value is the parsed value of the
+            attribute
+    Description:
+        the create_built_in_model function will create the built-in model, 
which does not require user
+        registration. This module will parse the inference attributes and 
create the built-in model.
+    """
+    attribute_map = get_model_attributes(model_id)
+
+    # validate the inference attributes
+    for attribute_name in inference_attributes:
+        if attribute_name not in attribute_map:
+            raise AttributeNotSupportError(model_id, attribute_name)
+
+    # parse the inference attributes, attributes is a Dict[str, Any]
+    attributes = parse_attribute(inference_attributes, attribute_map)
+
+    # build the built-in model
+    if model_id == BuiltInModelType.ARIMA.value:
+        model = ArimaModel(attributes)
+    elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value:
+        model = ExponentialSmoothingModel(attributes)
+    elif model_id == BuiltInModelType.NAIVE_FORECASTER.value:
+        model = NaiveForecasterModel(attributes)
+    elif model_id == BuiltInModelType.STL_FORECASTER.value:
+        model = STLForecasterModel(attributes)
+    elif model_id == BuiltInModelType.GMM_HMM.value:
+        model = GMMHMMModel(attributes)
+    elif model_id == BuiltInModelType.GAUSSIAN_HMM.value:
+        model = GaussianHmmModel(attributes)
+    elif model_id == BuiltInModelType.STRAY.value:
+        model = STRAYModel(attributes)
+    else:
+        raise BuiltInModelNotSupportError(model_id)
+
+    return model
+
+
+class Attribute(object):
+    def __init__(self, name: str):
+        """
+        Args:
+            name: the name of the attribute
+        """
+        self._name = name
+
+    @abstractmethod
+    def get_default_value(self):
+        raise NotImplementedError
+
+    @abstractmethod
+    def validate_value(self, value):
+        raise NotImplementedError
+
+    @abstractmethod
+    def parse(self, string_value: str):
+        raise NotImplementedError
+
+
+class IntAttribute(Attribute):
+    def __init__(self, name: str,
+                 default_value: int,
+                 default_low: int,
+                 default_high: int,
+                 ):
+        super(IntAttribute, self).__init__(name)
+        self.__default_value = default_value
+        self.__default_low = default_low
+        self.__default_high = default_high
+
+    def get_default_value(self):
+        return self.__default_value
+
+    def validate_value(self, value):
+        if self.__default_low <= value <= self.__default_high:
+            return True
+        raise NumericalRangeException(self._name, value, self.__default_low, 
self.__default_high)
+
+    def parse(self, string_value: str):
+        try:
+            int_value = int(string_value)
+        except:
+            raise WrongAttributeTypeError(self._name, "int")
+        return int_value
+
+
+class FloatAttribute(Attribute):
+    def __init__(self, name: str,
+                 default_value: float,
+                 default_low: float,
+                 default_high: float,
+                 ):
+        super(FloatAttribute, self).__init__(name)
+        self.__default_value = default_value
+        self.__default_low = default_low
+        self.__default_high = default_high
+
+    def get_default_value(self):
+        return self.__default_value
+
+    def validate_value(self, value):
+        if self.__default_low <= value <= self.__default_high:
+            return True
+        raise NumericalRangeException(self._name, value, self.__default_low, 
self.__default_high)
+
+    def parse(self, string_value: str):
+        try:
+            float_value = float(string_value)
+        except:
+            raise WrongAttributeTypeError(self._name, "float")
+        return float_value
+
+
+class StringAttribute(Attribute):
+    def __init__(self, name: str, default_value: str, value_choices: 
List[str]):
+        super(StringAttribute, self).__init__(name)
+        self.__default_value = default_value
+        self.__value_choices = value_choices
+
+    def get_default_value(self):
+        return self.__default_value
+
+    def validate_value(self, value):
+        if value in self.__value_choices:
+            return True
+        raise StringRangeException(self._name, value, self.__value_choices)
+
+    def parse(self, string_value: str):
+        return string_value
+
+
+class BooleanAttribute(Attribute):
+    def __init__(self, name: str, default_value: bool):
+        super(BooleanAttribute, self).__init__(name)
+        self.__default_value = default_value
+
+    def get_default_value(self):
+        return self.__default_value
+
+    def validate_value(self, value):
+        if isinstance(value, bool):
+            return True
+        raise WrongAttributeTypeError(self._name, "bool")
+
+    def parse(self, string_value: str):
+        if string_value.lower() == "true":
+            return True
+        elif string_value.lower() == "false":
+            return False
+        else:
+            raise WrongAttributeTypeError(self._name, "bool")
+
+
+class ListAttribute(Attribute):
+    def __init__(self, name: str, default_value: List, value_type):
+        """
+        value_type is the type of the elements in the list, e.g. int, float, 
str
+        """
+        super(ListAttribute, self).__init__(name)
+        self.__default_value = default_value
+        self.__value_type = value_type
+        self.__type_to_str = {str: "str", int: "int", float: "float"}
+
+    def get_default_value(self):
+        return self.__default_value
+
+    def validate_value(self, value):
+        if not isinstance(value, list):
+            raise WrongAttributeTypeError(self._name, "list")
+        for value_item in value:
+            if not isinstance(value_item, self.__value_type):
+                raise WrongAttributeTypeError(self._name, self.__value_type)
+        return True
+
+    def parse(self, string_value: str):
+        try:
+            list_value = eval(string_value)
+        except:
+            raise WrongAttributeTypeError(self._name, "list")
+        if not isinstance(list_value, list):
+            raise WrongAttributeTypeError(self._name, "list")
+        for i in range(len(list_value)):
+            try:
+                list_value[i] = self.__value_type(list_value[i])
+            except:
+                raise ListRangeException(self._name, list_value, 
self.__type_to_str[self.__value_type])
+        return list_value
+
+
+class TupleAttribute(Attribute):
+    def __init__(self, name: str, default_value: tuple, value_type):
+        """
+        value_type is the type of the elements in the list, e.g. int, float, 
str
+        """
+        super(TupleAttribute, self).__init__(name)
+        self.__default_value = default_value
+        self.__value_type = value_type
+        self.__type_to_str = {str: "str", int: "int", float: "float"}
+
+    def get_default_value(self):
+        return self.__default_value
+
+    def validate_value(self, value):
+        if not isinstance(value, tuple):
+            raise WrongAttributeTypeError(self._name, "tuple")
+        for value_item in value:
+            if not isinstance(value_item, self.__value_type):
+                raise WrongAttributeTypeError(self._name, self.__value_type)
+        return True
+
+    def parse(self, string_value: str):
+        try:
+            tuple_value = eval(string_value)
+        except:
+            raise WrongAttributeTypeError(self._name, "tuple")
+        if not isinstance(tuple_value, tuple):
+            raise WrongAttributeTypeError(self._name, "tuple")
+        list_value = list(tuple_value)
+        for i in range(len(list_value)):
+            try:
+                list_value[i] = self.__value_type(list_value[i])
+            except:
+                raise ListRangeException(self._name, list_value, 
self.__type_to_str[self.__value_type])
+        tuple_value = tuple(list_value)
+        return tuple_value
+
+
+def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, 
Attribute]):
+    """
+    Args:
+        input_attributes: a dict of attributes, where the key is the attribute 
name, the value is the string value of
+            the attribute
+        attribute_map: a dict of hyperparameters, where the key is the 
attribute name, the value is the Attribute
+            object
+    Returns:
+        a dict of attributes, where the key is the attribute name, the value 
is the parsed value of the attribute
+    """
+    attributes = {}
+    for attribute_name in attribute_map:
+        # user specified the attribute
+        if attribute_name in input_attributes:
+            attribute = attribute_map[attribute_name]
+            value = attribute.parse(input_attributes[attribute_name])
+            attribute.validate_value(value)
+            attributes[attribute_name] = value
+        # user did not specify the attribute, use the default value
+        else:
+            try:
+                attributes[attribute_name] = 
attribute_map[attribute_name].get_default_value()
+            except NotImplementedError as e:
+                logger.error(f"attribute {attribute_name} is not implemented.")
+                raise e
+    return attributes
+
+
+# built-in sktime model attributes
+# NaiveForecaster
+naive_forecaster_attribute_map = {
+    AttributeName.PREDICT_LENGTH.value: IntAttribute(
+        name=AttributeName.PREDICT_LENGTH.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.STRATEGY.value: StringAttribute(
+        name=AttributeName.STRATEGY.value,
+        default_value="last",
+        value_choices=["last", "mean"],
+    ),
+    AttributeName.SP.value: IntAttribute(
+        name=AttributeName.SP.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+}
+# ExponentialSmoothing
+exponential_smoothing_attribute_map = {
+    AttributeName.PREDICT_LENGTH.value: IntAttribute(
+        name=AttributeName.PREDICT_LENGTH.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.DAMPED_TREND.value: BooleanAttribute(
+        name=AttributeName.DAMPED_TREND.value,
+        default_value=False,
+    ),
+    AttributeName.INITIALIZATION_METHOD.value: StringAttribute(
+        name=AttributeName.INITIALIZATION_METHOD.value,
+        default_value="estimated",
+        value_choices=["estimated", "heuristic", "legacy-heuristic", "known"],
+    ),
+    AttributeName.OPTIMIZED.value: BooleanAttribute(
+        name=AttributeName.OPTIMIZED.value,
+        default_value=True,
+    ),
+    AttributeName.REMOVE_BIAS.value: BooleanAttribute(
+        name=AttributeName.REMOVE_BIAS.value,
+        default_value=False,
+    ),
+    AttributeName.USE_BRUTE.value: BooleanAttribute(
+        name=AttributeName.USE_BRUTE.value,
+        default_value=False,
+    )
+}
+# Arima
+arima_attribute_map = {
+    AttributeName.PREDICT_LENGTH.value: IntAttribute(
+        name=AttributeName.PREDICT_LENGTH.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.ORDER.value: TupleAttribute(
+        name=AttributeName.ORDER.value,
+        default_value=(1, 0, 0),
+        value_type=int
+    ),
+    AttributeName.SEASONAL_ORDER.value: TupleAttribute(
+        name=AttributeName.SEASONAL_ORDER.value,
+        default_value=(0, 0, 0, 0),
+        value_type=int
+    ),
+    AttributeName.METHOD.value: StringAttribute(
+        name=AttributeName.METHOD.value,
+        default_value="lbfgs",
+        value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"],
+    ),
+    AttributeName.MAXITER.value: IntAttribute(
+        name=AttributeName.MAXITER.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute(
+        name=AttributeName.SUPPRESS_WARNINGS.value,
+        default_value=True,
+    ),
+    AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute(
+        name=AttributeName.OUT_OF_SAMPLE_SIZE.value,
+        default_value=0,
+        default_low=0,
+        default_high=5000
+    ),
+    AttributeName.SCORING.value: StringAttribute(
+        name=AttributeName.SCORING.value,
+        default_value="mse",
+        value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"],
+    ),
+    AttributeName.WITH_INTERCEPT.value: BooleanAttribute(
+        name=AttributeName.WITH_INTERCEPT.value,
+        default_value=True,
+    ),
+    AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute(
+        name=AttributeName.TIME_VARYING_REGRESSION.value,
+        default_value=False,
+    ),
+    AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute(
+        name=AttributeName.ENFORCE_STATIONARITY.value,
+        default_value=True,
+    ),
+    AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute(
+        name=AttributeName.ENFORCE_INVERTIBILITY.value,
+        default_value=True,
+    ),
+    AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute(
+        name=AttributeName.SIMPLE_DIFFERENCING.value,
+        default_value=False,
+    ),
+    AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute(
+        name=AttributeName.MEASUREMENT_ERROR.value,
+        default_value=False,
+    ),
+    AttributeName.MLE_REGRESSION.value: BooleanAttribute(
+        name=AttributeName.MLE_REGRESSION.value,
+        default_value=True,
+    ),
+    AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute(
+        name=AttributeName.HAMILTON_REPRESENTATION.value,
+        default_value=False,
+    ),
+    AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute(
+        name=AttributeName.CONCENTRATE_SCALE.value,
+        default_value=False,
+    )
+}
+# STLForecaster
+stl_forecaster_attribute_map = {
+    AttributeName.PREDICT_LENGTH.value: IntAttribute(
+        name=AttributeName.PREDICT_LENGTH.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.SP.value: IntAttribute(
+        name=AttributeName.SP.value,
+        default_value=2,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.SEASONAL.value: IntAttribute(
+        name=AttributeName.SEASONAL.value,
+        default_value=7,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.SEASONAL_DEG.value: IntAttribute(
+        name=AttributeName.SEASONAL_DEG.value,
+        default_value=1,
+        default_low=0,
+        default_high=5000
+    ),
+    AttributeName.TREND_DEG.value: IntAttribute(
+        name=AttributeName.TREND_DEG.value,
+        default_value=1,
+        default_low=0,
+        default_high=5000
+    ),
+    AttributeName.LOW_PASS_DEG.value: IntAttribute(
+        name=AttributeName.LOW_PASS_DEG.value,
+        default_value=1,
+        default_low=0,
+        default_high=5000
+    ),
+    AttributeName.SEASONAL_JUMP.value: IntAttribute(
+        name=AttributeName.SEASONAL_JUMP.value,
+        default_value=1,
+        default_low=0,
+        default_high=5000
+    ),
+    AttributeName.TREND_JUMP.value: IntAttribute(
+        name=AttributeName.TREND_JUMP.value,
+        default_value=1,
+        default_low=0,
+        default_high=5000
+    ),
+    AttributeName.LOSS_PASS_JUMP.value: IntAttribute(
+        name=AttributeName.LOSS_PASS_JUMP.value,
+        default_value=1,
+        default_low=0,
+        default_high=5000
+    ),
+}
+
+# GAUSSIAN_HMM
+gaussian_hmm_attribute_map = {
+    AttributeName.N_COMPONENTS.value: IntAttribute(
+        name=AttributeName.N_COMPONENTS.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.COVARIANCE_TYPE.value: StringAttribute(
+        name=AttributeName.COVARIANCE_TYPE.value,
+        default_value="diag",
+        value_choices=["spherical", "diag", "full", "tied"],
+    ),
+    AttributeName.MIN_COVAR.value: FloatAttribute(
+        name=AttributeName.MIN_COVAR.value,
+        default_value=1e-3,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.STARTPROB_PRIOR.value: FloatAttribute(
+        name=AttributeName.STARTPROB_PRIOR.value,
+        default_value=1.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.TRANSMAT_PRIOR.value: FloatAttribute(
+        name=AttributeName.TRANSMAT_PRIOR.value,
+        default_value=1.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.MEANS_PRIOR.value: FloatAttribute(
+        name=AttributeName.MEANS_PRIOR.value,
+        default_value=0.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.MEANS_WEIGHT.value: FloatAttribute(
+        name=AttributeName.MEANS_WEIGHT.value,
+        default_value=0.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.COVARS_PRIOR.value: FloatAttribute(
+        name=AttributeName.COVARS_PRIOR.value,
+        default_value=1e-2,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.COVARS_WEIGHT.value: FloatAttribute(
+        name=AttributeName.COVARS_WEIGHT.value,
+        default_value=1.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.ALGORITHM.value: StringAttribute(
+        name=AttributeName.ALGORITHM.value,
+        default_value="viterbi",
+        value_choices=["viterbi", "map"],
+    ),
+    AttributeName.N_ITER.value: IntAttribute(
+        name=AttributeName.N_ITER.value,
+        default_value=10,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.TOL.value: FloatAttribute(
+        name=AttributeName.TOL.value,
+        default_value=1e-2,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.PARAMS.value: StringAttribute(
+        name=AttributeName.PARAMS.value,
+        default_value="stmc",
+        value_choices=["stmc", "stm"],
+    ),
+    AttributeName.INIT_PARAMS.value: StringAttribute(
+        name=AttributeName.INIT_PARAMS.value,
+        default_value="stmc",
+        value_choices=["stmc", "stm"],
+    ),
+    AttributeName.IMPLEMENTATION.value: StringAttribute(
+        name=AttributeName.IMPLEMENTATION.value,
+        default_value="log",
+        value_choices=["log", "scaling"],
+    )
+}
+
+# GMMHMM
+gmmhmm_attribute_map = {
+    AttributeName.N_COMPONENTS.value: IntAttribute(
+        name=AttributeName.N_COMPONENTS.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.N_MIX.value: IntAttribute(
+        name=AttributeName.N_MIX.value,
+        default_value=1,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.MIN_COVAR.value: FloatAttribute(
+        name=AttributeName.MIN_COVAR.value,
+        default_value=1e-3,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.STARTPROB_PRIOR.value: FloatAttribute(
+        name=AttributeName.STARTPROB_PRIOR.value,
+        default_value=1.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.TRANSMAT_PRIOR.value: FloatAttribute(
+        name=AttributeName.TRANSMAT_PRIOR.value,
+        default_value=1.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.WEIGHTS_PRIOR.value: FloatAttribute(
+        name=AttributeName.WEIGHTS_PRIOR.value,
+        default_value=1.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.MEANS_PRIOR.value: FloatAttribute(
+        name=AttributeName.MEANS_PRIOR.value,
+        default_value=0.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.MEANS_WEIGHT.value: FloatAttribute(
+        name=AttributeName.MEANS_WEIGHT.value,
+        default_value=0.0,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.ALGORITHM.value: StringAttribute(
+        name=AttributeName.ALGORITHM.value,
+        default_value="viterbi",
+        value_choices=["viterbi", "map"],
+    ),
+    AttributeName.COVARIANCE_TYPE.value: StringAttribute(
+        name=AttributeName.COVARIANCE_TYPE.value,
+        default_value="diag",
+        value_choices=["sperical", "diag", "full", "tied"],
+    ),
+    AttributeName.N_ITER.value: IntAttribute(
+        name=AttributeName.N_ITER.value,
+        default_value=10,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.TOL.value: FloatAttribute(
+        name=AttributeName.TOL.value,
+        default_value=1e-2,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.INIT_PARAMS.value: StringAttribute(
+        name=AttributeName.INIT_PARAMS.value,
+        default_value="stmcw",
+        value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm", 
"tc", "tw", "mc", "mw", "cw", "stm",
+                       "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw", 
"mcw", "stmc", "stmw", "stcw", "smcw",
+                       "tmcw", "stmcw"]
+    ),
+    AttributeName.PARAMS.value: StringAttribute(
+        name=AttributeName.PARAMS.value,
+        default_value="stmcw",
+        value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm", 
"tc", "tw", "mc", "mw", "cw", "stm",
+                       "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw", 
"mcw", "stmc", "stmw", "stcw", "smcw",
+                       "tmcw", "stmcw"]
+    ),
+    AttributeName.IMPLEMENTATION.value: StringAttribute(
+        name=AttributeName.IMPLEMENTATION.value,
+        default_value="log",
+        value_choices=["log", "scaling"],
+    )
+}
+
+# STRAY
+stray_attribute_map = {
+    AttributeName.ALPHA.value: FloatAttribute(
+        name=AttributeName.ALPHA.value,
+        default_value=0.01,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.K.value: IntAttribute(
+        name=AttributeName.K.value,
+        default_value=10,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.KNN_ALGORITHM.value: StringAttribute(
+        name=AttributeName.KNN_ALGORITHM.value,
+        default_value="brute",
+        value_choices=["brute", "kd_tree", "ball_tree", "auto"],
+    ),
+    AttributeName.P.value: FloatAttribute(
+        name=AttributeName.P.value,
+        default_value=0.5,
+        default_low=-1e10,
+        default_high=1e10,
+    ),
+    AttributeName.SIZE_THRESHOLD.value: IntAttribute(
+        name=AttributeName.SIZE_THRESHOLD.value,
+        default_value=50,
+        default_low=1,
+        default_high=5000
+    ),
+    AttributeName.OUTLIER_TAIL.value: StringAttribute(
+        name=AttributeName.OUTLIER_TAIL.value,
+        default_value="max",
+        value_choices=["min", "max"],
+    ),
+}
+
+
+class BuiltInModel(object):
+    def __init__(self, attributes):
+        self._attributes = attributes
+        self._model = None
+
+    @abstractmethod
+    def inference(self, data):
+        raise NotImplementedError
+
+
+class ArimaModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(ArimaModel, self).__init__(attributes)
+        self._model = ARIMA(
+            order=attributes['order'],
+            seasonal_order=attributes['seasonal_order'],
+            method=attributes['method'],
+            suppress_warnings=attributes['suppress_warnings'],
+            maxiter=attributes['maxiter'],
+            out_of_sample_size=attributes['out_of_sample_size'],
+            scoring=attributes['scoring'],
+            with_intercept=attributes['with_intercept'],
+            time_varying_regression=attributes['time_varying_regression'],
+            enforce_stationarity=attributes['enforce_stationarity'],
+            enforce_invertibility=attributes['enforce_invertibility'],
+            simple_differencing=attributes['simple_differencing'],
+            measurement_error=attributes['measurement_error'],
+            mle_regression=attributes['mle_regression'],
+            hamilton_representation=attributes['hamilton_representation'],
+            concentrate_scale=attributes['concentrate_scale']
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes['predict_length']
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class ExponentialSmoothingModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(ExponentialSmoothingModel, self).__init__(attributes)
+        self._model = ExponentialSmoothing(
+            damped_trend=attributes['damped_trend'],
+            initialization_method=attributes['initialization_method'],
+            optimized=attributes['optimized'],
+            remove_bias=attributes['remove_bias'],
+            use_brute=attributes['use_brute']
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes['predict_length']
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class NaiveForecasterModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(NaiveForecasterModel, self).__init__(attributes)
+        self._model = NaiveForecaster(
+            strategy=attributes['strategy'],
+            sp=attributes['sp']
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes['predict_length']
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class STLForecasterModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(STLForecasterModel, self).__init__(attributes)
+        self._model = STLForecaster(
+            sp=attributes['sp'],
+            seasonal=attributes['seasonal'],
+            seasonal_deg=attributes['seasonal_deg'],
+            trend_deg=attributes['trend_deg'],
+            low_pass_deg=attributes['low_pass_deg'],
+            seasonal_jump=attributes['seasonal_jump'],
+            trend_jump=attributes['trend_jump'],
+            low_pass_jump=attributes['low_pass_jump']
+        )
+
+    def inference(self, data):
+        try:
+            predict_length = self._attributes['predict_length']
+            self._model.fit(data)
+            output = self._model.predict(fh=range(predict_length))
+            output = np.array(output, dtype=np.float64)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class GMMHMMModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(GMMHMMModel, self).__init__(attributes)
+        self._model = GMMHMM(
+            n_components=attributes['n_components'],
+            n_mix=attributes['n_mix'],
+            min_covar=attributes['min_covar'],
+            startprob_prior=attributes['startprob_prior'],
+            transmat_prior=attributes['transmat_prior'],
+            means_prior=attributes['means_prior'],
+            means_weight=attributes['means_weight'],
+            weights_prior=attributes['weights_prior'],
+            algorithm=attributes['algorithm'],
+            covariance_type=attributes['covariance_type'],
+            n_iter=attributes['n_iter'],
+            tol=attributes['tol'],
+            params=attributes['params'],
+            init_params=attributes['init_params'],
+            implementation=attributes['implementation']
+        )
+
+    def inference(self, data):
+        try:
+            self._model.fit(data)
+            output = self._model.predict(data)
+            output = np.array(output, dtype=np.int32)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class GaussianHmmModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(GaussianHmmModel, self).__init__(attributes)
+        self._model = GaussianHMM(
+            n_components=attributes['n_components'],
+            covariance_type=attributes['covariance_type'],
+            min_covar=attributes['min_covar'],
+            startprob_prior=attributes['startprob_prior'],
+            transmat_prior=attributes['transmat_prior'],
+            means_prior=attributes['means_prior'],
+            means_weight=attributes['means_weight'],
+            covars_prior=attributes['covars_prior'],
+            covars_weight=attributes['covars_weight'],
+            algorithm=attributes['algorithm'],
+            n_iter=attributes['n_iter'],
+            tol=attributes['tol'],
+            params=attributes['params'],
+            init_params=attributes['init_params'],
+            implementation=attributes['implementation']
+        )
+
+    def inference(self, data):
+        try:
+            self._model.fit(data)
+            output = self._model.predict(data)
+            output = np.array(output, dtype=np.int32)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
+
+
+class STRAYModel(BuiltInModel):
+    def __init__(self, attributes):
+        super(STRAYModel, self).__init__(attributes)
+        self._model = STRAY(
+            alpha=attributes['alpha'],
+            k=attributes['k'],
+            knn_algorithm=attributes['knn_algorithm'],
+            p=attributes['p'],
+            size_threshold=attributes['size_threshold'],
+            outlier_tail=attributes['outlier_tail']
+        )
+
+    def inference(self, data):
+        try:
+            data = MinMaxScaler().fit_transform(data)
+            output = self._model.fit_transform(data)
+            # change the output to int
+            output = np.array(output, dtype=np.int32)
+            return output
+        except Exception as e:
+            raise InferenceModelInternalError(str(e))
diff --git a/iotdb-core/ainode/ainode/core/model/model_factory.py 
b/iotdb-core/ainode/ainode/core/model/model_factory.py
new file mode 100644
index 00000000000..1700dd28eb6
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/model_factory.py
@@ -0,0 +1,235 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+import shutil
+from urllib.parse import urlparse, urljoin
+
+import yaml
+from requests import Session
+from requests.adapters import HTTPAdapter
+
+from ainode.core.constant import DEFAULT_RECONNECT_TIMES, 
DEFAULT_RECONNECT_TIMEOUT, DEFAULT_CHUNK_SIZE, \
+    DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME
+from ainode.core.exception import InvalidUriError, BadConfigValueError
+from ainode.core.log import Logger
+from ainode.core.util.serde import get_data_type_byte_from_str
+from ainode.thrift.ainode.ttypes import TConfigs
+
+HTTP_PREFIX = "http://";
+HTTPS_PREFIX = "https://";
+
+logger = Logger()
+
+
+def _parse_uri(uri):
+    """
+    Args:
+        uri (str): uri to parse
+    Returns:
+        is_network_path (bool): True if the url is a network path, False 
otherwise
+        parsed_uri (str): parsed uri to get related file
+    """
+
+    parse_result = urlparse(uri)
+    is_network_path = parse_result.scheme in ('http', 'https')
+    if is_network_path:
+        return True, uri
+
+    # handle file:// in uri
+    if parse_result.scheme == 'file':
+        uri = uri[7:]
+
+    # handle ~ in uri
+    uri = os.path.expanduser(uri)
+    return False, uri
+
+
+def _download_file(url: str, storage_path: str) -> None:
+    """
+    Args:
+        url: url of file to download
+        storage_path: path to save the file
+    Returns:
+        None
+    """
+    logger.debug(f"download file from {url} to {storage_path}")
+
+    session = Session()
+    adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES)
+    session.mount(HTTP_PREFIX, adapter)
+    session.mount(HTTPS_PREFIX, adapter)
+
+    response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True)
+    response.raise_for_status()
+
+    with open(storage_path, 'wb') as file:
+        for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE):
+            if chunk:
+                file.write(chunk)
+
+    logger.debug(f"download file from {url} to {storage_path} success")
+
+
+def _register_model_from_network(uri: str, model_storage_path: str,
+                                 config_storage_path: str) -> [TConfigs, str]:
+    """
+    Args:
+        uri: network dir path of model to register, where model.pt and 
config.yaml are required,
+            e.g. https://huggingface.co/user/modelname/resolve/main/
+        model_storage_path: path to save model.pt
+        config_storage_path: path to save config.yaml
+    Returns:
+        configs: TConfigs
+        attributes: str
+    """
+    # concat uri to get complete url
+    uri = uri if uri.endswith("/") else uri + "/"
+    target_model_path = urljoin(uri, DEFAULT_MODEL_FILE_NAME)
+    target_config_path = urljoin(uri, DEFAULT_CONFIG_FILE_NAME)
+
+    # download config file
+    _download_file(target_config_path, config_storage_path)
+
+    # read and parse config dict from config.yaml
+    with open(config_storage_path, 'r', encoding='utf-8') as file:
+        config_dict = yaml.safe_load(file)
+    configs, attributes = _parse_inference_config(config_dict)
+
+    # if config.yaml is correct, download model file
+    _download_file(target_model_path, model_storage_path)
+    return configs, attributes
+
+
+def _register_model_from_local(uri: str, model_storage_path: str,
+                               config_storage_path: str) -> [TConfigs, str]:
+    """
+    Args:
+        uri: local dir path of model to register, where model.pt and 
config.yaml are required,
+            e.g. /Users/admin/Desktop/model
+        model_storage_path: path to save model.pt
+        config_storage_path: path to save config.yaml
+    Returns:
+        configs: TConfigs
+        attributes: str
+    """
+    # concat uri to get complete path
+    target_model_path = os.path.join(uri, DEFAULT_MODEL_FILE_NAME)
+    target_config_path = os.path.join(uri, DEFAULT_CONFIG_FILE_NAME)
+
+    # check if file exist
+    exist_model_file = os.path.exists(target_model_path)
+    exist_config_file = os.path.exists(target_config_path)
+
+    configs = None
+    attributes = None
+    if exist_model_file and exist_config_file:
+        # copy config.yaml
+        logger.debug(f"copy file from {target_config_path} to 
{config_storage_path}")
+        shutil.copy(target_config_path, config_storage_path)
+        logger.debug(f"copy file from {target_config_path} to 
{config_storage_path} success")
+
+        # read and parse config dict from config.yaml
+        with open(config_storage_path, 'r', encoding='utf-8') as file:
+            config_dict = yaml.safe_load(file)
+        configs, attributes = _parse_inference_config(config_dict)
+
+        # if config.yaml is correct, copy model file
+        logger.debug(f"copy file from {target_model_path} to 
{model_storage_path}")
+        shutil.copy(target_model_path, model_storage_path)
+        logger.debug(f"copy file from {target_model_path} to 
{model_storage_path} success")
+
+    elif not exist_model_file or not exist_config_file:
+        raise InvalidUriError(uri)
+
+    return configs, attributes
+
+
+def _parse_inference_config(config_dict):
+    """
+    Args:
+        config_dict: dict
+            - configs: dict
+                - input_shape (list<i32>): input shape of the model and needs 
to be two-dimensional array like [96, 2]
+                - output_shape (list<i32>): output shape of the model and 
needs to be two-dimensional array like [96, 2]
+                - input_type (list<str>): input type of the model and each 
element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 
'text'], default float64
+                - output_type (list<str>): output type of the model and each 
element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 
'text'], default float64
+            - attributes: dict
+    Returns:
+        configs: TConfigs
+        attributes: str
+    """
+    configs = config_dict['configs']
+
+    # check if input_shape and output_shape are two-dimensional array
+    if not (isinstance(configs['input_shape'], list) and 
len(configs['input_shape']) == 2):
+        raise BadConfigValueError('input_shape', configs['input_shape'],
+                                  'input_shape should be a two-dimensional 
array.')
+    if not (isinstance(configs['output_shape'], list) and 
len(configs['output_shape']) == 2):
+        raise BadConfigValueError('output_shape', configs['output_shape'],
+                                  'output_shape should be a two-dimensional 
array.')
+
+    # check if input_shape and output_shape are positive integer
+    input_shape_is_positive_number = isinstance(configs['input_shape'][0], 
int) and isinstance(
+        configs['input_shape'][1], int) and configs['input_shape'][0] > 0 and 
configs['input_shape'][1] > 0
+    if not input_shape_is_positive_number:
+        raise BadConfigValueError('input_shape', configs['input_shape'],
+                                  'element in input_shape should be positive 
integer.')
+
+    output_shape_is_positive_number = isinstance(configs['output_shape'][0], 
int) and isinstance(
+        configs['output_shape'][1], int) and configs['output_shape'][0] > 0 
and configs['output_shape'][1] > 0
+    if not output_shape_is_positive_number:
+        raise BadConfigValueError('output_shape', configs['output_shape'],
+                                  'element in output_shape should be positive 
integer.')
+
+    # check if input_type and output_type are one-dimensional array with right 
length
+    if 'input_type' in configs and not (
+            isinstance(configs['input_type'], list) and 
len(configs['input_type']) == configs['input_shape'][1]):
+        raise BadConfigValueError('input_type', configs['input_type'],
+                                  'input_type should be a one-dimensional 
array and length of it should be equal to input_shape[1].')
+
+    if 'output_type' in configs and not (
+            isinstance(configs['output_type'], list) and 
len(configs['output_type']) == configs['output_shape'][1]):
+        raise BadConfigValueError('output_type', configs['output_type'],
+                                  'output_type should be a one-dimensional 
array and length of it should be equal to output_shape[1].')
+
+    # parse input_type and output_type to byte
+    if 'input_type' in configs:
+        input_type = [get_data_type_byte_from_str(x) for x in 
configs['input_type']]
+    else:
+        input_type = [get_data_type_byte_from_str('float32')] * 
configs['input_shape'][1]
+
+    if 'output_type' in configs:
+        output_type = [get_data_type_byte_from_str(x) for x in 
configs['output_type']]
+    else:
+        output_type = [get_data_type_byte_from_str('float32')] * 
configs['output_shape'][1]
+
+    # parse attributes
+    attributes = ""
+    if 'attributes' in config_dict:
+        attributes = str(config_dict['attributes'])
+
+    return TConfigs(configs['input_shape'], configs['output_shape'], 
input_type, output_type), attributes
+
+
+def fetch_model_by_uri(uri: str, model_storage_path: str, config_storage_path: 
str):
+    is_network_path, uri = _parse_uri(uri)
+
+    if is_network_path:
+        return _register_model_from_network(uri, model_storage_path, 
config_storage_path)
+    else:
+        return _register_model_from_local(uri, model_storage_path, 
config_storage_path)
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py 
b/iotdb-core/ainode/ainode/core/model/model_storage.py
new file mode 100644
index 00000000000..a14e1ef4cc7
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import shutil
+from collections.abc import Callable
+
+import torch
+import torch._dynamo
+from pylru import lrucache
+
+from ainode.core.config import AINodeDescriptor
+from ainode.core.constant import (DEFAULT_MODEL_FILE_NAME,
+                                  DEFAULT_CONFIG_FILE_NAME)
+from ainode.core.exception import ModelNotExistError
+from ainode.core.log import Logger
+from ainode.core.model.model_factory import fetch_model_by_uri
+from ainode.core.util.lock import ModelLockPool
+from ainode.core.manager.training_manager import get_args
+from ainode.core.model.TimerXL.models.timer_xl import Model
+
+logger = Logger()
+
+
+class ModelStorage(object):
+    def __init__(self):
+        self._model_dir = os.path.join(os.getcwd(), 
AINodeDescriptor().get_config().get_ain_models_dir())
+        if not os.path.exists(self._model_dir):
+            try:
+                os.makedirs(self._model_dir)
+            except PermissionError as e:
+                logger.error(e)
+                raise e
+        self._lock_pool = ModelLockPool()
+        self._model_cache = 
lrucache(AINodeDescriptor().get_config().get_ain_model_storage_cache_size())
+
+    def register_model(self, model_id: str, uri: str):
+        """
+        Args:
+            model_id: id of model to register
+            uri: network dir path or local dir path of model to register, 
where model.pt and config.yaml are required,
+                e.g. https://huggingface.co/user/modelname/resolve/main/ or 
/Users/admin/Desktop/model
+        Returns:
+            configs: TConfigs
+            attributes: str
+        """
+        storage_path = os.path.join(self._model_dir, f'{model_id}')
+        # create storage dir if not exist
+        if not os.path.exists(storage_path):
+            os.makedirs(storage_path)
+        model_storage_path = os.path.join(storage_path, 
DEFAULT_MODEL_FILE_NAME)
+        config_storage_path = os.path.join(storage_path, 
DEFAULT_CONFIG_FILE_NAME)
+        return fetch_model_by_uri(uri, model_storage_path, config_storage_path)
+
+    def load_model(self, model_id: str, acceleration: bool) -> Callable:
+        """
+        Returns:
+            model: a ScriptModule contains model architecture and parameters, 
which can be deployed cross-platform
+        """
+        ain_models_dir = os.path.join(self._model_dir, f'{model_id}')
+        model_path = os.path.join(ain_models_dir, DEFAULT_MODEL_FILE_NAME)
+        with self._lock_pool.get_lock(model_id).read_lock():
+            if model_path in self._model_cache:
+                model = self._model_cache[model_path]
+                if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) 
or not acceleration:
+                    return model
+                else:
+                    model = torch.compile(model)
+                    self._model_cache[model_path] = model
+                    return model
+            else:
+                # todo: use modelType instead
+                if 'timer' in model_id:
+                    model_file_path = os.path.join(ain_models_dir, 
'checkpoint.pth')
+                    state_dict = torch.load(model_file_path)
+                    model = Model(get_args())
+                    model.load_state_dict(state_dict)
+                    self._model_cache[model_path] = model
+                    return model
+                elif not os.path.exists(model_path):
+                    raise ModelNotExistError(model_path)
+                else:
+                    model = torch.jit.load(model_path)
+                    if acceleration:
+                        try:
+                            model = torch.compile(model)
+                        except Exception as e:
+                            logger.warning(f"acceleration failed, fallback to 
normal mode: {str(e)}")
+                    self._model_cache[model_path] = model
+                    return model
+
+    def delete_model(self, model_id: str) -> None:
+        """
+        Args:
+            model_id: id of model to delete
+        Returns:
+            None
+        """
+        storage_path = os.path.join(self._model_dir, f'{model_id}')
+        with self._lock_pool.get_lock(model_id).write_lock():
+            if os.path.exists(storage_path):
+                for file_name in os.listdir(storage_path):
+                    self._remove_from_cache(os.path.join(storage_path, 
file_name))
+                shutil.rmtree(storage_path)
+
+    def _remove_from_cache(self, file_path: str) -> None:
+        if file_path in self._model_cache:
+            del self._model_cache[file_path]

Reply via email to