This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new a572c1c9da1 [AINode] Prevent auto_map gets covered and add model_list
for AINodeConcurrentForecastIT (#16928)
a572c1c9da1 is described below
commit a572c1c9da1b2185b790aec2dc4c61e07a2f89da
Author: Gewu <[email protected]>
AuthorDate: Fri Dec 19 13:25:29 2025 +0800
[AINode] Prevent auto_map gets covered and add model_list for
AINodeConcurrentForecastIT (#16928)
---
.../org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java | 10 ++++++++--
iotdb-core/ainode/iotdb/ainode/core/model/model_info.py | 2 ++
iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py | 9 ++++++---
3 files changed, 16 insertions(+), 5 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
index fe19f991e57..7b465d10051 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
@@ -36,8 +36,9 @@ import org.slf4j.LoggerFactory;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
+import java.util.Arrays;
+import java.util.List;
-import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
import static
org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
import static
org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
import static
org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
@@ -48,6 +49,11 @@ public class AINodeConcurrentForecastIT {
private static final Logger LOGGER =
LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
+ private static final List<AINodeTestUtils.FakeModelInfo> MODEL_LIST =
+ Arrays.asList(
+ new AINodeTestUtils.FakeModelInfo("sundial", "sundial", "builtin",
"active"),
+ new AINodeTestUtils.FakeModelInfo("timer_xl", "timer", "builtin",
"active"));
+
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
"SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM
root.AI) ORDER BY time, output_length=>%d)";
@@ -78,7 +84,7 @@ public class AINodeConcurrentForecastIT {
@Test
public void concurrentGPUForecastTest() throws SQLException,
InterruptedException {
- for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
+ for (AINodeTestUtils.FakeModelInfo modelInfo : MODEL_LIST) {
concurrentGPUForecastTest(modelInfo);
}
}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index 697b3671275..5d86e7c588f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -116,6 +116,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
"AutoConfig": "configuration_timer.TimerConfig",
"AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
},
+ _transformers_registered=True,
),
"sundial": ModelInfo(
model_id="sundial",
@@ -128,5 +129,6 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
"AutoConfig": "configuration_sundial.SundialConfig",
"AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
},
+ _transformers_registered=True,
),
}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
index a79371d2e79..ee09cfd75bb 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
@@ -196,9 +196,12 @@ class ModelStorage:
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
- if model_info.model_type == "":
- model_info.model_type = config.get("model_type",
"")
- model_info.auto_map = config.get("auto_map", None)
+ model_info.model_type = config.get(
+ "model_type", model_info.model_type
+ )
+ model_info.auto_map = config.get(
+ "auto_map", model_info.auto_map
+ )
logger.info(
f"Model {model_id} downloaded successfully and is
ready to use."
)