This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch cp-ain-to-206
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit cfaa0b63d3cd2971532bc8d8c442775ceb211417
Author: Yongzao <[email protected]>
AuthorDate: Mon Sep 29 13:55:38 2025 +0800

    [AINode][Bug fix] Concurrent inference (#16518)
    
    * trigger CI
    
    * bug fix 4 show loaded models
    
    (cherry picked from commit b4dde12d4cf6fddc283d63ce2b82635e6f9510c0)
---
 .../ainode/it/AINodeConcurrentInferenceIT.java     | 84 ++++++++++++++++++----
 .../apache/iotdb/ainode/utils/AINodeTestUtils.java |  7 +-
 2 files changed, 75 insertions(+), 16 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
index c73fbeb2fbf..b5b987594d9 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
@@ -22,15 +22,23 @@ package org.apache.iotdb.ainode.it;
 import org.apache.iotdb.it.env.EnvFactory;
 import org.apache.iotdb.itbase.env.BaseEnv;
 
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.sql.Connection;
+import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Statement;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
 
 import static 
org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
 
@@ -38,6 +46,11 @@ public class AINodeConcurrentInferenceIT {
 
   private static final Logger LOGGER = 
LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);
 
+  private static final Map<String, String> MODEL_ID_TO_TYPE_MAP =
+      ImmutableMap.of(
+          "timer_xl", "Timer-XL",
+          "sundial", "Timer-Sundial");
+
   @BeforeClass
   public static void setUp() throws Exception {
     // Init 1C1D1A cluster environment
@@ -91,12 +104,17 @@ public class AINodeConcurrentInferenceIT {
         Statement statement = connection.createStatement()) {
       final int threadCnt = 4;
       final int loop = 10;
+      final int predictLength = 96;
       statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", 
modelId));
+      checkModelOnSpecifiedDevice(statement, 
MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
       concurrentInference(
           statement,
-          String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", 
modelId),
+          String.format(
+              "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", 
predict_length=%d)",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", 
modelId));
     }
   }
@@ -111,14 +129,20 @@ public class AINodeConcurrentInferenceIT {
       throws SQLException, InterruptedException {
     try (Connection connection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
-      final int threadCnt = 4;
-      final int loop = 10;
-      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", 
modelId));
+      final int threadCnt = 10;
+      final int loop = 100;
+      final int predictLength = 512;
+      final String devices = "0,1";
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", 
modelId, devices));
+      checkModelOnSpecifiedDevice(statement, 
MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
       concurrentInference(
           statement,
-          String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", 
modelId),
+          String.format(
+              "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", 
predict_length=%d)",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", 
modelId));
     }
   }
@@ -134,15 +158,18 @@ public class AINodeConcurrentInferenceIT {
         Statement statement = connection.createStatement()) {
       final int threadCnt = 4;
       final int loop = 10;
+      final int predictLength = 96;
       statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", 
modelId));
+      checkModelOnSpecifiedDevice(statement, 
MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
           String.format(
-              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time)",
-              modelId),
+              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time), predict_length=>%d",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       long endTime = System.currentTimeMillis();
       LOGGER.info(
           String.format(
@@ -163,15 +190,19 @@ public class AINodeConcurrentInferenceIT {
         Statement statement = connection.createStatement()) {
       final int threadCnt = 10;
       final int loop = 100;
-      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", 
modelId));
+      final int predictLength = 512;
+      final String devices = "0,1";
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", 
modelId, devices));
+      checkModelOnSpecifiedDevice(statement, 
MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
           String.format(
-              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time)",
-              modelId),
+              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s 
FROM root.AI) ORDER BY time), predict_length=>%d",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       long endTime = System.currentTimeMillis();
       LOGGER.info(
           String.format(
@@ -180,4 +211,29 @@ public class AINodeConcurrentInferenceIT {
       statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", 
modelId));
     }
   }
+
+  private void checkModelOnSpecifiedDevice(Statement statement, String 
modelType, String device)
+      throws SQLException, InterruptedException {
+    for (int retry = 0; retry < 10; retry++) {
+      Set<String> targetDevices = ImmutableSet.copyOf(device.split(","));
+      Set<String> foundDevices = new HashSet<>();
+      try (final ResultSet resultSet =
+          statement.executeQuery(String.format("SHOW LOADED MODELS %s", 
device))) {
+        while (resultSet.next()) {
+          String deviceId = resultSet.getString(1);
+          String loadedModelType = resultSet.getString(2);
+          int count = resultSet.getInt(3);
+          if (loadedModelType.equals(modelType) && 
targetDevices.contains(deviceId)) {
+            Assert.assertTrue(count > 1);
+            foundDevices.add(deviceId);
+          }
+        }
+        if (foundDevices.containsAll(targetDevices)) {
+          return;
+        }
+      }
+      TimeUnit.SECONDS.sleep(3);
+    }
+    Assert.fail("Model " + modelType + " is not loaded on device " + device);
+  }
 }
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 31e498ae729..cbb0b03b229 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -60,7 +60,8 @@ public class AINodeTestUtils {
     }
   }
 
-  public static void concurrentInference(Statement statement, String sql, int 
threadCnt, int loop)
+  public static void concurrentInference(
+      Statement statement, String sql, int threadCnt, int loop, int 
expectedOutputLength)
       throws InterruptedException {
     Thread[] threads = new Thread[threadCnt];
     for (int i = 0; i < threadCnt; i++) {
@@ -70,9 +71,11 @@ public class AINodeTestUtils {
                 try {
                   for (int j = 0; j < loop; j++) {
                     try (ResultSet resultSet = statement.executeQuery(sql)) {
+                      int outputCnt = 0;
                       while (resultSet.next()) {
-                        // do nothing
+                        outputCnt++;
                       }
+                      assertEquals(expectedOutputLength, outputCnt);
                     } catch (SQLException e) {
                       fail(e.getMessage());
                     }

Reply via email to