This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 52002e86d2d [AINode] Append model management IT (#16938)
52002e86d2d is described below
commit 52002e86d2df556d90c88f15ab84325dc81b45ee
Author: Yongzao <[email protected]>
AuthorDate: Tue Dec 23 12:24:54 2025 +0800
[AINode] Append model management IT (#16938)
---
.../iotdb/it/env/cluster/node/AINodeWrapper.java | 2 +-
.../iotdb/ainode/it/AINodeCallInferenceIT.java | 24 +------
.../apache/iotdb/ainode/it/AINodeForecastIT.java | 16 +----
.../iotdb/ainode/it/AINodeModelManageIT.java | 58 +++++++++++-----
.../apache/iotdb/ainode/utils/AINodeTestUtils.java | 44 ++++++++++++
.../iotdb/ainode/core/manager/model_manager.py | 6 +-
.../ainode/iotdb/ainode/core/model/model_info.py | 12 ++--
.../iotdb/ainode/core/model/model_storage.py | 80 ++++++++++++----------
.../thrift-ainode/src/main/thrift/ainode.thrift | 4 +-
9 files changed, 150 insertions(+), 96 deletions(-)
diff --git
a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
index 34fd7e85240..15c2e4761dd 100644
---
a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
+++
b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
@@ -60,7 +60,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
public static final String CONFIG_PATH = "conf";
public static final String SCRIPT_PATH = "sbin";
public static final String BUILT_IN_MODEL_PATH =
"data/ainode/models/builtin";
- public static final String CACHE_BUILT_IN_MODEL_PATH =
"/data/ainode/models/weights";
+ public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models";
private void replaceAttribute(String[] keys, String[] values, String
filePath) {
Properties props = new Properties();
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
index 44e280eca16..3131e398059 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
@@ -40,21 +40,12 @@ import java.sql.Statement;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
-import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeCallInferenceIT {
- private static final String[] WRITE_SQL_IN_TREE =
- new String[] {
- "CREATE DATABASE root.AI",
- "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
- "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
- "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
- "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
- };
-
private static final String CALL_INFERENCE_SQL_TEMPLATE =
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\",
generateTime=true, outputLength=%d)";
private static final int DEFAULT_INPUT_LENGTH = 256;
@@ -64,16 +55,7 @@ public class AINodeCallInferenceIT {
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
- prepareData(WRITE_SQL_IN_TREE);
- try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- for (int i = 0; i < 2880; i++) {
- statement.execute(
- String.format(
- "INSERT INTO root.AI(timestamp,s0,s1,s2,s3)
VALUES(%d,%f,%f,%d,%d)",
- i, (float) i, (double) i, i, i));
- }
- }
+ prepareDataInTree();
}
@AfterClass
@@ -91,7 +73,7 @@ public class AINodeCallInferenceIT {
}
}
- public void callInferenceTest(Statement statement,
AINodeTestUtils.FakeModelInfo modelInfo)
+ public static void callInferenceTest(Statement statement,
AINodeTestUtils.FakeModelInfo modelInfo)
throws SQLException {
// Invoke call inference for specified models, there should exist result.
for (int i = 0; i < 4; i++) {
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
index bb0de13ed49..c2114ac9499 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
@@ -39,6 +39,7 @@ import java.sql.Statement;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
@@ -58,18 +59,7 @@ public class AINodeForecastIT {
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
- try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- statement.execute("CREATE DATABASE db");
- statement.execute(
- "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32
FIELD, s3 INT64 FIELD)");
- for (int i = 0; i < 5760; i++) {
- statement.execute(
- String.format(
- "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
- i, (float) i, (double) i, i, i));
- }
- }
+ prepareDataInTable();
}
@AfterClass
@@ -87,7 +77,7 @@ public class AINodeForecastIT {
}
}
- public void forecastTableFunctionTest(
+ public static void forecastTableFunctionTest(
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws
SQLException {
// Invoke forecast table function for specified models, there should exist
result.
for (int i = 0; i < 4; i++) {
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
index 3315617e7fd..8ece0ba7523 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
@@ -39,8 +39,12 @@ import java.sql.SQLException;
import java.sql.Statement;
import java.util.concurrent.TimeUnit;
+import static
org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest;
+import static
org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
@@ -54,6 +58,8 @@ public class AINodeModelManageIT {
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
+ prepareDataInTree();
+ prepareDataInTable();
}
@AfterClass
@@ -61,47 +67,61 @@ public class AINodeModelManageIT {
EnvFactory.getEnv().cleanClusterEnvironment();
}
- // @Test
+ @Test
public void userDefinedModelManagementTestInTree() throws SQLException,
InterruptedException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- userDefinedModelManagementTest(statement);
+ registerUserDefinedModel(statement);
+ callInferenceTest(
+ statement, new FakeModelInfo("user_chronos", "custom_t5",
"user_defined", "active"));
+ dropUserDefinedModel(statement);
+ errorTest(
+ statement,
+ "create model origin_chronos using uri
\"file:///data/chronos2_origin\"",
+ "1505: 't5' is already used by a Transformers config, pick another
name.");
+ statement.execute("drop model origin_chronos");
}
}
- // @Test
+ @Test
public void userDefinedModelManagementTestInTable() throws SQLException,
InterruptedException {
try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- userDefinedModelManagementTest(statement);
+ registerUserDefinedModel(statement);
+ forecastTableFunctionTest(
+ statement, new FakeModelInfo("user_chronos", "custom_t5",
"user_defined", "active"));
+ dropUserDefinedModel(statement);
+ errorTest(
+ statement,
+ "create model origin_chronos using uri
\"file:///data/chronos2_origin\"",
+ "1505: 't5' is already used by a Transformers config, pick another
name.");
+ statement.execute("drop model origin_chronos");
}
}
- private void userDefinedModelManagementTest(Statement statement)
+ private void registerUserDefinedModel(Statement statement)
throws SQLException, InterruptedException {
final String alterConfigSQL = "set configuration
\"trusted_uri_pattern\"='.*'";
- final String registerSql = "create model operationTest using uri \"" +
"\"";
- final String showSql = "SHOW MODELS operationTest";
- final String dropSql = "DROP MODEL operationTest";
-
+ final String registerSql = "create model user_chronos using uri
\"file:///data/chronos2\"";
+ final String showSql = "SHOW MODELS user_chronos";
statement.execute(alterConfigSQL);
statement.execute(registerSql);
boolean loading = true;
- int count = 0;
for (int retryCnt = 0; retryCnt < 100; retryCnt++) {
try (ResultSet resultSet = statement.executeQuery(showSql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
while (resultSet.next()) {
String modelId = resultSet.getString(1);
+ String modelType = resultSet.getString(2);
String category = resultSet.getString(3);
String state = resultSet.getString(4);
- assertEquals("operationTest", modelId);
- assertEquals("USER-DEFINED", category);
- if (state.equals("ACTIVE")) {
+ assertEquals("user_chronos", modelId);
+ assertEquals("custom_t5", modelType);
+ assertEquals("user_defined", category);
+ if (state.equals("active")) {
loading = false;
- count++;
- } else if (state.equals("LOADING")) {
+ } else if (state.equals("loading")) {
break;
} else {
fail("Unexpected status of model: " + state);
@@ -114,12 +134,16 @@ public class AINodeModelManageIT {
TimeUnit.SECONDS.sleep(1);
}
assertFalse(loading);
- assertEquals(1, count);
+ }
+
+ private void dropUserDefinedModel(Statement statement) throws SQLException {
+ final String showSql = "SHOW MODELS user_chronos";
+ final String dropSql = "DROP MODEL user_chronos";
statement.execute(dropSql);
try (ResultSet resultSet = statement.executeQuery(showSql)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
- count = 0;
+ int count = 0;
while (resultSet.next()) {
count++;
}
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 35fb51598b7..d620efacc26 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -19,11 +19,15 @@
package org.apache.iotdb.ainode.utils;
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.itbase.env.BaseEnv;
+
import com.google.common.collect.ImmutableSet;
import org.junit.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
@@ -39,6 +43,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@@ -206,6 +211,45 @@ public class AINodeTestUtils {
fail("Model " + modelId + " is still loaded on device " + device);
}
+ private static final String[] WRITE_SQL_IN_TREE =
+ new String[] {
+ "CREATE DATABASE root.AI",
+ "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
+ "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
+ "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
+ "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
+ };
+
+ /** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows
of data in tree. */
+ public static void prepareDataInTree() throws SQLException {
+ prepareData(WRITE_SQL_IN_TREE);
+ try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ for (int i = 0; i < 5760; i++) {
+ statement.execute(
+ String.format(
+ "INSERT INTO root.AI(timestamp,s0,s1,s2,s3)
VALUES(%d,%f,%f,%d,%d)",
+ i, (float) i, (double) i, i, i));
+ }
+ }
+ }
+
+ /** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of
data in table. */
+ public static void prepareDataInTable() throws SQLException {
+ try (Connection connection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ statement.execute("CREATE DATABASE db");
+ statement.execute(
+ "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32
FIELD, s3 INT64 FIELD)");
+ for (int i = 0; i < 5760; i++) {
+ statement.execute(
+ String.format(
+ "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
+ i, (float) i, (double) i, i, i));
+ }
+ }
+ }
+
public static class FakeModelInfo {
private final String modelId;
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
index ef0846c3d78..ff4226e734f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
@@ -61,9 +61,13 @@ class ModelManager:
return TRegisterModelResp(
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
)
+ except Exception as e:
+ # Catch-all for other exceptions (mainly from transformers
implementation)
+ return TRegisterModelResp(
+ get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
+ )
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
- self._refresh()
return self._model_storage.show_models(req)
def delete_model(self, req: TDeleteModelReq) -> TSStatus:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index bcb4a5e2056..d0da371bfd5 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -31,7 +31,7 @@ class ModelInfo:
pipeline_cls: str = "",
repo_id: str = "",
auto_map: Optional[Dict] = None,
- _transformers_registered: bool = False,
+ transformers_registered: bool = False,
):
self.model_id = model_id
self.model_type = model_type
@@ -40,7 +40,9 @@ class ModelInfo:
self.pipeline_cls = pipeline_cls
self.repo_id = repo_id
self.auto_map = auto_map # If exists, indicates it's a Transformers
model
- self._transformers_registered = _transformers_registered # Internal
flag: whether registered to Transformers
+ self.transformers_registered = (
+ transformers_registered # Internal flag: whether registered to
Transformers
+ )
def __repr__(self):
return (
@@ -116,7 +118,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
"AutoConfig": "configuration_timer.TimerConfig",
"AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
},
- _transformers_registered=True,
+ transformers_registered=True,
),
"sundial": ModelInfo(
model_id="sundial",
@@ -129,7 +131,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
"AutoConfig": "configuration_sundial.SundialConfig",
"AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
},
- _transformers_registered=True,
+ transformers_registered=True,
),
"chronos2": ModelInfo(
model_id="chronos2",
@@ -139,7 +141,7 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
pipeline_cls="pipeline_chronos2.Chronos2Pipeline",
repo_id="amazon/chronos-2",
auto_map={
- "AutoConfig": "config.Chronos2ForecastingConfig",
+ "AutoConfig": "config.Chronos2CoreConfig",
"AutoModelForCausalLM": "model.Chronos2Model",
},
),
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
index ee09cfd75bb..2cfb07fb56a 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
@@ -227,17 +227,22 @@ class ModelStorage:
model_type = config.get("model_type", "")
auto_map = config.get("auto_map", None)
pipeline_cls = config.get("pipeline_cls", "")
-
+ model_info = ModelInfo(
+ model_id=model_id,
+ model_type=model_type,
+ category=ModelCategory.USER_DEFINED,
+ state=ModelStates.ACTIVE,
+ pipeline_cls=pipeline_cls,
+ auto_map=auto_map,
+ transformers_registered=False, # Lazy registration
+ )
+ with self._lock_pool.get_lock(model_id).write_lock():
+ self._models[ModelCategory.USER_DEFINED.value][model_id] =
model_info
+ if self.ensure_transformers_registered(model_id) is None:
+ model_info.state = ModelStates.INACTIVE
+ else:
+ model_info.transformers_registered = True
with self._lock_pool.get_lock(model_id).write_lock():
- model_info = ModelInfo(
- model_id=model_id,
- model_type=model_type,
- category=ModelCategory.USER_DEFINED,
- state=ModelStates.ACTIVE,
- pipeline_cls=pipeline_cls,
- auto_map=auto_map,
- _transformers_registered=False, # Lazy registration
- )
self._models[ModelCategory.USER_DEFINED.value][model_id] =
model_info
# ==================== Registration Methods ====================
@@ -254,6 +259,7 @@ class ModelStorage:
Raises:
ModelExistedException: If the model_id already exists.
InvalidModelUriException: If the URI format is invalid.
+ Exception: For other errors during transformers model registration.
"""
if self.is_model_registered(model_id):
@@ -287,29 +293,34 @@ class ModelStorage:
state=ModelStates.ACTIVE,
pipeline_cls=pipeline_cls,
auto_map=auto_map,
- _transformers_registered=False, # Register later
+ transformers_registered=False, # Register later
)
self._models[ModelCategory.USER_DEFINED.value][model_id] =
model_info
- if auto_map:
- # Transformers model: immediately register to Transformers
autoloading mechanism
- success = self._register_transformers_model(model_info)
- if success:
- with self._lock_pool.get_lock(model_id).write_lock():
- model_info._transformers_registered = True
- else:
- with self._lock_pool.get_lock(model_id).write_lock():
+ if auto_map:
+ # Transformers model: immediately register to Transformers
autoloading mechanism
+ try:
+ if self._register_transformers_model(model_info):
+ model_info.transformers_registered = True
+ except Exception as e:
model_info.state = ModelStates.INACTIVE
- logger.error(f"Failed to register Transformers model
{model_id}")
- else:
- # Other type models: only log
- self._register_other_model(model_info)
+ logger.error(
+ f"Failed to register Transformers model {model_id},
because {e}"
+ )
+ raise e
+ else:
+ # Other type models: only log
+ self._register_other_model(model_info)
logger.info(f"Successfully registered model {model_id} from URI:
{uri}")
- def _register_transformers_model(self, model_info: ModelInfo):
+ def _register_transformers_model(self, model_info: ModelInfo) -> bool:
"""
Register Transformers model to autoloading mechanism (internal method)
+ Returns:
+ True if registration is successful
+ Raises:
+ Exception: Transformers internal exception if registration fails
"""
auto_map = model_info.auto_map
if not auto_map:
@@ -344,7 +355,7 @@ class ModelStorage:
logger.warning(
f"Failed to register Transformers model {model_info.model_id}:
{e}. Model may still work via auto_map, but ensure module path is correct."
)
- return False
+ raise e
def _register_other_model(self, model_info: ModelInfo):
"""Register other type models (non-Transformers models)"""
@@ -352,12 +363,11 @@ class ModelStorage:
f"Registered other type model: {model_info.model_id}
({model_info.model_type})"
)
- def ensure_transformers_registered(self, model_id: str) -> ModelInfo:
+ def ensure_transformers_registered(self, model_id: str) -> ModelInfo |
None:
"""
- Ensure Transformers model is registered (called for lazy registration)
- This method uses locks to ensure thread safety. All check logic is
within lock protection.
+ Ensure Transformers model is registered.
Returns:
- str: If None, registration failed, otherwise returns model path
+ ModelInfo | None: None if registration failed, otherwise returns
the corresponding ModelInfo
"""
# Use lock to protect entire check-execute process
with self._lock_pool.get_lock(model_id).write_lock():
@@ -369,11 +379,10 @@ class ModelStorage:
break
if not model_info:
- logger.warning(f"Model {model_id} does not exist, cannot
register")
return None
# If already registered, return directly
- if model_info._transformers_registered:
+ if model_info.transformers_registered:
return model_info
# If no auto_map, not a Transformers model, mark as registered
(avoid duplicate checks)
@@ -381,14 +390,13 @@ class ModelStorage:
not model_info.auto_map
or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys()
):
- model_info._transformers_registered = True
+ model_info.transformers_registered = True
return model_info
# Execute registration (under lock protection)
try:
- success = self._register_transformers_model(model_info)
- if success:
- model_info._transformers_registered = True
+ if self._register_transformers_model(model_info):
+ model_info.transformers_registered = True
logger.info(
f"Model {model_id} successfully registered to
Transformers"
)
@@ -401,7 +409,7 @@ class ModelStorage:
except Exception as e:
# Ensure state consistency in exception cases
model_info.state = ModelStates.INACTIVE
- model_info._transformers_registered = False
+ model_info.transformers_registered = False
logger.error(
f"Exception occurred while registering model {model_id} to
Transformers: {e}"
)
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index cda356a948e..ea32f01b6e2 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -40,8 +40,8 @@ struct TAIHeartbeatResp {
}
struct TRegisterModelReq {
- 1: required string uri
- 2: required string modelId
+ 1: required string modelId
+ 2: required string uri
}
struct TConfigs {