This is an automated email from the ASF dual-hosted git repository. ycycse pushed a commit to branch timer_xl_inference in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 4c3ceb3d932c3d5ba58d201ad701123a874a216a Author: YangCaiyin <[email protected]> AuthorDate: Fri May 9 19:42:04 2025 +0800 support timer_xl in inference --- iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py | 20 +++++++++++++------- .../ainode/ainode/core/manager/inference_manager.py | 20 ++++++++++++++------ .../ainode/core/model/built_in_model_factory.py | 9 +++++---- .../iotdb/confignode/persistence/ModelInfo.java | 1 + .../function/TableBuiltinTableFunction.java | 9 +++++---- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py index 1945fa25a2e..4e4d8588fd2 100644 --- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py +++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py @@ -27,9 +27,12 @@ from ainode.TimerXL.models.configuration_timer import TimerxlConfig from ainode.core.util.masking import prepare_4d_causal_attention_mask from ainode.core.util.huggingface_cache import Cache, DynamicCache -import safetensors +from safetensors.torch import load_file as load_safetensors from huggingface_hub import hf_hub_download +from ainode.core.log import Logger +logger = Logger() + @dataclass class Output: outputs: torch.Tensor @@ -211,12 +214,15 @@ class Model(nn.Module): state_dict = torch.load(config.ckpt_path) elif config.ckpt_path.endswith('.safetensors'): if not os.path.exists(config.ckpt_path): - print(f"[INFO] Checkpoint not found at {config.ckpt_path}, downloading from HuggingFace...") + logger.info(f"Checkpoint not found at {config.ckpt_path}, downloading from HuggingFace...") repo_id = "thuml/timer-base-84m" - filename = os.path.basename(config.ckpt_path) # eg: model.safetensors - config.ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) - print(f"[INFO] Downloaded checkpoint to {config.ckpt_path}") - state_dict = safetensors.torch.load_file(config.ckpt_path) + try: + config.ckpt_path = hf_hub_download(repo_id=repo_id, filename=os.path.basename(config.ckpt_path), local_dir=os.path.dirname(config.ckpt_path)) + logger.info(f"Got checkpoint to {config.ckpt_path}") + except Exception as e: + logger.error(f"Failed to download checkpoint to {config.ckpt_path} due to {e}") + raise e + state_dict = load_safetensors(config.ckpt_path) else: raise ValueError('unsupported model weight type') # If there is no key beginning with 'model.model' in state_dict, add a 'model.' before all keys. (The model code here has an additional layer of encapsulation compared to the code on huggingface.) @@ -234,7 +240,7 @@ class Model(nn.Module): # change [L, C=1] to [batchsize=1, L] self.device = next(self.model.parameters()).device - x = torch.tensor(x.values, dtype=next(self.model.parameters()).dtype, device=self.device) + x = torch.tensor(x, dtype=next(self.model.parameters()).dtype, device=self.device) x = x.view(1, -1) preds = self.forward(x, max_new_tokens) diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index d16e521cde6..476b8d68b80 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -71,8 +71,13 @@ class InferenceManager: options = req.options options['predict_length'] = req.outputLength model = _get_built_in_model(model_id, model_manager, options) - inference_result = convert_to_binary(_inference_with_built_in_model( - model, data)) + if model_id == '_timerxl': + inference_result = _inference_with_timerxl( + model, data, options.get("predict_length", 96)) + else: + inference_result =_inference_with_built_in_model( + model, data) + inference_result = convert_to_binary(inference_result) else: # user-registered models model = _get_model(model_id, model_manager, req.options) @@ -199,8 +204,8 @@ def _inference_with_built_in_model(model, full_data): will concatenate all the output DataFrames into a list. """ - data, _, _, _ = full_data - output = model.inference(data) + _, data, _, _ = full_data + output = model.inference(data[0]) # output: DataFrame, shape: (H', C') output = pd.DataFrame(output) return output @@ -224,10 +229,13 @@ def _inference_with_timerxl(model, full_data, pred_len): will concatenate all the output DataFrames into a list. """ - data, _, _, _ = full_data + _, data, _, _ = full_data + data = data[0] + if data.dtype.byteorder not in ('=', '|'): + data = data.byteswap().newbyteorder() output = model.inference(data, pred_len) # output: DataFrame, shape: (H', C') - output = pd.DataFrame(output) + output = pd.DataFrame(output[0]) return output diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index dd8c9ee5308..c6d3edf4101 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -36,6 +36,7 @@ from ainode.core.log import Logger from ainode.TimerXL.models import timer_xl from ainode.TimerXL.models.configuration_timer import TimerxlConfig +from config import AINodeDescriptor logger = Logger() @@ -82,7 +83,7 @@ def fetch_built_in_model(model_id, inference_attributes): # validate the inference attributes for attribute_name in inference_attributes: if attribute_name not in attribute_map: - raise AttributeNotSupportError(model_id, attribute_name) + logger.warning(f"{attribute_name} is not supported in {model_id}") # parse the inference attributes, attributes is a Dict[str, Any] attributes = parse_attribute(inference_attributes, attribute_map) @@ -398,9 +399,9 @@ timerxl_attribute_map = { ), AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute( name=AttributeName.TIMERXL_CKPT_PATH.value, - default_value=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights', 'timerxl', 'model.safetensors'), - value_choices=[os.path.join(os.path.dirname(os.path.abspath(__file__)), 'weights', 'timerxl', 'model.safetensors'), ""], - ), + default_value=os.path.join(os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir(), 'weights', 'timerxl','model.safetensors'), + value_choices=[''] + ) } # built-in sktime model attributes diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java index 21b317d9854..e2beede330a 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java @@ -75,6 +75,7 @@ public class ModelInfo implements SnapshotProcessor { private static final Set<String> builtInAnomalyDetectionModel = new HashSet<>(); static { + builtInForecastModel.add("_timerxl"); builtInForecastModel.add("_ARIMA"); builtInForecastModel.add("_NaiveForecaster"); builtInForecastModel.add("_STLForecaster"); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java index fda10eba8db..4a07f9a0c7b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java @@ -25,6 +25,7 @@ import org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction; +import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; import org.apache.iotdb.udf.api.relational.TableFunction; import java.util.Arrays; @@ -38,8 +39,8 @@ public enum TableBuiltinTableFunction { CUMULATE("cumulate"), SESSION("session"), VARIATION("variation"), - CAPACITY("capacity"); - // FORECAST("forecast"); + CAPACITY("capacity"), + FORECAST("forecast"); private final String functionName; @@ -79,8 +80,8 @@ public enum TableBuiltinTableFunction { return new VariationTableFunction(); case "capacity": return new CapacityTableFunction(); - // case "forecast": - // return new ForecastTableFunction(); + case "forecast": + return new ForecastTableFunction(); default: throw new UnsupportedOperationException("Unsupported table function: " + functionName); }
