This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch rc/2.0.6
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 4e42ba6b52567a77e0ffd49dfc324837c2426493
Author: jtmer <[email protected]>
AuthorDate: Tue Sep 16 09:17:05 2025 +0800

    [AINode] Add a batcher for inference (#16411)
    
    (cherry picked from commit 773433176dc4f1fad831be76f3274edf66b0ea95)
---
 .../ainode/it/AINodeConcurrentInferenceIT.java     | 70 +++++++++++++++++-----
 .../iotdb/ainode/core/inference/pool_group.py      |  4 +-
 2 files changed, 57 insertions(+), 17 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
index b8ebfa9600b..c73fbeb2fbf 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
@@ -81,61 +81,103 @@ public class AINodeConcurrentInferenceIT {
 
   @Test
   public void concurrentCPUCallInferenceTest() throws SQLException, 
InterruptedException {
+    concurrentCPUCallInferenceTest("timer_xl");
+    concurrentCPUCallInferenceTest("sundial");
+  }
+
+  private void concurrentCPUCallInferenceTest(String modelId)
+      throws SQLException, InterruptedException {
     try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
-      statement.execute("LOAD MODEL sundial TO DEVICES \"cpu\"");
-      concurrentInference(statement, "CALL INFERENCE(sundial, \"SELECT s FROM 
root.AI\")", 4, 10);
+      final int threadCnt = 4;
+      final int loop = 10;
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", 
modelId));
+      concurrentInference(
+          statement,
+          String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", 
modelId),
+          threadCnt,
+          loop);
+      statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", 
modelId));
     }
   }
 
   @Test
   public void concurrentGPUCallInferenceTest() throws SQLException, 
InterruptedException {
+    concurrentGPUCallInferenceTest("timer_xl");
+    concurrentGPUCallInferenceTest("sundial");
+  }
+
+  private void concurrentGPUCallInferenceTest(String modelId)
+      throws SQLException, InterruptedException {
     try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
-      statement.execute("LOAD MODEL sundial TO DEVICES \"0,1\"");
-      concurrentInference(statement, "CALL INFERENCE(sundial, \"SELECT s FROM 
root.AI\")", 10, 100);
+      final int threadCnt = 4;
+      final int loop = 10;
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", 
modelId));
+      concurrentInference(
+          statement,
+          String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", 
modelId),
+          threadCnt,
+          loop);
+      statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", 
modelId));
     }
   }
 
   @Test
   public void concurrentCPUForecastTest() throws SQLException, 
InterruptedException {
-    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+    concurrentCPUForecastTest("timer_xl");
+    concurrentCPUForecastTest("sundial");
+  }
+
+  private void concurrentCPUForecastTest(String modelId) throws SQLException, 
InterruptedException {
+    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
       final int threadCnt = 4;
       final int loop = 10;
-      statement.execute("LOAD MODEL sundial TO DEVICES \"cpu\"");
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", 
modelId));
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
-          "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s 
FROM root.AI) ORDER BY time)",
+          String.format(
+              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time)",
+              modelId),
           threadCnt,
           loop);
       long endTime = System.currentTimeMillis();
       LOGGER.info(
           String.format(
-              "Timer-Sundial concurrent inference %d reqs (%d threads, %d 
loops) in CPU takes time: %dms",
-              threadCnt * loop, threadCnt, loop, endTime - startTime));
+              "Model %s concurrent inference %d reqs (%d threads, %d loops) in 
CPU takes time: %dms",
+              modelId, threadCnt * loop, threadCnt, loop, endTime - 
startTime));
+      statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", 
modelId));
     }
   }
 
   @Test
   public void concurrentGPUForecastTest() throws SQLException, 
InterruptedException {
-    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+    concurrentGPUForecastTest("timer_xl");
+    concurrentGPUForecastTest("sundial");
+  }
+
+  public void concurrentGPUForecastTest(String modelId) throws SQLException, 
InterruptedException {
+    try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
       final int threadCnt = 10;
       final int loop = 100;
-      statement.execute("LOAD MODEL sundial TO DEVICES \"0,1\"");
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", 
modelId));
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
-          "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s 
FROM root.AI) ORDER BY time)",
+          String.format(
+              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time)",
+              modelId),
           threadCnt,
           loop);
       long endTime = System.currentTimeMillis();
       LOGGER.info(
           String.format(
-              "Timer-Sundial concurrent inference %d reqs (%d threads, %d 
loops) in GPU takes time: %dms",
-              threadCnt * loop, threadCnt, loop, endTime - startTime));
+              "Model %s concurrent inference %d reqs (%d threads, %d loops) in 
GPU takes time: %dms",
+              modelId, threadCnt * loop, threadCnt, loop, endTime - 
startTime));
+      statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", 
modelId));
     }
   }
 }
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
index a7004549db0..96dce845585 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
@@ -19,9 +19,7 @@ from typing import Dict, Tuple
 
 import torch.multiprocessing as mp
 
-from iotdb.ainode.core.exception import (
-    InferenceModelInternalError,
-)
+from iotdb.ainode.core.exception import InferenceModelInternalError
 from iotdb.ainode.core.inference.dispatcher.basic_dispatcher import 
BasicDispatcher
 from iotdb.ainode.core.inference.inference_request import (
     InferenceRequest,

Reply via email to