This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c7b2a6c091 [AINode] More strict concurrent inference IT (#16898)
7c7b2a6c091 is described below

commit 7c7b2a6c091f6a2f9fabb6b4fbfe2e737202eaec
Author: Yongzao <[email protected]>
AuthorDate: Fri Dec 12 16:43:52 2025 +0800

    [AINode] More strict concurrent inference IT (#16898)
---
 .../iotdb/ainode/it/AINodeConcurrentForecastIT.java     |  2 +-
 .../org/apache/iotdb/ainode/utils/AINodeTestUtils.java  | 17 ++++++++++++++++-
 .../ainode/core/inference/inference_request_pool.py     |  4 +++-
 3 files changed, 20 insertions(+), 3 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
index 64029c1e34b..844ec1d8223 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
@@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
 
   private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
-      "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM 
root.AI) ORDER BY time, forecast_length=>%d)";
+      "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM 
root.AI) ORDER BY time, output_length=>%d)";
 
   @BeforeClass
   public static void setUp() throws Exception {
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 0de90c42925..1d21a4d90f0 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -20,6 +20,7 @@
 package org.apache.iotdb.ainode.utils;
 
 import com.google.common.collect.ImmutableSet;
+import org.junit.Assert;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -34,6 +35,7 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -101,6 +103,7 @@ public class AINodeTestUtils {
   public static void concurrentInference(
       Statement statement, String sql, int threadCnt, int loop, int 
expectedOutputLength)
       throws InterruptedException {
+    AtomicBoolean allPass = new AtomicBoolean(true);
     Thread[] threads = new Thread[threadCnt];
     for (int i = 0; i < threadCnt; i++) {
       threads[i] =
@@ -113,12 +116,23 @@ public class AINodeTestUtils {
                       while (resultSet.next()) {
                         outputCnt++;
                       }
-                      assertEquals(expectedOutputLength, outputCnt);
+                      if (expectedOutputLength != outputCnt) {
+                        allPass.set(false);
+                        fail(
+                            "Output count mismatch for SQL: "
+                                + sql
+                                + ". Expected: "
+                                + expectedOutputLength
+                                + ", but got: "
+                                + outputCnt);
+                      }
                     } catch (SQLException e) {
+                      allPass.set(false);
                       fail(e.getMessage());
                     }
                   }
                 } catch (Exception e) {
+                  allPass.set(false);
                   fail(e.getMessage());
                 }
               });
@@ -130,6 +144,7 @@ public class AINodeTestUtils {
         fail("Thread timeout after 10 minutes");
       }
     }
+    Assert.assertTrue(allPass.get());
   }
 
   public static void checkModelOnSpecifiedDevice(Statement statement, String 
modelId, String device)
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index a6c415a6c84..c31bcd3d762 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -120,7 +120,9 @@ class InferenceRequestPool(mp.Process):
         grouped_requests = list(grouped_requests.values())
 
         for requests in grouped_requests:
-            batch_inputs = 
self._batcher.batch_request(requests).to(self.device)
+            batch_inputs = self._batcher.batch_request(requests).to(
+                "cpu"
+            )  # The input data should first load to CPU in current version
             if isinstance(self._inference_pipeline, ForecastPipeline):
                 batch_output = self._inference_pipeline.forecast(
                     batch_inputs,

Reply via email to