This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b9ac53d1fc84 [SPARK-41916][ML] Torch distributor: support multiple 
torchrun processes per task if task.gpu.amount > 1
b9ac53d1fc84 is described below

commit b9ac53d1fc84b74cbd9c7ce38217b396484ae62f
Author: Weichen Xu <[email protected]>
AuthorDate: Wed Dec 17 19:21:01 2025 +0800

    [SPARK-41916][ML] Torch distributor: support multiple torchrun processes 
per task if task.gpu.amount > 1
    
    ### What changes were proposed in this pull request?
    
    Torch distributor: support multiple torchrun processes per task if 
task.gpu.amount > 1
    
    ### Why are the changes needed?
    
    Torch distributor: support multiple torchrun processes per task if 
task.gpu.amount > 1
    this is a common case that we should support
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Manually
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53501 from WeichenXu123/SPARK-41916.
    
    Authored-by: Weichen Xu <[email protected]>
    Signed-off-by: Weichen Xu <[email protected]>
---
 python/pyspark/ml/torch/distributor.py            | 20 ++++++++++++++++----
 python/pyspark/ml/torch/tests/test_distributor.py | 14 ++++++++++++++
 2 files changed, 30 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index ef86f38b716b..3689d079f12d 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -212,8 +212,15 @@ class Distributor:
                 task_gpu_amount = int(_get_conf(self.spark, key, "0"))
                 if task_gpu_amount < 1:
                     raise RuntimeError(f"'{key}' was unset, so gpu usage is 
unavailable.")
-                # TODO(SPARK-41916): Address situation when 
spark.task.resource.gpu.amount > 1
-                return math.ceil(self.num_processes / task_gpu_amount)
+
+                if task_gpu_amount > 1:
+                    if not (self.num_processes % task_gpu_amount == 0):
+                        raise RuntimeError(
+                            f"TorchDistributor 'num_processes' value 
({self.num_processes}) "
+                            "must be a multiple of 
'spark.task.resource.gpu.amount' "
+                            f"({task_gpu_amount}) value."
+                        )
+                return self.num_processes // task_gpu_amount
             else:
                 key = "spark.driver.resource.gpu.amount"
                 if "gpu" not in _get_resources(self.spark):
@@ -421,14 +428,19 @@ class TorchDistributor(Distributor):
 
         master_addr = os.environ["MASTER_ADDR"]
         master_port = os.environ["MASTER_PORT"]
+
+        if cuda_visible_devices := os.environ.get("CUDA_VISIBLE_DEVICES"):
+            processes_per_node = len(cuda_visible_devices.split(","))
+        else:
+            processes_per_node = 1
         node_rank = os.environ["RANK"]
+
         torchrun_args = [
-            f"--nnodes={num_processes}",
+            f"--nnodes={num_processes // processes_per_node}",
             f"--node_rank={node_rank}",
             f"--rdzv_endpoint={master_addr}:{master_port}",
             "--rdzv_id=0",  # TODO: setup random ID that is gleaned from env 
variables
         ]
-        processes_per_node = 1
         return torchrun_args, processes_per_node
 
     @staticmethod
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index 4ef2f63153af..252b02028c6b 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -306,6 +306,20 @@ class TorchDistributorBaselineUnitTestsMixin:
         )
         self.delete_env_vars(input_env_vars)
 
+    @patch.dict(os.environ, {
+        "CUDA_VISIBLE_DEVICES": "0,1,2,3",
+        "MASTER_ADDR": "11.22.33.44",
+        "MASTER_PORT": "6677",
+        "RANK": "1",
+    })
+    def test_multi_gpu_node_get_torchrun_args(self):
+        torchrun_args, processes_per_node = 
TorchDistributor._get_torchrun_args(False, 8)
+        self.assertEqual(
+            torchrun_args,
+            ['--nnodes=2', '--node_rank=1', 
'--rdzv_endpoint=11.22.33.44:6677', '--rdzv_id=0']
+        )
+        self.assertEqual(processes_per_node, 4)
+
 
 @unittest.skipIf(not have_torch, torch_requirement_message)
 class 
TorchDistributorBaselineUnitTests(TorchDistributorBaselineUnitTestsMixin, 
unittest.TestCase):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to