This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 7c7b2a6c091 [AINode] More strict concurrent inference IT (#16898)
7c7b2a6c091 is described below
commit 7c7b2a6c091f6a2f9fabb6b4fbfe2e737202eaec
Author: Yongzao <[email protected]>
AuthorDate: Fri Dec 12 16:43:52 2025 +0800
[AINode] More strict concurrent inference IT (#16898)
---
.../iotdb/ainode/it/AINodeConcurrentForecastIT.java | 2 +-
.../org/apache/iotdb/ainode/utils/AINodeTestUtils.java | 17 ++++++++++++++++-
.../ainode/core/inference/inference_request_pool.py | 4 +++-
3 files changed, 20 insertions(+), 3 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
index 64029c1e34b..844ec1d8223 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
@@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
private static final Logger LOGGER =
LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM
root.AI) ORDER BY time, forecast_length=>%d)";
+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM
root.AI) ORDER BY time, output_length=>%d)";
@BeforeClass
public static void setUp() throws Exception {
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 0de90c42925..1d21a4d90f0 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -20,6 +20,7 @@
package org.apache.iotdb.ainode.utils;
import com.google.common.collect.ImmutableSet;
+import org.junit.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -34,6 +35,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -101,6 +103,7 @@ public class AINodeTestUtils {
public static void concurrentInference(
Statement statement, String sql, int threadCnt, int loop, int
expectedOutputLength)
throws InterruptedException {
+ AtomicBoolean allPass = new AtomicBoolean(true);
Thread[] threads = new Thread[threadCnt];
for (int i = 0; i < threadCnt; i++) {
threads[i] =
@@ -113,12 +116,23 @@ public class AINodeTestUtils {
while (resultSet.next()) {
outputCnt++;
}
- assertEquals(expectedOutputLength, outputCnt);
+ if (expectedOutputLength != outputCnt) {
+ allPass.set(false);
+ fail(
+ "Output count mismatch for SQL: "
+ + sql
+ + ". Expected: "
+ + expectedOutputLength
+ + ", but got: "
+ + outputCnt);
+ }
} catch (SQLException e) {
+ allPass.set(false);
fail(e.getMessage());
}
}
} catch (Exception e) {
+ allPass.set(false);
fail(e.getMessage());
}
});
@@ -130,6 +144,7 @@ public class AINodeTestUtils {
fail("Thread timeout after 10 minutes");
}
}
+ Assert.assertTrue(allPass.get());
}
public static void checkModelOnSpecifiedDevice(Statement statement, String
modelId, String device)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index a6c415a6c84..c31bcd3d762 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -120,7 +120,9 @@ class InferenceRequestPool(mp.Process):
grouped_requests = list(grouped_requests.values())
for requests in grouped_requests:
- batch_inputs =
self._batcher.batch_request(requests).to(self.device)
+ batch_inputs = self._batcher.batch_request(requests).to(
+ "cpu"
+ ) # The input data should first load to CPU in current version
if isinstance(self._inference_pipeline, ForecastPipeline):
batch_output = self._inference_pipeline.forecast(
batch_inputs,