This is an automated email from the ASF dual-hosted git repository.
ycycse pushed a commit to branch ycy/AINodeTraining
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/ycy/AINodeTraining by this
push:
new 8404a7c5c56 restore model part
8404a7c5c56 is described below
commit 8404a7c5c564be46f1f493746087d9ff2fc2c03d
Author: YangCaiyin <[email protected]>
AuthorDate: Wed Mar 5 19:00:05 2025 +0800
restore model part
---
iotdb-core/ainode/ainode/core/model/__init__.py | 17 +
.../ainode/core/model/built_in_model_factory.py | 924 +++++++++++++++++++++
.../ainode/ainode/core/model/model_factory.py | 235 ++++++
.../ainode/ainode/core/model/model_storage.py | 123 +++
4 files changed, 1299 insertions(+)
diff --git a/iotdb-core/ainode/ainode/core/model/__init__.py
b/iotdb-core/ainode/ainode/core/model/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
new file mode 100644
index 00000000000..4524272dfaf
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -0,0 +1,924 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from abc import abstractmethod
+from typing import List, Dict
+
+import numpy as np
+from sklearn.preprocessing import MinMaxScaler
+from sktime.annotation.hmm_learn import GaussianHMM, GMMHMM
+from sktime.annotation.stray import STRAY
+from sktime.forecasting.arima import ARIMA
+from sktime.forecasting.exp_smoothing import ExponentialSmoothing
+from sktime.forecasting.naive import NaiveForecaster
+from sktime.forecasting.trend import STLForecaster
+
+from ainode.core.constant import AttributeName, BuiltInModelType
+from ainode.core.exception import InferenceModelInternalError,
AttributeNotSupportError
+from ainode.core.exception import WrongAttributeTypeError,
NumericalRangeException, StringRangeException, \
+ ListRangeException, BuiltInModelNotSupportError
+from ainode.core.log import Logger
+
+logger = Logger()
+
+
+def get_model_attributes(model_id: str):
+ if model_id == BuiltInModelType.ARIMA.value:
+ attribute_map = arima_attribute_map
+ elif model_id == BuiltInModelType.NAIVE_FORECASTER.value:
+ attribute_map = naive_forecaster_attribute_map
+ elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value:
+ attribute_map = exponential_smoothing_attribute_map
+ elif model_id == BuiltInModelType.STL_FORECASTER.value:
+ attribute_map = stl_forecaster_attribute_map
+ elif model_id == BuiltInModelType.GMM_HMM.value:
+ attribute_map = gmmhmm_attribute_map
+ elif model_id == BuiltInModelType.GAUSSIAN_HMM.value:
+ attribute_map = gaussian_hmm_attribute_map
+ elif model_id == BuiltInModelType.STRAY.value:
+ attribute_map = stray_attribute_map
+ else:
+ raise BuiltInModelNotSupportError(model_id)
+ return attribute_map
+
+
+def fetch_built_in_model(model_id, inference_attributes):
+ """
+ Args:
+ model_id: the unique id of the model
+ inference_attributes: a list of attributes to be inferred, in this
function, the attributes will include some
+ parameters of the built-in model. Some parameters are optional,
and if the parameters are not
+ specified, the default value will be used.
+ Returns:
+ model: the built-in model
+ attributes: a dict of attributes, where the key is the attribute name,
the value is the parsed value of the
+ attribute
+ Description:
+ the create_built_in_model function will create the built-in model,
which does not require user
+ registration. This module will parse the inference attributes and
create the built-in model.
+ """
+ attribute_map = get_model_attributes(model_id)
+
+ # validate the inference attributes
+ for attribute_name in inference_attributes:
+ if attribute_name not in attribute_map:
+ raise AttributeNotSupportError(model_id, attribute_name)
+
+ # parse the inference attributes, attributes is a Dict[str, Any]
+ attributes = parse_attribute(inference_attributes, attribute_map)
+
+ # build the built-in model
+ if model_id == BuiltInModelType.ARIMA.value:
+ model = ArimaModel(attributes)
+ elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value:
+ model = ExponentialSmoothingModel(attributes)
+ elif model_id == BuiltInModelType.NAIVE_FORECASTER.value:
+ model = NaiveForecasterModel(attributes)
+ elif model_id == BuiltInModelType.STL_FORECASTER.value:
+ model = STLForecasterModel(attributes)
+ elif model_id == BuiltInModelType.GMM_HMM.value:
+ model = GMMHMMModel(attributes)
+ elif model_id == BuiltInModelType.GAUSSIAN_HMM.value:
+ model = GaussianHmmModel(attributes)
+ elif model_id == BuiltInModelType.STRAY.value:
+ model = STRAYModel(attributes)
+ else:
+ raise BuiltInModelNotSupportError(model_id)
+
+ return model
+
+
+class Attribute(object):
+ def __init__(self, name: str):
+ """
+ Args:
+ name: the name of the attribute
+ """
+ self._name = name
+
+ @abstractmethod
+ def get_default_value(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def validate_value(self, value):
+ raise NotImplementedError
+
+ @abstractmethod
+ def parse(self, string_value: str):
+ raise NotImplementedError
+
+
+class IntAttribute(Attribute):
+ def __init__(self, name: str,
+ default_value: int,
+ default_low: int,
+ default_high: int,
+ ):
+ super(IntAttribute, self).__init__(name)
+ self.__default_value = default_value
+ self.__default_low = default_low
+ self.__default_high = default_high
+
+ def get_default_value(self):
+ return self.__default_value
+
+ def validate_value(self, value):
+ if self.__default_low <= value <= self.__default_high:
+ return True
+ raise NumericalRangeException(self._name, value, self.__default_low,
self.__default_high)
+
+ def parse(self, string_value: str):
+ try:
+ int_value = int(string_value)
+ except:
+ raise WrongAttributeTypeError(self._name, "int")
+ return int_value
+
+
+class FloatAttribute(Attribute):
+ def __init__(self, name: str,
+ default_value: float,
+ default_low: float,
+ default_high: float,
+ ):
+ super(FloatAttribute, self).__init__(name)
+ self.__default_value = default_value
+ self.__default_low = default_low
+ self.__default_high = default_high
+
+ def get_default_value(self):
+ return self.__default_value
+
+ def validate_value(self, value):
+ if self.__default_low <= value <= self.__default_high:
+ return True
+ raise NumericalRangeException(self._name, value, self.__default_low,
self.__default_high)
+
+ def parse(self, string_value: str):
+ try:
+ float_value = float(string_value)
+ except:
+ raise WrongAttributeTypeError(self._name, "float")
+ return float_value
+
+
+class StringAttribute(Attribute):
+ def __init__(self, name: str, default_value: str, value_choices:
List[str]):
+ super(StringAttribute, self).__init__(name)
+ self.__default_value = default_value
+ self.__value_choices = value_choices
+
+ def get_default_value(self):
+ return self.__default_value
+
+ def validate_value(self, value):
+ if value in self.__value_choices:
+ return True
+ raise StringRangeException(self._name, value, self.__value_choices)
+
+ def parse(self, string_value: str):
+ return string_value
+
+
+class BooleanAttribute(Attribute):
+ def __init__(self, name: str, default_value: bool):
+ super(BooleanAttribute, self).__init__(name)
+ self.__default_value = default_value
+
+ def get_default_value(self):
+ return self.__default_value
+
+ def validate_value(self, value):
+ if isinstance(value, bool):
+ return True
+ raise WrongAttributeTypeError(self._name, "bool")
+
+ def parse(self, string_value: str):
+ if string_value.lower() == "true":
+ return True
+ elif string_value.lower() == "false":
+ return False
+ else:
+ raise WrongAttributeTypeError(self._name, "bool")
+
+
+class ListAttribute(Attribute):
+ def __init__(self, name: str, default_value: List, value_type):
+ """
+ value_type is the type of the elements in the list, e.g. int, float,
str
+ """
+ super(ListAttribute, self).__init__(name)
+ self.__default_value = default_value
+ self.__value_type = value_type
+ self.__type_to_str = {str: "str", int: "int", float: "float"}
+
+ def get_default_value(self):
+ return self.__default_value
+
+ def validate_value(self, value):
+ if not isinstance(value, list):
+ raise WrongAttributeTypeError(self._name, "list")
+ for value_item in value:
+ if not isinstance(value_item, self.__value_type):
+ raise WrongAttributeTypeError(self._name, self.__value_type)
+ return True
+
+ def parse(self, string_value: str):
+ try:
+ list_value = eval(string_value)
+ except:
+ raise WrongAttributeTypeError(self._name, "list")
+ if not isinstance(list_value, list):
+ raise WrongAttributeTypeError(self._name, "list")
+ for i in range(len(list_value)):
+ try:
+ list_value[i] = self.__value_type(list_value[i])
+ except:
+ raise ListRangeException(self._name, list_value,
self.__type_to_str[self.__value_type])
+ return list_value
+
+
+class TupleAttribute(Attribute):
+ def __init__(self, name: str, default_value: tuple, value_type):
+ """
+ value_type is the type of the elements in the list, e.g. int, float,
str
+ """
+ super(TupleAttribute, self).__init__(name)
+ self.__default_value = default_value
+ self.__value_type = value_type
+ self.__type_to_str = {str: "str", int: "int", float: "float"}
+
+ def get_default_value(self):
+ return self.__default_value
+
+ def validate_value(self, value):
+ if not isinstance(value, tuple):
+ raise WrongAttributeTypeError(self._name, "tuple")
+ for value_item in value:
+ if not isinstance(value_item, self.__value_type):
+ raise WrongAttributeTypeError(self._name, self.__value_type)
+ return True
+
+ def parse(self, string_value: str):
+ try:
+ tuple_value = eval(string_value)
+ except:
+ raise WrongAttributeTypeError(self._name, "tuple")
+ if not isinstance(tuple_value, tuple):
+ raise WrongAttributeTypeError(self._name, "tuple")
+ list_value = list(tuple_value)
+ for i in range(len(list_value)):
+ try:
+ list_value[i] = self.__value_type(list_value[i])
+ except:
+ raise ListRangeException(self._name, list_value,
self.__type_to_str[self.__value_type])
+ tuple_value = tuple(list_value)
+ return tuple_value
+
+
+def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str,
Attribute]):
+ """
+ Args:
+ input_attributes: a dict of attributes, where the key is the attribute
name, the value is the string value of
+ the attribute
+ attribute_map: a dict of hyperparameters, where the key is the
attribute name, the value is the Attribute
+ object
+ Returns:
+ a dict of attributes, where the key is the attribute name, the value
is the parsed value of the attribute
+ """
+ attributes = {}
+ for attribute_name in attribute_map:
+ # user specified the attribute
+ if attribute_name in input_attributes:
+ attribute = attribute_map[attribute_name]
+ value = attribute.parse(input_attributes[attribute_name])
+ attribute.validate_value(value)
+ attributes[attribute_name] = value
+ # user did not specify the attribute, use the default value
+ else:
+ try:
+ attributes[attribute_name] =
attribute_map[attribute_name].get_default_value()
+ except NotImplementedError as e:
+ logger.error(f"attribute {attribute_name} is not implemented.")
+ raise e
+ return attributes
+
+
+# built-in sktime model attributes
+# NaiveForecaster
+naive_forecaster_attribute_map = {
+ AttributeName.PREDICT_LENGTH.value: IntAttribute(
+ name=AttributeName.PREDICT_LENGTH.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.STRATEGY.value: StringAttribute(
+ name=AttributeName.STRATEGY.value,
+ default_value="last",
+ value_choices=["last", "mean"],
+ ),
+ AttributeName.SP.value: IntAttribute(
+ name=AttributeName.SP.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+}
+# ExponentialSmoothing
+exponential_smoothing_attribute_map = {
+ AttributeName.PREDICT_LENGTH.value: IntAttribute(
+ name=AttributeName.PREDICT_LENGTH.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.DAMPED_TREND.value: BooleanAttribute(
+ name=AttributeName.DAMPED_TREND.value,
+ default_value=False,
+ ),
+ AttributeName.INITIALIZATION_METHOD.value: StringAttribute(
+ name=AttributeName.INITIALIZATION_METHOD.value,
+ default_value="estimated",
+ value_choices=["estimated", "heuristic", "legacy-heuristic", "known"],
+ ),
+ AttributeName.OPTIMIZED.value: BooleanAttribute(
+ name=AttributeName.OPTIMIZED.value,
+ default_value=True,
+ ),
+ AttributeName.REMOVE_BIAS.value: BooleanAttribute(
+ name=AttributeName.REMOVE_BIAS.value,
+ default_value=False,
+ ),
+ AttributeName.USE_BRUTE.value: BooleanAttribute(
+ name=AttributeName.USE_BRUTE.value,
+ default_value=False,
+ )
+}
+# Arima
+arima_attribute_map = {
+ AttributeName.PREDICT_LENGTH.value: IntAttribute(
+ name=AttributeName.PREDICT_LENGTH.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.ORDER.value: TupleAttribute(
+ name=AttributeName.ORDER.value,
+ default_value=(1, 0, 0),
+ value_type=int
+ ),
+ AttributeName.SEASONAL_ORDER.value: TupleAttribute(
+ name=AttributeName.SEASONAL_ORDER.value,
+ default_value=(0, 0, 0, 0),
+ value_type=int
+ ),
+ AttributeName.METHOD.value: StringAttribute(
+ name=AttributeName.METHOD.value,
+ default_value="lbfgs",
+ value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"],
+ ),
+ AttributeName.MAXITER.value: IntAttribute(
+ name=AttributeName.MAXITER.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute(
+ name=AttributeName.SUPPRESS_WARNINGS.value,
+ default_value=True,
+ ),
+ AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute(
+ name=AttributeName.OUT_OF_SAMPLE_SIZE.value,
+ default_value=0,
+ default_low=0,
+ default_high=5000
+ ),
+ AttributeName.SCORING.value: StringAttribute(
+ name=AttributeName.SCORING.value,
+ default_value="mse",
+ value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"],
+ ),
+ AttributeName.WITH_INTERCEPT.value: BooleanAttribute(
+ name=AttributeName.WITH_INTERCEPT.value,
+ default_value=True,
+ ),
+ AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute(
+ name=AttributeName.TIME_VARYING_REGRESSION.value,
+ default_value=False,
+ ),
+ AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute(
+ name=AttributeName.ENFORCE_STATIONARITY.value,
+ default_value=True,
+ ),
+ AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute(
+ name=AttributeName.ENFORCE_INVERTIBILITY.value,
+ default_value=True,
+ ),
+ AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute(
+ name=AttributeName.SIMPLE_DIFFERENCING.value,
+ default_value=False,
+ ),
+ AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute(
+ name=AttributeName.MEASUREMENT_ERROR.value,
+ default_value=False,
+ ),
+ AttributeName.MLE_REGRESSION.value: BooleanAttribute(
+ name=AttributeName.MLE_REGRESSION.value,
+ default_value=True,
+ ),
+ AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute(
+ name=AttributeName.HAMILTON_REPRESENTATION.value,
+ default_value=False,
+ ),
+ AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute(
+ name=AttributeName.CONCENTRATE_SCALE.value,
+ default_value=False,
+ )
+}
+# STLForecaster
+stl_forecaster_attribute_map = {
+ AttributeName.PREDICT_LENGTH.value: IntAttribute(
+ name=AttributeName.PREDICT_LENGTH.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.SP.value: IntAttribute(
+ name=AttributeName.SP.value,
+ default_value=2,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.SEASONAL.value: IntAttribute(
+ name=AttributeName.SEASONAL.value,
+ default_value=7,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.SEASONAL_DEG.value: IntAttribute(
+ name=AttributeName.SEASONAL_DEG.value,
+ default_value=1,
+ default_low=0,
+ default_high=5000
+ ),
+ AttributeName.TREND_DEG.value: IntAttribute(
+ name=AttributeName.TREND_DEG.value,
+ default_value=1,
+ default_low=0,
+ default_high=5000
+ ),
+ AttributeName.LOW_PASS_DEG.value: IntAttribute(
+ name=AttributeName.LOW_PASS_DEG.value,
+ default_value=1,
+ default_low=0,
+ default_high=5000
+ ),
+ AttributeName.SEASONAL_JUMP.value: IntAttribute(
+ name=AttributeName.SEASONAL_JUMP.value,
+ default_value=1,
+ default_low=0,
+ default_high=5000
+ ),
+ AttributeName.TREND_JUMP.value: IntAttribute(
+ name=AttributeName.TREND_JUMP.value,
+ default_value=1,
+ default_low=0,
+ default_high=5000
+ ),
+ AttributeName.LOSS_PASS_JUMP.value: IntAttribute(
+ name=AttributeName.LOSS_PASS_JUMP.value,
+ default_value=1,
+ default_low=0,
+ default_high=5000
+ ),
+}
+
+# GAUSSIAN_HMM
+gaussian_hmm_attribute_map = {
+ AttributeName.N_COMPONENTS.value: IntAttribute(
+ name=AttributeName.N_COMPONENTS.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.COVARIANCE_TYPE.value: StringAttribute(
+ name=AttributeName.COVARIANCE_TYPE.value,
+ default_value="diag",
+ value_choices=["spherical", "diag", "full", "tied"],
+ ),
+ AttributeName.MIN_COVAR.value: FloatAttribute(
+ name=AttributeName.MIN_COVAR.value,
+ default_value=1e-3,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.STARTPROB_PRIOR.value: FloatAttribute(
+ name=AttributeName.STARTPROB_PRIOR.value,
+ default_value=1.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.TRANSMAT_PRIOR.value: FloatAttribute(
+ name=AttributeName.TRANSMAT_PRIOR.value,
+ default_value=1.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.MEANS_PRIOR.value: FloatAttribute(
+ name=AttributeName.MEANS_PRIOR.value,
+ default_value=0.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.MEANS_WEIGHT.value: FloatAttribute(
+ name=AttributeName.MEANS_WEIGHT.value,
+ default_value=0.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.COVARS_PRIOR.value: FloatAttribute(
+ name=AttributeName.COVARS_PRIOR.value,
+ default_value=1e-2,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.COVARS_WEIGHT.value: FloatAttribute(
+ name=AttributeName.COVARS_WEIGHT.value,
+ default_value=1.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.ALGORITHM.value: StringAttribute(
+ name=AttributeName.ALGORITHM.value,
+ default_value="viterbi",
+ value_choices=["viterbi", "map"],
+ ),
+ AttributeName.N_ITER.value: IntAttribute(
+ name=AttributeName.N_ITER.value,
+ default_value=10,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.TOL.value: FloatAttribute(
+ name=AttributeName.TOL.value,
+ default_value=1e-2,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.PARAMS.value: StringAttribute(
+ name=AttributeName.PARAMS.value,
+ default_value="stmc",
+ value_choices=["stmc", "stm"],
+ ),
+ AttributeName.INIT_PARAMS.value: StringAttribute(
+ name=AttributeName.INIT_PARAMS.value,
+ default_value="stmc",
+ value_choices=["stmc", "stm"],
+ ),
+ AttributeName.IMPLEMENTATION.value: StringAttribute(
+ name=AttributeName.IMPLEMENTATION.value,
+ default_value="log",
+ value_choices=["log", "scaling"],
+ )
+}
+
+# GMMHMM
+gmmhmm_attribute_map = {
+ AttributeName.N_COMPONENTS.value: IntAttribute(
+ name=AttributeName.N_COMPONENTS.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.N_MIX.value: IntAttribute(
+ name=AttributeName.N_MIX.value,
+ default_value=1,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.MIN_COVAR.value: FloatAttribute(
+ name=AttributeName.MIN_COVAR.value,
+ default_value=1e-3,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.STARTPROB_PRIOR.value: FloatAttribute(
+ name=AttributeName.STARTPROB_PRIOR.value,
+ default_value=1.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.TRANSMAT_PRIOR.value: FloatAttribute(
+ name=AttributeName.TRANSMAT_PRIOR.value,
+ default_value=1.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.WEIGHTS_PRIOR.value: FloatAttribute(
+ name=AttributeName.WEIGHTS_PRIOR.value,
+ default_value=1.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.MEANS_PRIOR.value: FloatAttribute(
+ name=AttributeName.MEANS_PRIOR.value,
+ default_value=0.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.MEANS_WEIGHT.value: FloatAttribute(
+ name=AttributeName.MEANS_WEIGHT.value,
+ default_value=0.0,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.ALGORITHM.value: StringAttribute(
+ name=AttributeName.ALGORITHM.value,
+ default_value="viterbi",
+ value_choices=["viterbi", "map"],
+ ),
+ AttributeName.COVARIANCE_TYPE.value: StringAttribute(
+ name=AttributeName.COVARIANCE_TYPE.value,
+ default_value="diag",
+ value_choices=["sperical", "diag", "full", "tied"],
+ ),
+ AttributeName.N_ITER.value: IntAttribute(
+ name=AttributeName.N_ITER.value,
+ default_value=10,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.TOL.value: FloatAttribute(
+ name=AttributeName.TOL.value,
+ default_value=1e-2,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.INIT_PARAMS.value: StringAttribute(
+ name=AttributeName.INIT_PARAMS.value,
+ default_value="stmcw",
+ value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm",
"tc", "tw", "mc", "mw", "cw", "stm",
+ "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw",
"mcw", "stmc", "stmw", "stcw", "smcw",
+ "tmcw", "stmcw"]
+ ),
+ AttributeName.PARAMS.value: StringAttribute(
+ name=AttributeName.PARAMS.value,
+ default_value="stmcw",
+ value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm",
"tc", "tw", "mc", "mw", "cw", "stm",
+ "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw",
"mcw", "stmc", "stmw", "stcw", "smcw",
+ "tmcw", "stmcw"]
+ ),
+ AttributeName.IMPLEMENTATION.value: StringAttribute(
+ name=AttributeName.IMPLEMENTATION.value,
+ default_value="log",
+ value_choices=["log", "scaling"],
+ )
+}
+
+# STRAY
+stray_attribute_map = {
+ AttributeName.ALPHA.value: FloatAttribute(
+ name=AttributeName.ALPHA.value,
+ default_value=0.01,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.K.value: IntAttribute(
+ name=AttributeName.K.value,
+ default_value=10,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.KNN_ALGORITHM.value: StringAttribute(
+ name=AttributeName.KNN_ALGORITHM.value,
+ default_value="brute",
+ value_choices=["brute", "kd_tree", "ball_tree", "auto"],
+ ),
+ AttributeName.P.value: FloatAttribute(
+ name=AttributeName.P.value,
+ default_value=0.5,
+ default_low=-1e10,
+ default_high=1e10,
+ ),
+ AttributeName.SIZE_THRESHOLD.value: IntAttribute(
+ name=AttributeName.SIZE_THRESHOLD.value,
+ default_value=50,
+ default_low=1,
+ default_high=5000
+ ),
+ AttributeName.OUTLIER_TAIL.value: StringAttribute(
+ name=AttributeName.OUTLIER_TAIL.value,
+ default_value="max",
+ value_choices=["min", "max"],
+ ),
+}
+
+
+class BuiltInModel(object):
+ def __init__(self, attributes):
+ self._attributes = attributes
+ self._model = None
+
+ @abstractmethod
+ def inference(self, data):
+ raise NotImplementedError
+
+
+class ArimaModel(BuiltInModel):
+ def __init__(self, attributes):
+ super(ArimaModel, self).__init__(attributes)
+ self._model = ARIMA(
+ order=attributes['order'],
+ seasonal_order=attributes['seasonal_order'],
+ method=attributes['method'],
+ suppress_warnings=attributes['suppress_warnings'],
+ maxiter=attributes['maxiter'],
+ out_of_sample_size=attributes['out_of_sample_size'],
+ scoring=attributes['scoring'],
+ with_intercept=attributes['with_intercept'],
+ time_varying_regression=attributes['time_varying_regression'],
+ enforce_stationarity=attributes['enforce_stationarity'],
+ enforce_invertibility=attributes['enforce_invertibility'],
+ simple_differencing=attributes['simple_differencing'],
+ measurement_error=attributes['measurement_error'],
+ mle_regression=attributes['mle_regression'],
+ hamilton_representation=attributes['hamilton_representation'],
+ concentrate_scale=attributes['concentrate_scale']
+ )
+
+ def inference(self, data):
+ try:
+ predict_length = self._attributes['predict_length']
+ self._model.fit(data)
+ output = self._model.predict(fh=range(predict_length))
+ output = np.array(output, dtype=np.float64)
+ return output
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class ExponentialSmoothingModel(BuiltInModel):
+ def __init__(self, attributes):
+ super(ExponentialSmoothingModel, self).__init__(attributes)
+ self._model = ExponentialSmoothing(
+ damped_trend=attributes['damped_trend'],
+ initialization_method=attributes['initialization_method'],
+ optimized=attributes['optimized'],
+ remove_bias=attributes['remove_bias'],
+ use_brute=attributes['use_brute']
+ )
+
+ def inference(self, data):
+ try:
+ predict_length = self._attributes['predict_length']
+ self._model.fit(data)
+ output = self._model.predict(fh=range(predict_length))
+ output = np.array(output, dtype=np.float64)
+ return output
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class NaiveForecasterModel(BuiltInModel):
+ def __init__(self, attributes):
+ super(NaiveForecasterModel, self).__init__(attributes)
+ self._model = NaiveForecaster(
+ strategy=attributes['strategy'],
+ sp=attributes['sp']
+ )
+
+ def inference(self, data):
+ try:
+ predict_length = self._attributes['predict_length']
+ self._model.fit(data)
+ output = self._model.predict(fh=range(predict_length))
+ output = np.array(output, dtype=np.float64)
+ return output
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class STLForecasterModel(BuiltInModel):
+ def __init__(self, attributes):
+ super(STLForecasterModel, self).__init__(attributes)
+ self._model = STLForecaster(
+ sp=attributes['sp'],
+ seasonal=attributes['seasonal'],
+ seasonal_deg=attributes['seasonal_deg'],
+ trend_deg=attributes['trend_deg'],
+ low_pass_deg=attributes['low_pass_deg'],
+ seasonal_jump=attributes['seasonal_jump'],
+ trend_jump=attributes['trend_jump'],
+ low_pass_jump=attributes['low_pass_jump']
+ )
+
+ def inference(self, data):
+ try:
+ predict_length = self._attributes['predict_length']
+ self._model.fit(data)
+ output = self._model.predict(fh=range(predict_length))
+ output = np.array(output, dtype=np.float64)
+ return output
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class GMMHMMModel(BuiltInModel):
+ def __init__(self, attributes):
+ super(GMMHMMModel, self).__init__(attributes)
+ self._model = GMMHMM(
+ n_components=attributes['n_components'],
+ n_mix=attributes['n_mix'],
+ min_covar=attributes['min_covar'],
+ startprob_prior=attributes['startprob_prior'],
+ transmat_prior=attributes['transmat_prior'],
+ means_prior=attributes['means_prior'],
+ means_weight=attributes['means_weight'],
+ weights_prior=attributes['weights_prior'],
+ algorithm=attributes['algorithm'],
+ covariance_type=attributes['covariance_type'],
+ n_iter=attributes['n_iter'],
+ tol=attributes['tol'],
+ params=attributes['params'],
+ init_params=attributes['init_params'],
+ implementation=attributes['implementation']
+ )
+
+ def inference(self, data):
+ try:
+ self._model.fit(data)
+ output = self._model.predict(data)
+ output = np.array(output, dtype=np.int32)
+ return output
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class GaussianHmmModel(BuiltInModel):
+ def __init__(self, attributes):
+ super(GaussianHmmModel, self).__init__(attributes)
+ self._model = GaussianHMM(
+ n_components=attributes['n_components'],
+ covariance_type=attributes['covariance_type'],
+ min_covar=attributes['min_covar'],
+ startprob_prior=attributes['startprob_prior'],
+ transmat_prior=attributes['transmat_prior'],
+ means_prior=attributes['means_prior'],
+ means_weight=attributes['means_weight'],
+ covars_prior=attributes['covars_prior'],
+ covars_weight=attributes['covars_weight'],
+ algorithm=attributes['algorithm'],
+ n_iter=attributes['n_iter'],
+ tol=attributes['tol'],
+ params=attributes['params'],
+ init_params=attributes['init_params'],
+ implementation=attributes['implementation']
+ )
+
+ def inference(self, data):
+ try:
+ self._model.fit(data)
+ output = self._model.predict(data)
+ output = np.array(output, dtype=np.int32)
+ return output
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class STRAYModel(BuiltInModel):
+ def __init__(self, attributes):
+ super(STRAYModel, self).__init__(attributes)
+ self._model = STRAY(
+ alpha=attributes['alpha'],
+ k=attributes['k'],
+ knn_algorithm=attributes['knn_algorithm'],
+ p=attributes['p'],
+ size_threshold=attributes['size_threshold'],
+ outlier_tail=attributes['outlier_tail']
+ )
+
+ def inference(self, data):
+ try:
+ data = MinMaxScaler().fit_transform(data)
+ output = self._model.fit_transform(data)
+ # change the output to int
+ output = np.array(output, dtype=np.int32)
+ return output
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
diff --git a/iotdb-core/ainode/ainode/core/model/model_factory.py
b/iotdb-core/ainode/ainode/core/model/model_factory.py
new file mode 100644
index 00000000000..1700dd28eb6
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/model_factory.py
@@ -0,0 +1,235 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+import shutil
+from urllib.parse import urlparse, urljoin
+
+import yaml
+from requests import Session
+from requests.adapters import HTTPAdapter
+
+from ainode.core.constant import DEFAULT_RECONNECT_TIMES,
DEFAULT_RECONNECT_TIMEOUT, DEFAULT_CHUNK_SIZE, \
+ DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME
+from ainode.core.exception import InvalidUriError, BadConfigValueError
+from ainode.core.log import Logger
+from ainode.core.util.serde import get_data_type_byte_from_str
+from ainode.thrift.ainode.ttypes import TConfigs
+
+HTTP_PREFIX = "http://"
+HTTPS_PREFIX = "https://"
+
+logger = Logger()
+
+
+def _parse_uri(uri):
+ """
+ Args:
+ uri (str): uri to parse
+ Returns:
+ is_network_path (bool): True if the url is a network path, False
otherwise
+ parsed_uri (str): parsed uri to get related file
+ """
+
+ parse_result = urlparse(uri)
+ is_network_path = parse_result.scheme in ('http', 'https')
+ if is_network_path:
+ return True, uri
+
+ # handle file:// in uri
+ if parse_result.scheme == 'file':
+ uri = uri[7:]
+
+ # handle ~ in uri
+ uri = os.path.expanduser(uri)
+ return False, uri
+
+
+def _download_file(url: str, storage_path: str) -> None:
+ """
+ Args:
+ url: url of file to download
+ storage_path: path to save the file
+ Returns:
+ None
+ """
+ logger.debug(f"download file from {url} to {storage_path}")
+
+ session = Session()
+ adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES)
+ session.mount(HTTP_PREFIX, adapter)
+ session.mount(HTTPS_PREFIX, adapter)
+
+ response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True)
+ response.raise_for_status()
+
+ with open(storage_path, 'wb') as file:
+ for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE):
+ if chunk:
+ file.write(chunk)
+
+ logger.debug(f"download file from {url} to {storage_path} success")
+
+
+def _register_model_from_network(uri: str, model_storage_path: str,
+ config_storage_path: str) -> [TConfigs, str]:
+ """
+ Args:
+ uri: network dir path of model to register, where model.pt and
config.yaml are required,
+ e.g. https://huggingface.co/user/modelname/resolve/main/
+ model_storage_path: path to save model.pt
+ config_storage_path: path to save config.yaml
+ Returns:
+ configs: TConfigs
+ attributes: str
+ """
+ # concat uri to get complete url
+ uri = uri if uri.endswith("/") else uri + "/"
+ target_model_path = urljoin(uri, DEFAULT_MODEL_FILE_NAME)
+ target_config_path = urljoin(uri, DEFAULT_CONFIG_FILE_NAME)
+
+ # download config file
+ _download_file(target_config_path, config_storage_path)
+
+ # read and parse config dict from config.yaml
+ with open(config_storage_path, 'r', encoding='utf-8') as file:
+ config_dict = yaml.safe_load(file)
+ configs, attributes = _parse_inference_config(config_dict)
+
+ # if config.yaml is correct, download model file
+ _download_file(target_model_path, model_storage_path)
+ return configs, attributes
+
+
+def _register_model_from_local(uri: str, model_storage_path: str,
+ config_storage_path: str) -> [TConfigs, str]:
+ """
+ Args:
+ uri: local dir path of model to register, where model.pt and
config.yaml are required,
+ e.g. /Users/admin/Desktop/model
+ model_storage_path: path to save model.pt
+ config_storage_path: path to save config.yaml
+ Returns:
+ configs: TConfigs
+ attributes: str
+ """
+ # concat uri to get complete path
+ target_model_path = os.path.join(uri, DEFAULT_MODEL_FILE_NAME)
+ target_config_path = os.path.join(uri, DEFAULT_CONFIG_FILE_NAME)
+
+ # check if file exist
+ exist_model_file = os.path.exists(target_model_path)
+ exist_config_file = os.path.exists(target_config_path)
+
+ configs = None
+ attributes = None
+ if exist_model_file and exist_config_file:
+ # copy config.yaml
+ logger.debug(f"copy file from {target_config_path} to
{config_storage_path}")
+ shutil.copy(target_config_path, config_storage_path)
+ logger.debug(f"copy file from {target_config_path} to
{config_storage_path} success")
+
+ # read and parse config dict from config.yaml
+ with open(config_storage_path, 'r', encoding='utf-8') as file:
+ config_dict = yaml.safe_load(file)
+ configs, attributes = _parse_inference_config(config_dict)
+
+ # if config.yaml is correct, copy model file
+ logger.debug(f"copy file from {target_model_path} to
{model_storage_path}")
+ shutil.copy(target_model_path, model_storage_path)
+ logger.debug(f"copy file from {target_model_path} to
{model_storage_path} success")
+
+ elif not exist_model_file or not exist_config_file:
+ raise InvalidUriError(uri)
+
+ return configs, attributes
+
+
+def _parse_inference_config(config_dict):
+ """
+ Args:
+ config_dict: dict
+ - configs: dict
+ - input_shape (list<i32>): input shape of the model and needs
to be two-dimensional array like [96, 2]
+ - output_shape (list<i32>): output shape of the model and
needs to be two-dimensional array like [96, 2]
+ - input_type (list<str>): input type of the model and each
element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64',
'text'], default float64
+ - output_type (list<str>): output type of the model and each
element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64',
'text'], default float64
+ - attributes: dict
+ Returns:
+ configs: TConfigs
+ attributes: str
+ """
+ configs = config_dict['configs']
+
+ # check if input_shape and output_shape are two-dimensional array
+ if not (isinstance(configs['input_shape'], list) and
len(configs['input_shape']) == 2):
+ raise BadConfigValueError('input_shape', configs['input_shape'],
+ 'input_shape should be a two-dimensional
array.')
+ if not (isinstance(configs['output_shape'], list) and
len(configs['output_shape']) == 2):
+ raise BadConfigValueError('output_shape', configs['output_shape'],
+ 'output_shape should be a two-dimensional
array.')
+
+ # check if input_shape and output_shape are positive integer
+ input_shape_is_positive_number = isinstance(configs['input_shape'][0],
int) and isinstance(
+ configs['input_shape'][1], int) and configs['input_shape'][0] > 0 and
configs['input_shape'][1] > 0
+ if not input_shape_is_positive_number:
+ raise BadConfigValueError('input_shape', configs['input_shape'],
+ 'element in input_shape should be positive
integer.')
+
+ output_shape_is_positive_number = isinstance(configs['output_shape'][0],
int) and isinstance(
+ configs['output_shape'][1], int) and configs['output_shape'][0] > 0
and configs['output_shape'][1] > 0
+ if not output_shape_is_positive_number:
+ raise BadConfigValueError('output_shape', configs['output_shape'],
+ 'element in output_shape should be positive
integer.')
+
+ # check if input_type and output_type are one-dimensional array with right
length
+ if 'input_type' in configs and not (
+ isinstance(configs['input_type'], list) and
len(configs['input_type']) == configs['input_shape'][1]):
+ raise BadConfigValueError('input_type', configs['input_type'],
+ 'input_type should be a one-dimensional
array and length of it should be equal to input_shape[1].')
+
+ if 'output_type' in configs and not (
+ isinstance(configs['output_type'], list) and
len(configs['output_type']) == configs['output_shape'][1]):
+ raise BadConfigValueError('output_type', configs['output_type'],
+ 'output_type should be a one-dimensional
array and length of it should be equal to output_shape[1].')
+
+ # parse input_type and output_type to byte
+ if 'input_type' in configs:
+ input_type = [get_data_type_byte_from_str(x) for x in
configs['input_type']]
+ else:
+ input_type = [get_data_type_byte_from_str('float32')] *
configs['input_shape'][1]
+
+ if 'output_type' in configs:
+ output_type = [get_data_type_byte_from_str(x) for x in
configs['output_type']]
+ else:
+ output_type = [get_data_type_byte_from_str('float32')] *
configs['output_shape'][1]
+
+ # parse attributes
+ attributes = ""
+ if 'attributes' in config_dict:
+ attributes = str(config_dict['attributes'])
+
+ return TConfigs(configs['input_shape'], configs['output_shape'],
input_type, output_type), attributes
+
+
+def fetch_model_by_uri(uri: str, model_storage_path: str, config_storage_path:
str):
+ is_network_path, uri = _parse_uri(uri)
+
+ if is_network_path:
+ return _register_model_from_network(uri, model_storage_path,
config_storage_path)
+ else:
+ return _register_model_from_local(uri, model_storage_path,
config_storage_path)
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py
b/iotdb-core/ainode/ainode/core/model/model_storage.py
new file mode 100644
index 00000000000..a14e1ef4cc7
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import shutil
+from collections.abc import Callable
+
+import torch
+import torch._dynamo
+from pylru import lrucache
+
+from ainode.core.config import AINodeDescriptor
+from ainode.core.constant import (DEFAULT_MODEL_FILE_NAME,
+ DEFAULT_CONFIG_FILE_NAME)
+from ainode.core.exception import ModelNotExistError
+from ainode.core.log import Logger
+from ainode.core.model.model_factory import fetch_model_by_uri
+from ainode.core.util.lock import ModelLockPool
+from ainode.core.manager.training_manager import get_args
+from ainode.core.model.TimerXL.models.timer_xl import Model
+
+logger = Logger()
+
+
+class ModelStorage(object):
+ def __init__(self):
+ self._model_dir = os.path.join(os.getcwd(),
AINodeDescriptor().get_config().get_ain_models_dir())
+ if not os.path.exists(self._model_dir):
+ try:
+ os.makedirs(self._model_dir)
+ except PermissionError as e:
+ logger.error(e)
+ raise e
+ self._lock_pool = ModelLockPool()
+ self._model_cache =
lrucache(AINodeDescriptor().get_config().get_ain_model_storage_cache_size())
+
+ def register_model(self, model_id: str, uri: str):
+ """
+ Args:
+ model_id: id of model to register
+ uri: network dir path or local dir path of model to register,
where model.pt and config.yaml are required,
+ e.g. https://huggingface.co/user/modelname/resolve/main/ or
/Users/admin/Desktop/model
+ Returns:
+ configs: TConfigs
+ attributes: str
+ """
+ storage_path = os.path.join(self._model_dir, f'{model_id}')
+ # create storage dir if not exist
+ if not os.path.exists(storage_path):
+ os.makedirs(storage_path)
+ model_storage_path = os.path.join(storage_path,
DEFAULT_MODEL_FILE_NAME)
+ config_storage_path = os.path.join(storage_path,
DEFAULT_CONFIG_FILE_NAME)
+ return fetch_model_by_uri(uri, model_storage_path, config_storage_path)
+
+ def load_model(self, model_id: str, acceleration: bool) -> Callable:
+ """
+ Returns:
+ model: a ScriptModule contains model architecture and parameters,
which can be deployed cross-platform
+ """
+ ain_models_dir = os.path.join(self._model_dir, f'{model_id}')
+ model_path = os.path.join(ain_models_dir, DEFAULT_MODEL_FILE_NAME)
+ with self._lock_pool.get_lock(model_id).read_lock():
+ if model_path in self._model_cache:
+ model = self._model_cache[model_path]
+ if isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
or not acceleration:
+ return model
+ else:
+ model = torch.compile(model)
+ self._model_cache[model_path] = model
+ return model
+ else:
+ # todo: use modelType instead
+ if 'timer' in model_id:
+ model_file_path = os.path.join(ain_models_dir,
'checkpoint.pth')
+ state_dict = torch.load(model_file_path)
+ model = Model(get_args())
+ model.load_state_dict(state_dict)
+ self._model_cache[model_path] = model
+ return model
+ elif not os.path.exists(model_path):
+ raise ModelNotExistError(model_path)
+ else:
+ model = torch.jit.load(model_path)
+ if acceleration:
+ try:
+ model = torch.compile(model)
+ except Exception as e:
+ logger.warning(f"acceleration failed, fallback to
normal mode: {str(e)}")
+ self._model_cache[model_path] = model
+ return model
+
+ def delete_model(self, model_id: str) -> None:
+ """
+ Args:
+ model_id: id of model to delete
+ Returns:
+ None
+ """
+ storage_path = os.path.join(self._model_dir, f'{model_id}')
+ with self._lock_pool.get_lock(model_id).write_lock():
+ if os.path.exists(storage_path):
+ for file_name in os.listdir(storage_path):
+ self._remove_from_cache(os.path.join(storage_path,
file_name))
+ shutil.rmtree(storage_path)
+
+ def _remove_from_cache(self, file_path: str) -> None:
+ if file_path in self._model_cache:
+ del self._model_cache[file_path]