This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch create-model-CI
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 9868b417676e4344c4cf0b9944a41e835379c6ee
Author: Yongzao <[email protected]>
AuthorDate: Sun Dec 21 10:07:21 2025 +0800

    finish
---
 .../iotdb/it/env/cluster/node/AINodeWrapper.java   |  2 +-
 .../iotdb/ainode/it/AINodeCallInferenceIT.java     | 24 ++---------
 .../apache/iotdb/ainode/it/AINodeForecastIT.java   | 16 ++------
 .../iotdb/ainode/it/AINodeModelManageIT.java       | 48 ++++++++++++++--------
 .../apache/iotdb/ainode/utils/AINodeTestUtils.java | 44 ++++++++++++++++++++
 .../ainode/iotdb/ainode/core/model/model_info.py   | 12 +++---
 .../iotdb/ainode/core/model/model_storage.py       | 17 ++++----
 .../thrift-ainode/src/main/thrift/ainode.thrift    |  4 +-
 8 files changed, 99 insertions(+), 68 deletions(-)

diff --git 
a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
 
b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
index 34fd7e85240..15c2e4761dd 100644
--- 
a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
+++ 
b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
@@ -60,7 +60,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
   public static final String CONFIG_PATH = "conf";
   public static final String SCRIPT_PATH = "sbin";
   public static final String BUILT_IN_MODEL_PATH = 
"data/ainode/models/builtin";
-  public static final String CACHE_BUILT_IN_MODEL_PATH = 
"/data/ainode/models/weights";
+  public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models";
 
   private void replaceAttribute(String[] keys, String[] values, String 
filePath) {
     Properties props = new Properties();
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
index 44e280eca16..3131e398059 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
@@ -40,21 +40,12 @@ import java.sql.Statement;
 
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
-import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
 
 @RunWith(IoTDBTestRunner.class)
 @Category({AIClusterIT.class})
 public class AINodeCallInferenceIT {
 
-  private static final String[] WRITE_SQL_IN_TREE =
-      new String[] {
-        "CREATE DATABASE root.AI",
-        "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
-        "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
-        "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
-        "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
-      };
-
   private static final String CALL_INFERENCE_SQL_TEMPLATE =
       "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", 
generateTime=true, outputLength=%d)";
   private static final int DEFAULT_INPUT_LENGTH = 256;
@@ -64,16 +55,7 @@ public class AINodeCallInferenceIT {
   public static void setUp() throws Exception {
     // Init 1C1D1A cluster environment
     EnvFactory.getEnv().initClusterEnvironment(1, 1);
-    prepareData(WRITE_SQL_IN_TREE);
-    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
-        Statement statement = connection.createStatement()) {
-      for (int i = 0; i < 2880; i++) {
-        statement.execute(
-            String.format(
-                "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) 
VALUES(%d,%f,%f,%d,%d)",
-                i, (float) i, (double) i, i, i));
-      }
-    }
+    prepareDataInTree();
   }
 
   @AfterClass
@@ -91,7 +73,7 @@ public class AINodeCallInferenceIT {
     }
   }
 
-  public void callInferenceTest(Statement statement, 
AINodeTestUtils.FakeModelInfo modelInfo)
+  public static void callInferenceTest(Statement statement, 
AINodeTestUtils.FakeModelInfo modelInfo)
       throws SQLException {
     // Invoke call inference for specified models, there should exist result.
     for (int i = 0; i < 4; i++) {
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
index bb0de13ed49..c2114ac9499 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
@@ -39,6 +39,7 @@ import java.sql.Statement;
 
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
 
 @RunWith(IoTDBTestRunner.class)
 @Category({AIClusterIT.class})
@@ -58,18 +59,7 @@ public class AINodeForecastIT {
   public static void setUp() throws Exception {
     // Init 1C1D1A cluster environment
     EnvFactory.getEnv().initClusterEnvironment(1, 1);
-    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
-        Statement statement = connection.createStatement()) {
-      statement.execute("CREATE DATABASE db");
-      statement.execute(
-          "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 
FIELD, s3 INT64 FIELD)");
-      for (int i = 0; i < 5760; i++) {
-        statement.execute(
-            String.format(
-                "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
-                i, (float) i, (double) i, i, i));
-      }
-    }
+    prepareDataInTable();
   }
 
   @AfterClass
@@ -87,7 +77,7 @@ public class AINodeForecastIT {
     }
   }
 
-  public void forecastTableFunctionTest(
+  public static void forecastTableFunctionTest(
       Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws 
SQLException {
     // Invoke forecast table function for specified models, there should exist 
result.
     for (int i = 0; i < 4; i++) {
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
index 3315617e7fd..6b2cffd0636 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
@@ -39,8 +39,12 @@ import java.sql.SQLException;
 import java.sql.Statement;
 import java.util.concurrent.TimeUnit;
 
+import static 
org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest;
+import static 
org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest;
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -54,6 +58,8 @@ public class AINodeModelManageIT {
   public static void setUp() throws Exception {
     // Init 1C1D1A cluster environment
     EnvFactory.getEnv().initClusterEnvironment(1, 1);
+    prepareDataInTree();
+    prepareDataInTable();
   }
 
   @AfterClass
@@ -61,47 +67,51 @@ public class AINodeModelManageIT {
     EnvFactory.getEnv().cleanClusterEnvironment();
   }
 
-  //  @Test
+  @Test
   public void userDefinedModelManagementTestInTree() throws SQLException, 
InterruptedException {
     try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
-      userDefinedModelManagementTest(statement);
+      registerUserDefinedModel(statement);
+      callInferenceTest(
+          statement, new FakeModelInfo("user_chronos", "custom_t5", 
"user_defined", "active"));
+      dropUserDefinedModel(statement);
     }
   }
 
-  //  @Test
+  @Test
   public void userDefinedModelManagementTestInTable() throws SQLException, 
InterruptedException {
     try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
-      userDefinedModelManagementTest(statement);
+      registerUserDefinedModel(statement);
+      forecastTableFunctionTest(
+          statement, new FakeModelInfo("user_chronos", "custom_t5", 
"user_defined", "active"));
+      dropUserDefinedModel(statement);
     }
   }
 
-  private void userDefinedModelManagementTest(Statement statement)
+  private void registerUserDefinedModel(Statement statement)
       throws SQLException, InterruptedException {
     final String alterConfigSQL = "set configuration 
\"trusted_uri_pattern\"='.*'";
-    final String registerSql = "create model operationTest using uri \"" + 
"\"";
-    final String showSql = "SHOW MODELS operationTest";
-    final String dropSql = "DROP MODEL operationTest";
-
+    final String registerSql = "create model user_chronos using uri 
\"file:///data/chronos2\"";
+    final String showSql = "SHOW MODELS user_chronos";
     statement.execute(alterConfigSQL);
     statement.execute(registerSql);
     boolean loading = true;
-    int count = 0;
     for (int retryCnt = 0; retryCnt < 100; retryCnt++) {
       try (ResultSet resultSet = statement.executeQuery(showSql)) {
         ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
         checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
         while (resultSet.next()) {
           String modelId = resultSet.getString(1);
+          String modelType = resultSet.getString(2);
           String category = resultSet.getString(3);
           String state = resultSet.getString(4);
-          assertEquals("operationTest", modelId);
-          assertEquals("USER-DEFINED", category);
-          if (state.equals("ACTIVE")) {
+          assertEquals("user_chronos", modelId);
+          assertEquals("user_defined", category);
+          assertEquals("custom_t5", modelType);
+          if (state.equals("active")) {
             loading = false;
-            count++;
-          } else if (state.equals("LOADING")) {
+          } else if (state.equals("loading")) {
             break;
           } else {
             fail("Unexpected status of model: " + state);
@@ -114,12 +124,16 @@ public class AINodeModelManageIT {
       TimeUnit.SECONDS.sleep(1);
     }
     assertFalse(loading);
-    assertEquals(1, count);
+  }
+
+  private void dropUserDefinedModel(Statement statement) throws SQLException {
+    final String showSql = "SHOW MODELS user_chronos";
+    final String dropSql = "DROP MODEL user_chronos";
     statement.execute(dropSql);
     try (ResultSet resultSet = statement.executeQuery(showSql)) {
       ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
       checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
-      count = 0;
+      int count = 0;
       while (resultSet.next()) {
         count++;
       }
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 35fb51598b7..d620efacc26 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -19,11 +19,15 @@
 
 package org.apache.iotdb.ainode.utils;
 
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.itbase.env.BaseEnv;
+
 import com.google.common.collect.ImmutableSet;
 import org.junit.Assert;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.sql.Connection;
 import java.sql.ResultSet;
 import java.sql.ResultSetMetaData;
 import java.sql.SQLException;
@@ -39,6 +43,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
 
@@ -206,6 +211,45 @@ public class AINodeTestUtils {
     fail("Model " + modelId + " is still loaded on device " + device);
   }
 
+  private static final String[] WRITE_SQL_IN_TREE =
+      new String[] {
+        "CREATE DATABASE root.AI",
+        "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
+        "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
+        "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
+        "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
+      };
+
+  /** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows 
of data in tree. */
+  public static void prepareDataInTree() throws SQLException {
+    prepareData(WRITE_SQL_IN_TREE);
+    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+        Statement statement = connection.createStatement()) {
+      for (int i = 0; i < 5760; i++) {
+        statement.execute(
+            String.format(
+                "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) 
VALUES(%d,%f,%f,%d,%d)",
+                i, (float) i, (double) i, i, i));
+      }
+    }
+  }
+
+  /** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of 
data in table. */
+  public static void prepareDataInTable() throws SQLException {
+    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+        Statement statement = connection.createStatement()) {
+      statement.execute("CREATE DATABASE db");
+      statement.execute(
+          "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 
FIELD, s3 INT64 FIELD)");
+      for (int i = 0; i < 5760; i++) {
+        statement.execute(
+            String.format(
+                "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
+                i, (float) i, (double) i, i, i));
+      }
+    }
+  }
+
   public static class FakeModelInfo {
 
     private final String modelId;
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index bcb4a5e2056..d0da371bfd5 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -31,7 +31,7 @@ class ModelInfo:
         pipeline_cls: str = "",
         repo_id: str = "",
         auto_map: Optional[Dict] = None,
-        _transformers_registered: bool = False,
+        transformers_registered: bool = False,
     ):
         self.model_id = model_id
         self.model_type = model_type
@@ -40,7 +40,9 @@ class ModelInfo:
         self.pipeline_cls = pipeline_cls
         self.repo_id = repo_id
         self.auto_map = auto_map  # If exists, indicates it's a Transformers 
model
-        self._transformers_registered = _transformers_registered  # Internal 
flag: whether registered to Transformers
+        self.transformers_registered = (
+            transformers_registered  # Internal flag: whether registered to 
Transformers
+        )
 
     def __repr__(self):
         return (
@@ -116,7 +118,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
             "AutoConfig": "configuration_timer.TimerConfig",
             "AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
         },
-        _transformers_registered=True,
+        transformers_registered=True,
     ),
     "sundial": ModelInfo(
         model_id="sundial",
@@ -129,7 +131,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
             "AutoConfig": "configuration_sundial.SundialConfig",
             "AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
         },
-        _transformers_registered=True,
+        transformers_registered=True,
     ),
     "chronos2": ModelInfo(
         model_id="chronos2",
@@ -139,7 +141,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
         pipeline_cls="pipeline_chronos2.Chronos2Pipeline",
         repo_id="amazon/chronos-2",
         auto_map={
-            "AutoConfig": "config.Chronos2ForecastingConfig",
+            "AutoConfig": "config.Chronos2CoreConfig",
             "AutoModelForCausalLM": "model.Chronos2Model",
         },
     ),
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
index ee09cfd75bb..910a0620fac 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
@@ -236,7 +236,7 @@ class ModelStorage:
                 state=ModelStates.ACTIVE,
                 pipeline_cls=pipeline_cls,
                 auto_map=auto_map,
-                _transformers_registered=False,  # Lazy registration
+                transformers_registered=False,  # Lazy registration
             )
             self._models[ModelCategory.USER_DEFINED.value][model_id] = 
model_info
 
@@ -287,7 +287,7 @@ class ModelStorage:
                 state=ModelStates.ACTIVE,
                 pipeline_cls=pipeline_cls,
                 auto_map=auto_map,
-                _transformers_registered=False,  # Register later
+                transformers_registered=False,  # Register later
             )
             self._models[ModelCategory.USER_DEFINED.value][model_id] = 
model_info
 
@@ -296,7 +296,7 @@ class ModelStorage:
             success = self._register_transformers_model(model_info)
             if success:
                 with self._lock_pool.get_lock(model_id).write_lock():
-                    model_info._transformers_registered = True
+                    model_info.transformers_registered = True
             else:
                 with self._lock_pool.get_lock(model_id).write_lock():
                     model_info.state = ModelStates.INACTIVE
@@ -352,7 +352,7 @@ class ModelStorage:
             f"Registered other type model: {model_info.model_id} 
({model_info.model_type})"
         )
 
-    def ensure_transformers_registered(self, model_id: str) -> ModelInfo:
+    def ensure_transformers_registered(self, model_id: str) -> ModelInfo | 
None:
         """
         Ensure Transformers model is registered (called for lazy registration)
         This method uses locks to ensure thread safety. All check logic is 
within lock protection.
@@ -369,11 +369,10 @@ class ModelStorage:
                     break
 
             if not model_info:
-                logger.warning(f"Model {model_id} does not exist, cannot 
register")
                 return None
 
             # If already registered, return directly
-            if model_info._transformers_registered:
+            if model_info.transformers_registered:
                 return model_info
 
             # If no auto_map, not a Transformers model, mark as registered 
(avoid duplicate checks)
@@ -381,14 +380,14 @@ class ModelStorage:
                 not model_info.auto_map
                 or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys()
             ):
-                model_info._transformers_registered = True
+                model_info.transformers_registered = True
                 return model_info
 
             # Execute registration (under lock protection)
             try:
                 success = self._register_transformers_model(model_info)
                 if success:
-                    model_info._transformers_registered = True
+                    model_info.transformers_registered = True
                     logger.info(
                         f"Model {model_id} successfully registered to 
Transformers"
                     )
@@ -401,7 +400,7 @@ class ModelStorage:
             except Exception as e:
                 # Ensure state consistency in exception cases
                 model_info.state = ModelStates.INACTIVE
-                model_info._transformers_registered = False
+                model_info.transformers_registered = False
                 logger.error(
                     f"Exception occurred while registering model {model_id} to 
Transformers: {e}"
                 )
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index cda356a948e..ea32f01b6e2 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -40,8 +40,8 @@ struct TAIHeartbeatResp {
 }
 
 struct TRegisterModelReq {
-  1: required string uri
-  2: required string modelId
+  1: required string modelId
+  2: required string uri
 }
 
 struct TConfigs {

Reply via email to