This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch create-model-CI in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 9868b417676e4344c4cf0b9944a41e835379c6ee Author: Yongzao <[email protected]> AuthorDate: Sun Dec 21 10:07:21 2025 +0800 finish --- .../iotdb/it/env/cluster/node/AINodeWrapper.java | 2 +- .../iotdb/ainode/it/AINodeCallInferenceIT.java | 24 ++--------- .../apache/iotdb/ainode/it/AINodeForecastIT.java | 16 ++------ .../iotdb/ainode/it/AINodeModelManageIT.java | 48 ++++++++++++++-------- .../apache/iotdb/ainode/utils/AINodeTestUtils.java | 44 ++++++++++++++++++++ .../ainode/iotdb/ainode/core/model/model_info.py | 12 +++--- .../iotdb/ainode/core/model/model_storage.py | 17 ++++---- .../thrift-ainode/src/main/thrift/ainode.thrift | 4 +- 8 files changed, 99 insertions(+), 68 deletions(-) diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java index 34fd7e85240..15c2e4761dd 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java @@ -60,7 +60,7 @@ public class AINodeWrapper extends AbstractNodeWrapper { public static final String CONFIG_PATH = "conf"; public static final String SCRIPT_PATH = "sbin"; public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin"; - public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights"; + public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models"; private void replaceAttribute(String[] keys, String[] values, String filePath) { Properties props = new Properties(); diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 44e280eca16..3131e398059 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -40,21 +40,12 @@ import java.sql.Statement; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; @RunWith(IoTDBTestRunner.class) @Category({AIClusterIT.class}) public class AINodeCallInferenceIT { - private static final String[] WRITE_SQL_IN_TREE = - new String[] { - "CREATE DATABASE root.AI", - "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", - }; - private static final String CALL_INFERENCE_SQL_TEMPLATE = "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; private static final int DEFAULT_INPUT_LENGTH = 256; @@ -64,16 +55,7 @@ public class AINodeCallInferenceIT { public static void setUp() throws Exception { // Init 1C1D1A cluster environment EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareData(WRITE_SQL_IN_TREE); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } + prepareDataInTree(); } @AfterClass @@ -91,7 +73,7 @@ public class AINodeCallInferenceIT { } } - public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) + public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { // Invoke call inference for specified models, there should exist result. for (int i = 0; i < 4; i++) { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java index bb0de13ed49..c2114ac9499 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -39,6 +39,7 @@ import java.sql.Statement; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; @RunWith(IoTDBTestRunner.class) @Category({AIClusterIT.class}) @@ -58,18 +59,7 @@ public class AINodeForecastIT { public static void setUp() throws Exception { // Init 1C1D1A cluster environment EnvFactory.getEnv().initClusterEnvironment(1, 1); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE db"); - statement.execute( - "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)"); - for (int i = 0; i < 5760; i++) { - statement.execute( - String.format( - "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } + prepareDataInTable(); } @AfterClass @@ -87,7 +77,7 @@ public class AINodeForecastIT { } } - public void forecastTableFunctionTest( + public static void forecastTableFunctionTest( Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { // Invoke forecast table function for specified models, there should exist result. for (int i = 0; i < 4; i++) { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 3315617e7fd..6b2cffd0636 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -39,8 +39,12 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.concurrent.TimeUnit; +import static org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest; +import static org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -54,6 +58,8 @@ public class AINodeModelManageIT { public static void setUp() throws Exception { // Init 1C1D1A cluster environment EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareDataInTree(); + prepareDataInTable(); } @AfterClass @@ -61,47 +67,51 @@ public class AINodeModelManageIT { EnvFactory.getEnv().cleanClusterEnvironment(); } - // @Test + @Test public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - userDefinedModelManagementTest(statement); + registerUserDefinedModel(statement); + callInferenceTest( + statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); + dropUserDefinedModel(statement); } } - // @Test + @Test public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - userDefinedModelManagementTest(statement); + registerUserDefinedModel(statement); + forecastTableFunctionTest( + statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active")); + dropUserDefinedModel(statement); } } - private void userDefinedModelManagementTest(Statement statement) + private void registerUserDefinedModel(Statement statement) throws SQLException, InterruptedException { final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = "create model operationTest using uri \"" + "\""; - final String showSql = "SHOW MODELS operationTest"; - final String dropSql = "DROP MODEL operationTest"; - + final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\""; + final String showSql = "SHOW MODELS user_chronos"; statement.execute(alterConfigSQL); statement.execute(registerSql); boolean loading = true; - int count = 0; for (int retryCnt = 0; retryCnt < 100; retryCnt++) { try (ResultSet resultSet = statement.executeQuery(showSql)) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); while (resultSet.next()) { String modelId = resultSet.getString(1); + String modelType = resultSet.getString(2); String category = resultSet.getString(3); String state = resultSet.getString(4); - assertEquals("operationTest", modelId); - assertEquals("USER-DEFINED", category); - if (state.equals("ACTIVE")) { + assertEquals("user_chronos", modelId); + assertEquals("user_defined", category); + assertEquals("custom_t5", modelType); + if (state.equals("active")) { loading = false; - count++; - } else if (state.equals("LOADING")) { + } else if (state.equals("loading")) { break; } else { fail("Unexpected status of model: " + state); @@ -114,12 +124,16 @@ public class AINodeModelManageIT { TimeUnit.SECONDS.sleep(1); } assertFalse(loading); - assertEquals(1, count); + } + + private void dropUserDefinedModel(Statement statement) throws SQLException { + final String showSql = "SHOW MODELS user_chronos"; + final String dropSql = "DROP MODEL user_chronos"; statement.execute(dropSql); try (ResultSet resultSet = statement.executeQuery(showSql)) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); - count = 0; + int count = 0; while (resultSet.next()) { count++; } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 35fb51598b7..d620efacc26 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -19,11 +19,15 @@ package org.apache.iotdb.ainode.utils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.itbase.env.BaseEnv; + import com.google.common.collect.ImmutableSet; import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.sql.Connection; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; @@ -39,6 +43,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -206,6 +211,45 @@ public class AINodeTestUtils { fail("Model " + modelId + " is still loaded on device " + device); } + private static final String[] WRITE_SQL_IN_TREE = + new String[] { + "CREATE DATABASE root.AI", + "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", + }; + + /** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */ + public static void prepareDataInTree() throws SQLException { + prepareData(WRITE_SQL_IN_TREE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 5760; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + /** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */ + public static void prepareDataInTable() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE db"); + statement.execute( + "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)"); + for (int i = 0; i < 5760; i++) { + statement.execute( + String.format( + "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + public static class FakeModelInfo { private final String modelId; diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index bcb4a5e2056..d0da371bfd5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -31,7 +31,7 @@ class ModelInfo: pipeline_cls: str = "", repo_id: str = "", auto_map: Optional[Dict] = None, - _transformers_registered: bool = False, + transformers_registered: bool = False, ): self.model_id = model_id self.model_type = model_type @@ -40,7 +40,9 @@ class ModelInfo: self.pipeline_cls = pipeline_cls self.repo_id = repo_id self.auto_map = auto_map # If exists, indicates it's a Transformers model - self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers + self.transformers_registered = ( + transformers_registered # Internal flag: whether registered to Transformers + ) def __repr__(self): return ( @@ -116,7 +118,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { "AutoConfig": "configuration_timer.TimerConfig", "AutoModelForCausalLM": "modeling_timer.TimerForPrediction", }, - _transformers_registered=True, + transformers_registered=True, ), "sundial": ModelInfo( model_id="sundial", @@ -129,7 +131,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { "AutoConfig": "configuration_sundial.SundialConfig", "AutoModelForCausalLM": "modeling_sundial.SundialForPrediction", }, - _transformers_registered=True, + transformers_registered=True, ), "chronos2": ModelInfo( model_id="chronos2", @@ -139,7 +141,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { pipeline_cls="pipeline_chronos2.Chronos2Pipeline", repo_id="amazon/chronos-2", auto_map={ - "AutoConfig": "config.Chronos2ForecastingConfig", + "AutoConfig": "config.Chronos2CoreConfig", "AutoModelForCausalLM": "model.Chronos2Model", }, ), diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index ee09cfd75bb..910a0620fac 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -236,7 +236,7 @@ class ModelStorage: state=ModelStates.ACTIVE, pipeline_cls=pipeline_cls, auto_map=auto_map, - _transformers_registered=False, # Lazy registration + transformers_registered=False, # Lazy registration ) self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info @@ -287,7 +287,7 @@ class ModelStorage: state=ModelStates.ACTIVE, pipeline_cls=pipeline_cls, auto_map=auto_map, - _transformers_registered=False, # Register later + transformers_registered=False, # Register later ) self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info @@ -296,7 +296,7 @@ class ModelStorage: success = self._register_transformers_model(model_info) if success: with self._lock_pool.get_lock(model_id).write_lock(): - model_info._transformers_registered = True + model_info.transformers_registered = True else: with self._lock_pool.get_lock(model_id).write_lock(): model_info.state = ModelStates.INACTIVE @@ -352,7 +352,7 @@ class ModelStorage: f"Registered other type model: {model_info.model_id} ({model_info.model_type})" ) - def ensure_transformers_registered(self, model_id: str) -> ModelInfo: + def ensure_transformers_registered(self, model_id: str) -> ModelInfo | None: """ Ensure Transformers model is registered (called for lazy registration) This method uses locks to ensure thread safety. All check logic is within lock protection. @@ -369,11 +369,10 @@ class ModelStorage: break if not model_info: - logger.warning(f"Model {model_id} does not exist, cannot register") return None # If already registered, return directly - if model_info._transformers_registered: + if model_info.transformers_registered: return model_info # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks) @@ -381,14 +380,14 @@ class ModelStorage: not model_info.auto_map or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys() ): - model_info._transformers_registered = True + model_info.transformers_registered = True return model_info # Execute registration (under lock protection) try: success = self._register_transformers_model(model_info) if success: - model_info._transformers_registered = True + model_info.transformers_registered = True logger.info( f"Model {model_id} successfully registered to Transformers" ) @@ -401,7 +400,7 @@ class ModelStorage: except Exception as e: # Ensure state consistency in exception cases model_info.state = ModelStates.INACTIVE - model_info._transformers_registered = False + model_info.transformers_registered = False logger.error( f"Exception occurred while registering model {model_id} to Transformers: {e}" ) diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index cda356a948e..ea32f01b6e2 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -40,8 +40,8 @@ struct TAIHeartbeatResp { } struct TRegisterModelReq { - 1: required string uri - 2: required string modelId + 1: required string modelId + 2: required string uri } struct TConfigs {
