This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch rc/2.0.6 in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 4e42ba6b52567a77e0ffd49dfc324837c2426493 Author: jtmer <[email protected]> AuthorDate: Tue Sep 16 09:17:05 2025 +0800 [AINode] Add a batcher for inference (#16411) (cherry picked from commit 773433176dc4f1fad831be76f3274edf66b0ea95) --- .../ainode/it/AINodeConcurrentInferenceIT.java | 70 +++++++++++++++++----- .../iotdb/ainode/core/inference/pool_group.py | 4 +- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java index b8ebfa9600b..c73fbeb2fbf 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java @@ -81,61 +81,103 @@ public class AINodeConcurrentInferenceIT { @Test public void concurrentCPUCallInferenceTest() throws SQLException, InterruptedException { + concurrentCPUCallInferenceTest("timer_xl"); + concurrentCPUCallInferenceTest("sundial"); + } + + private void concurrentCPUCallInferenceTest(String modelId) + throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu\""); - concurrentInference(statement, "CALL INFERENCE(sundial, \"SELECT s FROM root.AI\")", 4, 10); + final int threadCnt = 4; + final int loop = 10; + statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId)); + concurrentInference( + statement, + String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId), + threadCnt, + loop); + statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId)); } } @Test public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException { + concurrentGPUCallInferenceTest("timer_xl"); + concurrentGPUCallInferenceTest("sundial"); + } + + private void concurrentGPUCallInferenceTest(String modelId) + throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - statement.execute("LOAD MODEL sundial TO DEVICES \"0,1\""); - concurrentInference(statement, "CALL INFERENCE(sundial, \"SELECT s FROM root.AI\")", 10, 100); + final int threadCnt = 4; + final int loop = 10; + statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId)); + concurrentInference( + statement, + String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId), + threadCnt, + loop); + statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId)); } } @Test public void concurrentCPUForecastTest() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + concurrentCPUForecastTest("timer_xl"); + concurrentCPUForecastTest("sundial"); + } + + private void concurrentCPUForecastTest(String modelId) throws SQLException, InterruptedException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { final int threadCnt = 4; final int loop = 10; - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu\""); + statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId)); long startTime = System.currentTimeMillis(); concurrentInference( statement, - "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s FROM root.AI) ORDER BY time)", + String.format( + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)", + modelId), threadCnt, loop); long endTime = System.currentTimeMillis(); LOGGER.info( String.format( - "Timer-Sundial concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms", - threadCnt * loop, threadCnt, loop, endTime - startTime)); + "Model %s concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms", + modelId, threadCnt * loop, threadCnt, loop, endTime - startTime)); + statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId)); } } @Test public void concurrentGPUForecastTest() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + concurrentGPUForecastTest("timer_xl"); + concurrentGPUForecastTest("sundial"); + } + + public void concurrentGPUForecastTest(String modelId) throws SQLException, InterruptedException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { final int threadCnt = 10; final int loop = 100; - statement.execute("LOAD MODEL sundial TO DEVICES \"0,1\""); + statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId)); long startTime = System.currentTimeMillis(); concurrentInference( statement, - "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s FROM root.AI) ORDER BY time)", + String.format( + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)", + modelId), threadCnt, loop); long endTime = System.currentTimeMillis(); LOGGER.info( String.format( - "Timer-Sundial concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", - threadCnt * loop, threadCnt, loop, endTime - startTime)); + "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", + modelId, threadCnt * loop, threadCnt, loop, endTime - startTime)); + statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId)); } } } diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py index a7004549db0..96dce845585 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py @@ -19,9 +19,7 @@ from typing import Dict, Tuple import torch.multiprocessing as mp -from iotdb.ainode.core.exception import ( - InferenceModelInternalError, -) +from iotdb.ainode.core.exception import InferenceModelInternalError from iotdb.ainode.core.inference.dispatcher.basic_dispatcher import BasicDispatcher from iotdb.ainode.core.inference.inference_request import ( InferenceRequest,
