This is an automated email from the ASF dual-hosted git repository.
CRZbulabula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new e2322e9e9e5 [AINode] Remove Chronos2 DataLoader pin_memory option
(#17822)
e2322e9e9e5 is described below
commit e2322e9e9e5d0ef5a4b982f27fa682c066e94cca
Author: Yongzao <[email protected]>
AuthorDate: Wed Jun 3 16:32:20 2026 +0800
[AINode] Remove Chronos2 DataLoader pin_memory option (#17822)
---
.../iotdb/ainode/it/AINodeSharedClusterIT.java | 83 ++++++++++++++++++++++
.../request_scheduler/basic_request_scheduler.py | 19 +++--
.../core/model/chronos2/pipeline_chronos2.py | 1 -
3 files changed, 91 insertions(+), 12 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
index 4ea2b4af41a..cbd0f62f16e 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
@@ -50,6 +50,7 @@ import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
import static
org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
@@ -90,6 +91,10 @@ public class AINodeSharedClusterIT {
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT 256\")";
private static final int DEFAULT_INPUT_LENGTH = 256;
private static final int DEFAULT_OUTPUT_LENGTH = 48;
+ private static final int LOADED_MODEL_SMOKE_INPUT_LENGTH = 96;
+ private static final int LOADED_MODEL_SMOKE_OUTPUT_LENGTH = 1;
+ private static final List<String> LTSM_LOAD_DEVICE_COMBINATIONS =
+ Arrays.asList("cpu", "0", "cpu,0");
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
"SELECT * FROM FORECAST("
@@ -438,6 +443,84 @@ public class AINodeSharedClusterIT {
// ========== Concurrent forecast tests ==========
+ @Test
+ public void largeTimeSeriesModelLoadInferenceAndForecastTest()
+ throws SQLException, InterruptedException {
+ try (Connection treeConnection =
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+ Statement treeStatement = treeConnection.createStatement();
+ Connection tableConnection =
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement tableStatement = tableConnection.createStatement()) {
+ for (FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
+ for (String devices : LTSM_LOAD_DEVICE_COMBINATIONS) {
+ loadRunAndUnloadModelOnDevices(
+ treeStatement, tableStatement, modelInfo.getModelId(), devices);
+ }
+ }
+ }
+ }
+
+ private void loadRunAndUnloadModelOnDevices(
+ Statement treeStatement, Statement tableStatement, String modelId,
String devices)
+ throws SQLException, InterruptedException {
+ boolean loadSubmitted = false;
+ try {
+ treeStatement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'",
modelId, devices));
+ loadSubmitted = true;
+ checkModelOnSpecifiedDevice(treeStatement, modelId, devices);
+ assertLoadedModelCallInferenceSucceeds(treeStatement, modelId);
+ assertLoadedModelForecastSucceeds(tableStatement, modelId);
+ } finally {
+ if (loadSubmitted) {
+ treeStatement.execute(String.format("UNLOAD MODEL %s FROM DEVICES
'%s'", modelId, devices));
+ checkModelNotOnSpecifiedDevice(treeStatement, modelId, devices);
+ }
+ }
+ }
+
+ private void assertLoadedModelCallInferenceSucceeds(Statement statement,
String modelId)
+ throws SQLException {
+ String callInferenceSQL =
+ String.format(
+ CALL_INFERENCE_SQL_TEMPLATE,
+ modelId,
+ 0,
+ LOADED_MODEL_SMOKE_INPUT_LENGTH,
+ LOADED_MODEL_SMOKE_OUTPUT_LENGTH);
+ try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
+ ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
+ checkHeader(resultSetMetaData, "Time,output");
+ Assert.assertEquals(Types.DOUBLE, resultSetMetaData.getColumnType(2));
+ int count = 0;
+ while (resultSet.next()) {
+ resultSet.getDouble("output");
+ count++;
+ }
+ Assert.assertEquals(LOADED_MODEL_SMOKE_OUTPUT_LENGTH, count);
+ }
+ }
+
+ private void assertLoadedModelForecastSucceeds(Statement statement, String
modelId)
+ throws SQLException {
+ String forecastTableFunctionSQL =
+ String.format(
+ FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+ modelId,
+ 0,
+ 5760,
+ LOADED_MODEL_SMOKE_INPUT_LENGTH,
+ 5760,
+ LOADED_MODEL_SMOKE_OUTPUT_LENGTH,
+ 1,
+ "time");
+ try (ResultSet resultSet =
statement.executeQuery(forecastTableFunctionSQL)) {
+ int count = 0;
+ while (resultSet.next()) {
+ count++;
+ }
+ Assert.assertEquals(LOADED_MODEL_SMOKE_OUTPUT_LENGTH, count);
+ }
+ }
+
@Test
public void concurrentForecastTest() throws SQLException,
InterruptedException {
for (FakeModelInfo modelInfo : CONCURRENT_FORECAST_MODELS) {
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
index ef5d37a18a1..2ee15fabca4 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
+++
b/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
@@ -16,8 +16,6 @@
# under the License.
#
-import os
-
import psutil
import torch
@@ -53,23 +51,22 @@ class BasicRequestScheduler(AbstractRequestScheduler):
def memory_is_available(self):
if "cuda" in self.device.type:
- used = torch.cuda.memory_allocated(self.device)
- reserved = torch.cuda.memory_reserved(self.device)
+ available, total = torch.cuda.mem_get_info(self.device)
elif "cpu" in self.device.type:
- process = psutil.Process(os.getpid())
- used = process.memory_info().rss
- reserved = used
+ memory = psutil.virtual_memory()
+ available = memory.available
+ total = memory.total
else:
- used = 0
- reserved = 0
logger.warning(
f"[Inference] Unsupported device type: {self.device.type}.
Memory checks will not be performed."
)
+ return True
logger.debug(
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] "
- f"Memory used: {used/1024**2:.2f} MB, Max memory:
{self.max_memory_bytes/1024**2:.2f} MB"
+ f"Memory available: {available/1024**2:.2f} MB, Total memory:
{total/1024**2:.2f} MB, "
+ f"Required free memory: {self.max_memory_bytes/1024**2:.2f} MB"
)
- return used < self.max_memory_bytes
+ return available > self.max_memory_bytes
def schedule_activate(self) -> list:
requests = []
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
index 01ff78ba48d..083125e4b99 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
@@ -309,7 +309,6 @@ class Chronos2Pipeline(ForecastPipeline):
test_loader = DataLoader(
test_dataset,
batch_size=None,
- pin_memory=True,
shuffle=False,
drop_last=False,
)