This is an automated email from the ASF dual-hosted git repository. jackietien pushed a commit to branch force_ci/object_type in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 191edf13094927c16f0f9216864e5f75d64781de Author: Yongzao <[email protected]> AuthorDate: Fri Dec 12 16:43:52 2025 +0800 [AINode] More strict concurrent inference IT (#16898) (cherry picked from commit 7c7b2a6c091f6a2f9fabb6b4fbfe2e737202eaec) --- .../iotdb/ainode/it/AINodeConcurrentForecastIT.java | 2 +- .../org/apache/iotdb/ainode/utils/AINodeTestUtils.java | 17 ++++++++++++++++- .../ainode/core/inference/inference_request_pool.py | 4 +++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java index 64029c1e34b..844ec1d8223 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT { private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class); private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)"; + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)"; @BeforeClass public static void setUp() throws Exception { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 0de90c42925..1d21a4d90f0 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -20,6 +20,7 @@ package org.apache.iotdb.ainode.utils; import com.google.common.collect.ImmutableSet; +import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,6 +35,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -101,6 +103,7 @@ public class AINodeTestUtils { public static void concurrentInference( Statement statement, String sql, int threadCnt, int loop, int expectedOutputLength) throws InterruptedException { + AtomicBoolean allPass = new AtomicBoolean(true); Thread[] threads = new Thread[threadCnt]; for (int i = 0; i < threadCnt; i++) { threads[i] = @@ -113,12 +116,23 @@ public class AINodeTestUtils { while (resultSet.next()) { outputCnt++; } - assertEquals(expectedOutputLength, outputCnt); + if (expectedOutputLength != outputCnt) { + allPass.set(false); + fail( + "Output count mismatch for SQL: " + + sql + + ". Expected: " + + expectedOutputLength + + ", but got: " + + outputCnt); + } } catch (SQLException e) { + allPass.set(false); fail(e.getMessage()); } } } catch (Exception e) { + allPass.set(false); fail(e.getMessage()); } }); @@ -130,6 +144,7 @@ public class AINodeTestUtils { fail("Thread timeout after 10 minutes"); } } + Assert.assertTrue(allPass.get()); } public static void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index a6c415a6c84..c31bcd3d762 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -120,7 +120,9 @@ class InferenceRequestPool(mp.Process): grouped_requests = list(grouped_requests.values()) for requests in grouped_requests: - batch_inputs = self._batcher.batch_request(requests).to(self.device) + batch_inputs = self._batcher.batch_request(requests).to( + "cpu" + ) # The input data should first load to CPU in current version if isinstance(self._inference_pipeline, ForecastPipeline): batch_output = self._inference_pipeline.forecast( batch_inputs,
