This is an automated email from the ASF dual-hosted git repository.

ycycse pushed a commit to branch timer_xl_inference
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 4c3ceb3d932c3d5ba58d201ad701123a874a216a
Author: YangCaiyin <[email protected]>
AuthorDate: Fri May 9 19:42:04 2025 +0800

    support timer_xl in inference
---
 iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py  | 20 +++++++++++++-------
 .../ainode/ainode/core/manager/inference_manager.py  | 20 ++++++++++++++------
 .../ainode/core/model/built_in_model_factory.py      |  9 +++++----
 .../iotdb/confignode/persistence/ModelInfo.java      |  1 +
 .../function/TableBuiltinTableFunction.java          |  9 +++++----
 5 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py 
b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
index 1945fa25a2e..4e4d8588fd2 100644
--- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
+++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
@@ -27,9 +27,12 @@ from ainode.TimerXL.models.configuration_timer import 
TimerxlConfig
 from ainode.core.util.masking import prepare_4d_causal_attention_mask
 from ainode.core.util.huggingface_cache import Cache, DynamicCache
 
-import safetensors
+from safetensors.torch import load_file as load_safetensors
 from huggingface_hub import hf_hub_download
 
+from ainode.core.log import Logger
+logger = Logger()
+
 @dataclass
 class Output:
     outputs: torch.Tensor
@@ -211,12 +214,15 @@ class Model(nn.Module):
                 state_dict = torch.load(config.ckpt_path)
             elif config.ckpt_path.endswith('.safetensors'):
                 if not os.path.exists(config.ckpt_path):
-                    print(f"[INFO] Checkpoint not found at {config.ckpt_path}, 
downloading from HuggingFace...")
+                    logger.info(f"Checkpoint not found at {config.ckpt_path}, 
downloading from HuggingFace...")
                     repo_id = "thuml/timer-base-84m"
-                    filename = os.path.basename(config.ckpt_path)  # eg: 
model.safetensors
-                    config.ckpt_path = hf_hub_download(repo_id=repo_id, 
filename=filename)
-                    print(f"[INFO] Downloaded checkpoint to 
{config.ckpt_path}")
-                state_dict = safetensors.torch.load_file(config.ckpt_path)
+                    try:
+                        config.ckpt_path = hf_hub_download(repo_id=repo_id, 
filename=os.path.basename(config.ckpt_path), 
local_dir=os.path.dirname(config.ckpt_path))
+                        logger.info(f"Got checkpoint to {config.ckpt_path}")
+                    except Exception as e:
+                        logger.error(f"Failed to download checkpoint to 
{config.ckpt_path} due to {e}")
+                        raise e
+                state_dict = load_safetensors(config.ckpt_path)
             else:
                 raise ValueError('unsupported model weight type')
             # If there is no key beginning with 'model.model' in state_dict, 
add a 'model.' before all keys. (The model code here has an additional layer of 
encapsulation compared to the code on huggingface.)
@@ -234,7 +240,7 @@ class Model(nn.Module):
         # change [L, C=1] to [batchsize=1, L]
         self.device = next(self.model.parameters()).device
         
-        x = torch.tensor(x.values, dtype=next(self.model.parameters()).dtype, 
device=self.device)
+        x = torch.tensor(x, dtype=next(self.model.parameters()).dtype, 
device=self.device)
         x = x.view(1, -1)
 
         preds = self.forward(x, max_new_tokens)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index d16e521cde6..476b8d68b80 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -71,8 +71,13 @@ class InferenceManager:
                 options = req.options
                 options['predict_length'] = req.outputLength
                 model = _get_built_in_model(model_id, model_manager, options)
-                inference_result = 
convert_to_binary(_inference_with_built_in_model(
-                    model, data))
+                if model_id == '_timerxl':
+                    inference_result = _inference_with_timerxl(
+                        model, data, options.get("predict_length", 96))
+                else:
+                    inference_result =_inference_with_built_in_model(
+                        model, data)
+                inference_result = convert_to_binary(inference_result)
             else:
                 # user-registered models
                 model = _get_model(model_id, model_manager, req.options)
@@ -199,8 +204,8 @@ def _inference_with_built_in_model(model, full_data):
         will concatenate all the output DataFrames into a list.
     """
 
-    data, _, _, _ = full_data
-    output = model.inference(data)
+    _, data, _, _ = full_data
+    output = model.inference(data[0])
     # output: DataFrame, shape: (H', C')
     output = pd.DataFrame(output)
     return output
@@ -224,10 +229,13 @@ def _inference_with_timerxl(model, full_data, pred_len):
         will concatenate all the output DataFrames into a list.
     """
 
-    data, _, _, _ = full_data
+    _, data, _, _ = full_data
+    data = data[0]
+    if data.dtype.byteorder not in ('=', '|'):
+        data = data.byteswap().newbyteorder()
     output = model.inference(data, pred_len)
     # output: DataFrame, shape: (H', C')
-    output = pd.DataFrame(output)
+    output = pd.DataFrame(output[0])
     return output
 
 
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index dd8c9ee5308..c6d3edf4101 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -36,6 +36,7 @@ from ainode.core.log import Logger
 
 from ainode.TimerXL.models import timer_xl
 from ainode.TimerXL.models.configuration_timer import TimerxlConfig
+from config import AINodeDescriptor
 
 logger = Logger()
 
@@ -82,7 +83,7 @@ def fetch_built_in_model(model_id, inference_attributes):
     # validate the inference attributes
     for attribute_name in inference_attributes:
         if attribute_name not in attribute_map:
-            raise AttributeNotSupportError(model_id, attribute_name)
+            logger.warning(f"{attribute_name} is not supported in {model_id}")
 
     # parse the inference attributes, attributes is a Dict[str, Any]
     attributes = parse_attribute(inference_attributes, attribute_map)
@@ -398,9 +399,9 @@ timerxl_attribute_map = {
     ),
     AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute(
         name=AttributeName.TIMERXL_CKPT_PATH.value,
-        default_value=os.path.join(os.path.dirname(os.path.abspath(__file__)), 
'weights', 'timerxl', 'model.safetensors'),
-        
value_choices=[os.path.join(os.path.dirname(os.path.abspath(__file__)), 
'weights', 'timerxl', 'model.safetensors'), ""],
-    ),
+        default_value=os.path.join(os.getcwd(), 
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights', 
'timerxl','model.safetensors'),
+        value_choices=['']
+    )
 }
 
 # built-in sktime model attributes
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 21b317d9854..e2beede330a 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -75,6 +75,7 @@ public class ModelInfo implements SnapshotProcessor {
   private static final Set<String> builtInAnomalyDetectionModel = new 
HashSet<>();
 
   static {
+    builtInForecastModel.add("_timerxl");
     builtInForecastModel.add("_ARIMA");
     builtInForecastModel.add("_NaiveForecaster");
     builtInForecastModel.add("_STLForecaster");
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
index fda10eba8db..4a07f9a0c7b 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
@@ -25,6 +25,7 @@ import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
 import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
 import 
org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
+import 
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
 import org.apache.iotdb.udf.api.relational.TableFunction;
 
 import java.util.Arrays;
@@ -38,8 +39,8 @@ public enum TableBuiltinTableFunction {
   CUMULATE("cumulate"),
   SESSION("session"),
   VARIATION("variation"),
-  CAPACITY("capacity");
-  //  FORECAST("forecast");
+  CAPACITY("capacity"),
+  FORECAST("forecast");
 
   private final String functionName;
 
@@ -79,8 +80,8 @@ public enum TableBuiltinTableFunction {
         return new VariationTableFunction();
       case "capacity":
         return new CapacityTableFunction();
-        //      case "forecast":
-        //        return new ForecastTableFunction();
+      case "forecast":
+        return new ForecastTableFunction();
       default:
         throw new UnsupportedOperationException("Unsupported table function: " 
+ functionName);
     }

Reply via email to