This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new e8af596c640 [AINode] Upgrade torch version (#17323)
e8af596c640 is described below

commit e8af596c640f53e098d4421601a6e4fbb21d3122
Author: Yongzao <[email protected]>
AuthorDate: Fri Mar 20 16:38:38 2026 +0800

    [AINode] Upgrade torch version (#17323)
---
 .../iotdb/it/env/cluster/node/AINodeWrapper.java   | 50 ++++++++++++++++++++--
 .../ainode/core/device/backend/cuda_backend.py     | 20 +++++++++
 iotdb-core/ainode/iotdb/ainode/core/script.py      | 27 ++++++------
 iotdb-core/ainode/pyproject.toml                   |  2 +-
 4 files changed, 83 insertions(+), 16 deletions(-)

diff --git 
a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
 
b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
index 15c2e4761dd..d452e19f381 100644
--- 
a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
+++ 
b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
@@ -38,6 +38,7 @@ import java.nio.file.StandardCopyOption;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Properties;
+import java.util.concurrent.TimeUnit;
 import java.util.stream.Stream;
 
 import static org.apache.iotdb.it.env.cluster.ClusterConstant.AI_NODE_NAME;
@@ -52,7 +53,8 @@ public class AINodeWrapper extends AbstractNodeWrapper {
   private final String seedConfigNode;
   private final int clusterIngressPort;
 
-  private static final String SCRIPT_FILE = "start-ainode.sh";
+  private static final String START_SCRIPT_FILE = "start-ainode.sh";
+  private static final String STOP_SCRIPT_FILE = "stop-ainode.sh";
 
   private static final String SHELL_COMMAND = "bash";
 
@@ -165,8 +167,8 @@ public class AINodeWrapper extends AbstractNodeWrapper {
       // start AINode
       List<String> startCommand = new ArrayList<>();
       startCommand.add(SHELL_COMMAND);
-      startCommand.add(filePrefix + File.separator + SCRIPT_PATH + 
File.separator + SCRIPT_FILE);
-      startCommand.add("-r");
+      startCommand.add(
+          filePrefix + File.separator + SCRIPT_PATH + File.separator + 
START_SCRIPT_FILE);
 
       ProcessBuilder processBuilder =
           new ProcessBuilder(startCommand)
@@ -179,6 +181,48 @@ public class AINodeWrapper extends AbstractNodeWrapper {
     }
   }
 
+  @Override
+  public void stop() {
+    if (this.instance == null) {
+      return;
+    }
+    try {
+      // stop AINode
+      File stdoutFile = new File(getLogPath());
+      String filePrefix = getNodePath();
+      List<String> stopCommand = new ArrayList<>();
+      stopCommand.add(SHELL_COMMAND);
+      stopCommand.add(
+          filePrefix + File.separator + SCRIPT_PATH + File.separator + 
STOP_SCRIPT_FILE);
+      ProcessBuilder processBuilder =
+          new ProcessBuilder(stopCommand)
+              .redirectOutput(ProcessBuilder.Redirect.appendTo(stdoutFile))
+              .redirectError(ProcessBuilder.Redirect.appendTo(stdoutFile));
+      Process stopProcess = processBuilder.inheritIO().start();
+      if (!stopProcess.waitFor(20, TimeUnit.SECONDS)) {
+        logger.warn("Node {} does not exit within 20s, killing it", getId());
+        if (!this.instance.destroyForcibly().waitFor(10, TimeUnit.SECONDS)) {
+          logger.error("Cannot forcibly stop node {}", getId());
+        }
+      }
+      int exitCode = stopProcess.exitValue();
+      if (exitCode != 0) {
+        logger.warn("Node {}'s stop script exited with code {}", getId(), 
exitCode);
+      }
+    } catch (InterruptedException e) {
+      Thread.currentThread().interrupt();
+      logger.error("Waiting node to shutdown error.", e);
+    } catch (IOException e) {
+      logger.error("Waiting node to shutdown error.", e);
+    }
+    logger.info("In test {} {} stopped.", getTestLogDirName(), getId());
+  }
+
+  @Override
+  public void stopForcibly() {
+    this.stop();
+  }
+
   @Override
   public int getMetricPort() {
     // no metric currently
diff --git a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py 
b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
index c7533cc4dd7..553101bb844 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+import time
 
 import torch
 
@@ -24,6 +25,9 @@ from iotdb.ainode.core.device.backend.base import 
BackendAdapter, BackendType
 class CUDABackend(BackendAdapter):
     type = BackendType.CUDA
 
+    def __init__(self) -> None:
+        self._safe_cuda_init()
+
     def is_available(self) -> bool:
         return torch.cuda.is_available()
 
@@ -37,3 +41,19 @@ class CUDABackend(BackendAdapter):
 
     def set_device(self, index: int) -> None:
         torch.cuda.set_device(index)
+
+    def _safe_cuda_init(self) -> None:
+        # Safe CUDA initialization to avoid potential deadlocks
+        # This is a workaround for certain PyTorch versions where the first 
CUDA call can cause a long delay
+        # By calling a simple CUDA operation at startup, we can ensure that 
the CUDA context is initialized early
+        # and avoid unexpected delays during actual model loading or inference.
+        attempt_cnt = 3
+        for attempt in range(attempt_cnt):
+            try:
+                if self.is_available():
+                    return
+                raise RuntimeError("CUDA not available")
+            except Exception as e:
+                print(f"CUDA init attempt {attempt + 1} failed: {e}")
+                if attempt < attempt_cnt:
+                    time.sleep(1.5)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/script.py 
b/iotdb-core/ainode/iotdb/ainode/core/script.py
index 38653d7ceab..86373a3e065 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/script.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/script.py
@@ -15,11 +15,26 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
 import multiprocessing
 import sys
 
+# PyInstaller multiprocessing support
+# freeze_support() is essential for PyInstaller frozen executables on all 
platforms
+# It detects if the current process is a multiprocessing child process
+# If it is, it executes the child process target function and exits
+# If it's not, it returns immediately and continues with main() execution
+# This prevents child processes from executing the main application logic
+if getattr(sys, "frozen", False):
+    # Call freeze_support() for both standard multiprocessing and 
torch.multiprocessing
+    multiprocessing.freeze_support()
+    multiprocessing.set_start_method("spawn", force=True)
+
 import torch.multiprocessing as mp
 
+mp.freeze_support()
+mp.set_start_method("spawn", force=True)
+
 from iotdb.ainode.core.ai_node import AINode
 from iotdb.ainode.core.log import Logger
 
@@ -42,7 +57,6 @@ def main():
     command = arguments[1]
     if command == "start":
         try:
-            mp.set_start_method("spawn", force=True)
             logger.info(f"Current multiprocess start method: 
{mp.get_start_method()}")
             logger.info("IoTDB-AINode is starting...")
             ai_node = AINode()
@@ -55,15 +69,4 @@ def main():
 
 
 if __name__ == "__main__":
-    # PyInstaller multiprocessing support
-    # freeze_support() is essential for PyInstaller frozen executables on all 
platforms
-    # It detects if the current process is a multiprocessing child process
-    # If it is, it executes the child process target function and exits
-    # If it's not, it returns immediately and continues with main() execution
-    # This prevents child processes from executing the main application logic
-    if getattr(sys, "frozen", False):
-        # Call freeze_support() for both standard multiprocessing and 
torch.multiprocessing
-        multiprocessing.freeze_support()
-        mp.freeze_support()
-
     main()
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index d92a466daf7..9a142fe7259 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -79,7 +79,7 @@ exclude = [
 python = ">=3.11.0,<3.12.0"
 
 # ---- DL / HF stack ----
-torch = "^2.8.0,<2.9.0"
+torch = "^2.9.0,<2.10.0"
 torchmetrics = "^1.8.0"
 transformers = "==4.56.2"
 tokenizers = ">=0.22.0,<=0.23.0"

Reply via email to