This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b9ac53d1fc84 [SPARK-41916][ML] Torch distributor: support multiple
torchrun processes per task if task.gpu.amount > 1
b9ac53d1fc84 is described below
commit b9ac53d1fc84b74cbd9c7ce38217b396484ae62f
Author: Weichen Xu <[email protected]>
AuthorDate: Wed Dec 17 19:21:01 2025 +0800
[SPARK-41916][ML] Torch distributor: support multiple torchrun processes
per task if task.gpu.amount > 1
### What changes were proposed in this pull request?
Torch distributor: support multiple torchrun processes per task if
task.gpu.amount > 1
### Why are the changes needed?
Torch distributor: support multiple torchrun processes per task if
task.gpu.amount > 1
this is a common case that we should support
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Manually
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53501 from WeichenXu123/SPARK-41916.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
python/pyspark/ml/torch/distributor.py | 20 ++++++++++++++++----
python/pyspark/ml/torch/tests/test_distributor.py | 14 ++++++++++++++
2 files changed, 30 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index ef86f38b716b..3689d079f12d 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -212,8 +212,15 @@ class Distributor:
task_gpu_amount = int(_get_conf(self.spark, key, "0"))
if task_gpu_amount < 1:
raise RuntimeError(f"'{key}' was unset, so gpu usage is
unavailable.")
- # TODO(SPARK-41916): Address situation when
spark.task.resource.gpu.amount > 1
- return math.ceil(self.num_processes / task_gpu_amount)
+
+ if task_gpu_amount > 1:
+ if not (self.num_processes % task_gpu_amount == 0):
+ raise RuntimeError(
+ f"TorchDistributor 'num_processes' value
({self.num_processes}) "
+ "must be a multiple of
'spark.task.resource.gpu.amount' "
+ f"({task_gpu_amount}) value."
+ )
+ return self.num_processes // task_gpu_amount
else:
key = "spark.driver.resource.gpu.amount"
if "gpu" not in _get_resources(self.spark):
@@ -421,14 +428,19 @@ class TorchDistributor(Distributor):
master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
+
+ if cuda_visible_devices := os.environ.get("CUDA_VISIBLE_DEVICES"):
+ processes_per_node = len(cuda_visible_devices.split(","))
+ else:
+ processes_per_node = 1
node_rank = os.environ["RANK"]
+
torchrun_args = [
- f"--nnodes={num_processes}",
+ f"--nnodes={num_processes // processes_per_node}",
f"--node_rank={node_rank}",
f"--rdzv_endpoint={master_addr}:{master_port}",
"--rdzv_id=0", # TODO: setup random ID that is gleaned from env
variables
]
- processes_per_node = 1
return torchrun_args, processes_per_node
@staticmethod
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py
b/python/pyspark/ml/torch/tests/test_distributor.py
index 4ef2f63153af..252b02028c6b 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -306,6 +306,20 @@ class TorchDistributorBaselineUnitTestsMixin:
)
self.delete_env_vars(input_env_vars)
+ @patch.dict(os.environ, {
+ "CUDA_VISIBLE_DEVICES": "0,1,2,3",
+ "MASTER_ADDR": "11.22.33.44",
+ "MASTER_PORT": "6677",
+ "RANK": "1",
+ })
+ def test_multi_gpu_node_get_torchrun_args(self):
+ torchrun_args, processes_per_node =
TorchDistributor._get_torchrun_args(False, 8)
+ self.assertEqual(
+ torchrun_args,
+ ['--nnodes=2', '--node_rank=1',
'--rdzv_endpoint=11.22.33.44:6677', '--rdzv_id=0']
+ )
+ self.assertEqual(processes_per_node, 4)
+
@unittest.skipIf(not have_torch, torch_requirement_message)
class
TorchDistributorBaselineUnitTests(TorchDistributorBaselineUnitTestsMixin,
unittest.TestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]