This is an automated email from the ASF dual-hosted git repository. JackieTien97 pushed a commit to branch rc/2.0.10 in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 8da40c6bc97da1ebe41a0d7a3ce3b9a90ede90c4 Author: Yongzao <[email protected]> AuthorDate: Wed Jun 3 16:32:20 2026 +0800 [AINode] Remove Chronos2 DataLoader pin_memory option (#17822) --- .../iotdb/ainode/it/AINodeSharedClusterIT.java | 83 ++++++++++++++++++++++ .../request_scheduler/basic_request_scheduler.py | 19 +++-- .../core/model/chronos2/pipeline_chronos2.py | 1 - 3 files changed, 91 insertions(+), 12 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java index 4ea2b4af41a..cbd0f62f16e 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java @@ -50,6 +50,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; @@ -90,6 +91,10 @@ public class AINodeSharedClusterIT { "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT 256\")"; private static final int DEFAULT_INPUT_LENGTH = 256; private static final int DEFAULT_OUTPUT_LENGTH = 48; + private static final int LOADED_MODEL_SMOKE_INPUT_LENGTH = 96; + private static final int LOADED_MODEL_SMOKE_OUTPUT_LENGTH = 1; + private static final List<String> LTSM_LOAD_DEVICE_COMBINATIONS = + Arrays.asList("cpu", "0", "cpu,0"); private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = "SELECT * FROM FORECAST(" @@ -438,6 +443,84 @@ public class AINodeSharedClusterIT { // ========== Concurrent forecast tests ========== + @Test + public void largeTimeSeriesModelLoadInferenceAndForecastTest() + throws SQLException, InterruptedException { + try (Connection treeConnection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement treeStatement = treeConnection.createStatement(); + Connection tableConnection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement tableStatement = tableConnection.createStatement()) { + for (FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) { + for (String devices : LTSM_LOAD_DEVICE_COMBINATIONS) { + loadRunAndUnloadModelOnDevices( + treeStatement, tableStatement, modelInfo.getModelId(), devices); + } + } + } + } + + private void loadRunAndUnloadModelOnDevices( + Statement treeStatement, Statement tableStatement, String modelId, String devices) + throws SQLException, InterruptedException { + boolean loadSubmitted = false; + try { + treeStatement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); + loadSubmitted = true; + checkModelOnSpecifiedDevice(treeStatement, modelId, devices); + assertLoadedModelCallInferenceSucceeds(treeStatement, modelId); + assertLoadedModelForecastSucceeds(tableStatement, modelId); + } finally { + if (loadSubmitted) { + treeStatement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelId, devices)); + checkModelNotOnSpecifiedDevice(treeStatement, modelId, devices); + } + } + } + + private void assertLoadedModelCallInferenceSucceeds(Statement statement, String modelId) + throws SQLException { + String callInferenceSQL = + String.format( + CALL_INFERENCE_SQL_TEMPLATE, + modelId, + 0, + LOADED_MODEL_SMOKE_INPUT_LENGTH, + LOADED_MODEL_SMOKE_OUTPUT_LENGTH); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output"); + Assert.assertEquals(Types.DOUBLE, resultSetMetaData.getColumnType(2)); + int count = 0; + while (resultSet.next()) { + resultSet.getDouble("output"); + count++; + } + Assert.assertEquals(LOADED_MODEL_SMOKE_OUTPUT_LENGTH, count); + } + } + + private void assertLoadedModelForecastSucceeds(Statement statement, String modelId) + throws SQLException { + String forecastTableFunctionSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, + modelId, + 0, + 5760, + LOADED_MODEL_SMOKE_INPUT_LENGTH, + 5760, + LOADED_MODEL_SMOKE_OUTPUT_LENGTH, + 1, + "time"); + try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) { + int count = 0; + while (resultSet.next()) { + count++; + } + Assert.assertEquals(LOADED_MODEL_SMOKE_OUTPUT_LENGTH, count); + } + } + @Test public void concurrentForecastTest() throws SQLException, InterruptedException { for (FakeModelInfo modelInfo : CONCURRENT_FORECAST_MODELS) { diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py index ef5d37a18a1..2ee15fabca4 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py @@ -16,8 +16,6 @@ # under the License. # -import os - import psutil import torch @@ -53,23 +51,22 @@ class BasicRequestScheduler(AbstractRequestScheduler): def memory_is_available(self): if "cuda" in self.device.type: - used = torch.cuda.memory_allocated(self.device) - reserved = torch.cuda.memory_reserved(self.device) + available, total = torch.cuda.mem_get_info(self.device) elif "cpu" in self.device.type: - process = psutil.Process(os.getpid()) - used = process.memory_info().rss - reserved = used + memory = psutil.virtual_memory() + available = memory.available + total = memory.total else: - used = 0 - reserved = 0 logger.warning( f"[Inference] Unsupported device type: {self.device.type}. Memory checks will not be performed." ) + return True logger.debug( f"[Inference][Device-{self.device}][Pool-{self.pool_id}] " - f"Memory used: {used/1024**2:.2f} MB, Max memory: {self.max_memory_bytes/1024**2:.2f} MB" + f"Memory available: {available/1024**2:.2f} MB, Total memory: {total/1024**2:.2f} MB, " + f"Required free memory: {self.max_memory_bytes/1024**2:.2f} MB" ) - return used < self.max_memory_bytes + return available > self.max_memory_bytes def schedule_activate(self) -> list: requests = [] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py index 01ff78ba48d..083125e4b99 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py @@ -309,7 +309,6 @@ class Chronos2Pipeline(ForecastPipeline): test_loader = DataLoader( test_dataset, batch_size=None, - pin_memory=True, shuffle=False, drop_last=False, )
