This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch rc/2.0.6 in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit c7a614c8ccfa235ad1785521a6480d131e0a9133 Author: Yongzao <[email protected]> AuthorDate: Mon Sep 29 13:55:38 2025 +0800 [AINode][Bug fix] Concurrent inference (#16518) * trigger CI * bug fix 4 show loaded models (cherry picked from commit b4dde12d4cf6fddc283d63ce2b82635e6f9510c0) --- .../ainode/it/AINodeConcurrentInferenceIT.java | 84 ++++++++++++++++++---- .../apache/iotdb/ainode/utils/AINodeTestUtils.java | 7 +- 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java index c73fbeb2fbf..b5b987594d9 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java @@ -22,15 +22,23 @@ package org.apache.iotdb.ainode.it; import org.apache.iotdb.it.env.EnvFactory; import org.apache.iotdb.itbase.env.BaseEnv; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.sql.Connection; +import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; @@ -38,6 +46,11 @@ public class AINodeConcurrentInferenceIT { private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class); + private static final Map<String, String> MODEL_ID_TO_TYPE_MAP = + ImmutableMap.of( + "timer_xl", "Timer-XL", + "sundial", "Timer-Sundial"); + @BeforeClass public static void setUp() throws Exception { // Init 1C1D1A cluster environment @@ -91,12 +104,17 @@ public class AINodeConcurrentInferenceIT { Statement statement = connection.createStatement()) { final int threadCnt = 4; final int loop = 10; + final int predictLength = 96; statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId)); + checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu"); concurrentInference( statement, - String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId), + String.format( + "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)", + modelId, predictLength), threadCnt, - loop); + loop, + predictLength); statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId)); } } @@ -111,14 +129,20 @@ public class AINodeConcurrentInferenceIT { throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - final int threadCnt = 4; - final int loop = 10; - statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId)); + final int threadCnt = 10; + final int loop = 100; + final int predictLength = 512; + final String devices = "0,1"; + statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices)); + checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices); concurrentInference( statement, - String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId), + String.format( + "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)", + modelId, predictLength), threadCnt, - loop); + loop, + predictLength); statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId)); } } @@ -134,15 +158,18 @@ public class AINodeConcurrentInferenceIT { Statement statement = connection.createStatement()) { final int threadCnt = 4; final int loop = 10; + final int predictLength = 96; statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId)); + checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu"); long startTime = System.currentTimeMillis(); concurrentInference( statement, String.format( - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)", - modelId), + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d", + modelId, predictLength), threadCnt, - loop); + loop, + predictLength); long endTime = System.currentTimeMillis(); LOGGER.info( String.format( @@ -163,15 +190,19 @@ public class AINodeConcurrentInferenceIT { Statement statement = connection.createStatement()) { final int threadCnt = 10; final int loop = 100; - statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId)); + final int predictLength = 512; + final String devices = "0,1"; + statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices)); + checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices); long startTime = System.currentTimeMillis(); concurrentInference( statement, String.format( - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)", - modelId), + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d", + modelId, predictLength), threadCnt, - loop); + loop, + predictLength); long endTime = System.currentTimeMillis(); LOGGER.info( String.format( @@ -180,4 +211,29 @@ public class AINodeConcurrentInferenceIT { statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId)); } } + + private void checkModelOnSpecifiedDevice(Statement statement, String modelType, String device) + throws SQLException, InterruptedException { + for (int retry = 0; retry < 10; retry++) { + Set<String> targetDevices = ImmutableSet.copyOf(device.split(",")); + Set<String> foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS %s", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString(1); + String loadedModelType = resultSet.getString(2); + int count = resultSet.getInt(3); + if (loadedModelType.equals(modelType) && targetDevices.contains(deviceId)) { + Assert.assertTrue(count > 1); + foundDevices.add(deviceId); + } + } + if (foundDevices.containsAll(targetDevices)) { + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + Assert.fail("Model " + modelType + " is not loaded on device " + device); + } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 31e498ae729..cbb0b03b229 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -60,7 +60,8 @@ public class AINodeTestUtils { } } - public static void concurrentInference(Statement statement, String sql, int threadCnt, int loop) + public static void concurrentInference( + Statement statement, String sql, int threadCnt, int loop, int expectedOutputLength) throws InterruptedException { Thread[] threads = new Thread[threadCnt]; for (int i = 0; i < threadCnt; i++) { @@ -70,9 +71,11 @@ public class AINodeTestUtils { try { for (int j = 0; j < loop; j++) { try (ResultSet resultSet = statement.executeQuery(sql)) { + int outputCnt = 0; while (resultSet.next()) { - // do nothing + outputCnt++; } + assertEquals(expectedOutputLength, outputCnt); } catch (SQLException e) { fail(e.getMessage()); }
