This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch ain-bug-fix
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/ain-bug-fix by this push:
     new 90ac5618711 finish
90ac5618711 is described below

commit 90ac56187113927cfe290aa5799d66fa92e84700
Author: Yongzao <[email protected]>
AuthorDate: Sat Jun 28 02:57:10 2025 +0800

    finish
---
 .../ainode/ainode/core/manager/model_manager.py    | 13 ++++++++
 .../ainode/ainode/core/model/model_storage.py      | 38 ++++++++++++++++++++--
 .../iotdb/confignode/persistence/ModelInfo.java    |  9 ++---
 .../operator/process/ai/InferenceOperator.java     |  3 +-
 .../queryengine/plan/analyze/AnalyzeVisitor.java   |  7 ----
 .../iotdb/commons/model/ModelInformation.java      |  2 +-
 6 files changed, 53 insertions(+), 19 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py 
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
index 9b965f0c711..46177378035 100644
--- a/iotdb-core/ainode/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -26,6 +26,7 @@ from ainode.core.exception import (
     InvalidUriError,
 )
 from ainode.core.log import Logger
+from ainode.core.model.model_info import BuiltInModelType, ModelInfo, 
ModelStates
 from ainode.core.model.model_storage import ModelStorage
 from ainode.core.util.status import get_status
 from ainode.thrift.ainode.ttypes import (
@@ -140,3 +141,15 @@ class ModelManager:
 
     def show_models(self) -> TShowModelsResp:
         return self.model_storage.show_models()
+
+    def register_built_in_model(self, model_info: ModelInfo):
+        self.model_storage.register_built_in_model(model_info)
+
+    def update_model_state(self, model_id: str, state: ModelStates):
+        self.model_storage.update_model_state(model_id, state)
+
+    def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
+        """
+        Get the type of the model with the given model_id.
+        """
+        return self.model_storage.get_built_in_model_type(model_id.lower())
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py 
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 2682a139e6c..92f2528845c 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -207,6 +207,9 @@ class ModelStorage(object):
         with self._lock_pool.get_lock(model_id).write_lock():
             if os.path.exists(storage_path):
                 shutil.rmtree(storage_path)
+            if model_id in self._model_info_map:
+                del self._model_info_map[model_id]
+                logger.info(f"Model {model_id} deleted successfully.")
 
     def _is_built_in(self, model_id: str) -> bool:
         """
@@ -218,9 +221,9 @@ class ModelStorage(object):
         Returns:
             bool: True if the model is built-in, False otherwise.
         """
-        return (
-            model_id in self._model_info_map
-            and self._model_info_map[model_id].category == 
ModelCategory.BUILT_IN
+        return model_id in self._model_info_map and (
+                self._model_info_map[model_id].category == 
ModelCategory.BUILT_IN
+                or self._model_info_map[model_id].category == 
ModelCategory.FINE_TUNED
         )
 
     def load_model(self, model_id: str, acceleration: bool) -> Callable:
@@ -291,3 +294,32 @@ class ModelStorage(object):
                 for model_id, model_info in self._model_info_map.items()
             ),
         )
+
+    def register_built_in_model(self, model_info: ModelInfo):
+        with self._lock_pool.get_lock(model_info.model_id).write_lock():
+            self._model_info_map[model_info.model_id] = model_info
+
+    def update_model_state(self, model_id: str, state: ModelStates):
+        with self._lock_pool.get_lock(model_id).write_lock():
+            if model_id in self._model_info_map:
+                self._model_info_map[model_id].state = state
+            else:
+                raise ValueError(f"Model {model_id} does not exist.")
+
+    def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
+        """
+        Get the type of the model with the given model_id.
+
+        Args:
+            model_id (str): The ID of the model.
+
+        Returns:
+            str: The type of the model.
+        """
+        with self._lock_pool.get_lock(model_id).read_lock():
+            if model_id in self._model_info_map:
+                return get_built_in_model_type(
+                    self._model_info_map[model_id].model_type
+                )
+            else:
+                raise ValueError(f"Model {model_id} does not exist.")
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index e96f6fb7bf6..7f0eb6b4e88 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -121,13 +121,8 @@ public class ModelInfo implements SnapshotProcessor {
     try {
       acquireModelTableWriteLock();
       String modelName = plan.getModelName();
-      if (modelTable.containsModel(modelName)) {
-        return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode())
-            .setMessage(String.format("model [%s] has already been created.", 
modelName));
-      } else {
-        modelTable.addModel(new ModelInformation(modelName, 
ModelStatus.LOADING));
-        return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
-      }
+      modelTable.addModel(new ModelInformation(modelName, 
ModelStatus.LOADING));
+      return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
     } catch (Exception e) {
       final String errorMessage =
           String.format(
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
index a384be3ad24..fd51ced46e8 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
@@ -254,7 +254,8 @@ public class InferenceOperator implements ProcessOperator {
   }
 
   private TsBlock preProcess(TsBlock inputTsBlock) {
-    boolean notBuiltIn = 
!modelInferenceDescriptor.getModelInformation().isBuiltIn();
+    //    boolean notBuiltIn = 
!modelInferenceDescriptor.getModelInformation().isBuiltIn();
+    boolean notBuiltIn = false;
     if (windowType == null || windowType == InferenceWindowType.HEAD) {
       if (notBuiltIn
           && totalRow != 
modelInferenceDescriptor.getModelInformation().getInputShape()[0]) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
index b71a9770126..a1033250f68 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
@@ -481,13 +481,6 @@ public class AnalyzeVisitor extends 
StatementVisitor<Analysis, MPPQueryContext>
     if (modelInformation.isBuiltIn()) {
       return;
     }
-
-    if (modelInformation.getInputShape()[0] != windowSize) {
-      throw new SemanticException(
-          String.format(
-              "Window output %d is not equal to input size of model %d",
-              windowSize, modelInformation.getInputShape()[0]));
-    }
   }
 
   private ISchemaTree analyzeSchema(
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 9e84c92a311..3fa10768543 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -84,7 +84,7 @@ public class ModelInformation {
   }
 
   public ModelInformation(String modelName, ModelStatus status) {
-    this.modelType = ModelType.USER_DEFINED;
+    this.modelType = ModelType.BUILT_IN_FORECAST;
     this.modelName = modelName;
     this.inputShape = new int[0];
     this.outputShape = new int[0];

Reply via email to