This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6b800ba8461 [SPARK-41591][PYTHON][ML] Training PyTorch Files on Single Node Multi GPU 6b800ba8461 is described below commit 6b800ba8461935a205d8c15eba2ff11f141dea47 Author: Rithwik Ediga Lakhamsani <rithwik.ed...@databricks.com> AuthorDate: Thu Jan 12 08:42:01 2023 +0900 [SPARK-41591][PYTHON][ML] Training PyTorch Files on Single Node Multi GPU ### What changes were proposed in this pull request? This is an addition to https://github.com/apache/spark/pull/39146 to add support for single node training using PyTorch files. The users would follow the second workflow in the [design document](https://docs.google.com/document/d/1QPO1Ly8WteL6aIPvVcR7Xne9qVtJiB3fdrRn7NwBcpA/edit#heading=h.8yvw9xq428fh) to run training. I added some new utility functions as well as built on top of current functions. ### Why are the changes needed? Look at the [main ticket](https://issues.apache.org/jira/browse/SPARK-41589) for more details. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Some unit tests were added and integration tests will be added in a later PR (https://issues.apache.org/jira/browse/SPARK-41777). Closes #39188 from rithwik-db/pytorch-file-local-training. Authored-by: Rithwik Ediga Lakhamsani <rithwik.ed...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/ml/torch/distributor.py | 186 ++++++++++++++++++++- python/pyspark/ml/torch/tests/test_distributor.py | 147 +++++++++++++++- .../pyspark/ml/torch/torch_run_process_wrapper.py | 83 +++++++++ 3 files changed, 412 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 2a4027cbb25..80d5ad31c3c 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -15,7 +15,16 @@ # limitations under the License. # +import collections +import ctypes import math +import os +import random +import re +import signal +import sys +import subprocess +import time from typing import Union, Callable, Optional, Any import warnings @@ -34,8 +43,8 @@ def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool: Parameters ---------- - sc : SparkContext - The SparkContext for the distributor. + sc : :class:`SparkContext` + The :class:`SparkContext` for the distributor. key : str string for conf name default_value : str @@ -64,6 +73,42 @@ def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool: ) +def get_gpus_owned(sc: SparkContext) -> list[str]: + """Gets the number of GPUs that Spark scheduled to the calling task. + + Parameters + ---------- + sc : :class:`SparkContext` + The :class:`SparkContext` that has GPUs available. + + Returns + ------- + list + The correct mapping of addresses to workers. + + Raises + ------ + ValueError + Raised if the input addresses were not found. + """ + CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" + pattern = re.compile("^[1-9][0-9]*|0$") + addresses = sc.resources["gpu"].addresses + if any(not pattern.match(address) for address in addresses): + raise ValueError( + f"Found GPU addresses {addresses} which " + "are not all in the correct format " + "for CUDA_VISIBLE_DEVICES, which requires " + "integers with no zero padding." + ) + if CUDA_VISIBLE_DEVICES in os.environ: + gpu_indices = list(map(int, addresses)) + gpu_list = os.environ[CUDA_VISIBLE_DEVICES].split(",") + gpu_owned = [gpu_list[i] for i in gpu_indices] + return gpu_owned + return addresses + + class Distributor: """ The parent class for TorchDistributor. This class shouldn't be instantiated directly. @@ -85,6 +130,12 @@ class Distributor: self.num_tasks = self._get_num_tasks() self.ssl_conf = None + def _create_input_params(self) -> dict[str, Any]: + input_params = self.__dict__.copy() + for unneeded_param in ["spark", "sc", "ssl_conf"]: + del input_params[unneeded_param] + return input_params + def _get_num_tasks(self) -> int: """ Returns the number of Spark tasks to use for distributed training @@ -261,6 +312,130 @@ class TorchDistributor(Distributor): super().__init__(num_processes, local_mode, use_gpu) self.ssl_conf = "pytorch.spark.distributor.ignoreSsl" # type: ignore self._validate_input_params() + self.input_params = self._create_input_params() + + @staticmethod + def _create_torchrun_command( + input_params: dict[str, Any], path_to_train_file: str, *args: Any + ) -> list[str]: + local_mode = input_params["local_mode"] + num_processes = input_params["num_processes"] + + if local_mode: + torchrun_args = ["--standalone", "--nnodes=1"] + processes_per_node = num_processes + else: + pass + # TODO(SPARK-41592): Handle distributed training + + args_string = list(map(str, args)) # converting all args to strings + + return ( + [sys.executable, "-m", "pyspark.ml.torch.distributor.torch_run_process_wrapper"] + + torchrun_args + + [f"--nproc_per_node={processes_per_node}"] + + [path_to_train_file, *args_string] + ) + + @staticmethod + def _execute_command( + cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True + ) -> None: + _TAIL_LINES_TO_KEEP = 100 + + def sigterm_on_parent_death() -> None: + """ + Uses prctl to automatically send SIGTERM to the command process when its parent is dead. + This handles the case when the parent is a PySpark worker process. + If a user cancels the PySpark job, the worker process gets killed, regardless of + PySpark daemon and worker reuse settings. + """ + if _prctl: + try: + libc = ctypes.CDLL("libc.so.6") + # Set the parent process death signal of the command process to SIGTERM. + libc.prctl(1, signal.SIGTERM) + except OSError: + pass + + task = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE, + env=os.environ, + preexec_fn=sigterm_on_parent_death, + ) + task.stdin.close() # type: ignore + tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP) + try: + for line in task.stdout: # type: ignore + decoded = line.decode() + tail.append(decoded) + if redirect_to_stdout: + sys.stdout.write(decoded) + task.wait() + finally: + if task.poll() is None: + try: + task.terminate() # SIGTERM + time.sleep(0.5) + if task.poll() is None: + task.kill() # SIGKILL + except OSError: + pass + if task.returncode != os.EX_OK: + if len(tail) == _TAIL_LINES_TO_KEEP: + last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task output are" + else: + last_n_msg = "task output is" + task_output = "".join(tail) + raise RuntimeError( + f"Command {cmd} failed with return code {task.returncode}." + f"The {last_n_msg} included below: {task_output}" + ) + + def _run_local_training( + self, + framework_wrapper_fn: Optional[Callable], + train_object: Union[Callable, str], + *args: Any, + ) -> Optional[Any]: + CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" + cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ + old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "") + try: + if self.use_gpu: + gpus_owned = get_gpus_owned(self.sc) + + if self.num_processes > len(gpus_owned): + raise ValueError( + f"""{self.num_processes} processes were requested + for local training with GPU training but only + {len(gpus_owned)} GPUs were available.""" + ) + random.seed(hash(train_object)) + selected_gpus = [str(e) for e in random.sample(gpus_owned, self.num_processes)] + os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus) + + output = framework_wrapper_fn(self.input_params, train_object, *args) # type: ignore + finally: + if cuda_state_was_set: + os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices + else: + if CUDA_VISIBLE_DEVICES in os.environ: + del os.environ[CUDA_VISIBLE_DEVICES] + + return output + + @staticmethod + def _run_training_on_pytorch_file( + input_params: dict[str, Any], train_path: str, *args: Any + ) -> None: + training_command = TorchDistributor._create_torchrun_command( + input_params, train_path, *args + ) + TorchDistributor._execute_command(training_command) def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]: """Runs distributed training. @@ -278,4 +453,9 @@ class TorchDistributor(Distributor): Returns the output of train_object called with args if train_object is a Callable with an expected output. """ - pass + framework_wrapper_fn = None + if isinstance(train_object, str): + framework_wrapper_fn = TorchDistributor._run_training_on_pytorch_file + if self.local_mode: + output = self._run_local_training(framework_wrapper_fn, train_object, *args) + return output diff --git a/python/pyspark/ml/torch/tests/test_distributor.py b/python/pyspark/ml/torch/tests/test_distributor.py index e84505f92fe..4b24eff8742 100644 --- a/python/pyspark/ml/torch/tests/test_distributor.py +++ b/python/pyspark/ml/torch/tests/test_distributor.py @@ -15,17 +15,38 @@ # limitations under the License. # +import contextlib import os +from six import StringIO # type: ignore import stat +import subprocess +import sys +import time import tempfile +import threading +from typing import Callable import unittest +from unittest.mock import patch from pyspark import SparkConf, SparkContext -from pyspark.ml.torch.distributor import TorchDistributor +from pyspark.ml.torch.distributor import TorchDistributor, get_gpus_owned +from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, check_parent_alive from pyspark.sql import SparkSession from pyspark.testing.utils import SPARK_HOME +@contextlib.contextmanager +def patch_stdout() -> StringIO: + """patch stdout and give an output""" + sys_stdout = sys.stdout + io_out = StringIO() + sys.stdout = io_out + try: + yield io_out + finally: + sys.stdout = sys_stdout + + class TorchDistributorBaselineUnitTests(unittest.TestCase): def setUp(self) -> None: conf = SparkConf() @@ -35,6 +56,14 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase): def tearDown(self) -> None: self.spark.stop() + def setup_env_vars(self, input_map: dict[str, str]) -> None: + for key, value in input_map.items(): + os.environ[key] = value + + def delete_env_vars(self, input_map: dict[str, str]) -> None: + for key in input_map.keys(): + del os.environ[key] + def test_validate_correct_inputs(self) -> None: inputs = [ (1, True, False), @@ -90,6 +119,55 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "unset"): TorchDistributor(num_processes, False, True) + def test_execute_command(self) -> None: + """Test that run command runs the process and logs are written correctly""" + + with patch_stdout() as output: + stdout_command = ["echo", "hello_stdout"] + TorchDistributor._execute_command(stdout_command) + self.assertIn( + "hello_stdout", output.getvalue().strip(), "hello_stdout should print to stdout" + ) + + with patch_stdout() as output: + stderr_command = ["bash", "-c", "echo hello_stderr >&2"] + TorchDistributor._execute_command(stderr_command) + self.assertIn( + "hello_stderr", output.getvalue().strip(), "hello_stderr should print to stdout" + ) + + # include command in the exception message + with self.assertRaisesRegexp(RuntimeError, "exit 1"): # pylint: disable=deprecated-method + error_command = ["bash", "-c", "exit 1"] + TorchDistributor._execute_command(error_command) + + with self.assertRaisesRegexp(RuntimeError, "abcdef"): # pylint: disable=deprecated-method + error_command = ["bash", "-c", "'abc''def'"] + TorchDistributor._execute_command(error_command) + + def test_create_torchrun_command(self) -> None: + train_path = "train.py" + args_string = ["1", "3"] + local_mode_input_params = {"num_processes": 4, "local_mode": True} + + expected_local_mode_output = [ + sys.executable, + "-m", + "pyspark.ml.torch.distributor.torch_run_process_wrapper", + "--standalone", + "--nnodes=1", + "--nproc_per_node=4", + "train.py", + "1", + "3", + ] + self.assertEqual( + TorchDistributor._create_torchrun_command( + local_mode_input_params, train_path, *args_string + ), + expected_local_mode_output, + ) + class TorchDistributorLocalUnitTests(unittest.TestCase): def setUp(self) -> None: @@ -118,6 +196,14 @@ class TorchDistributorLocalUnitTests(unittest.TestCase): os.unlink(self.tempFile.name) self.spark.stop() + def setup_env_vars(self, input_map: dict[str, str]) -> None: + for key, value in input_map.items(): + os.environ[key] = value + + def delete_env_vars(self, input_map: dict[str, str]) -> None: + for key in input_map.keys(): + del os.environ[key] + def test_get_num_tasks_locally(self) -> None: succeeds = [1, 2] fails = [4, 8] @@ -133,6 +219,42 @@ class TorchDistributorLocalUnitTests(unittest.TestCase): distributor = TorchDistributor(num_processes, True, True) distributor.num_processes = 3 + def test_get_gpus_owned_local(self) -> None: + addresses = ["0", "1", "2"] + self.assertEqual(get_gpus_owned(self.sc), addresses) + + env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"} + self.setup_env_vars(env_vars) + self.assertEqual(get_gpus_owned(self.sc), ["3", "4", "5"]) + self.delete_env_vars(env_vars) + + def test_local_training_succeeds(self) -> None: + CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" + inputs = [ + ("0,1,2", 1, True, "1"), + ("0,1,2", 3, True, "1,2,0"), + ("0,1,2", 2, False, "0,1,2"), + (None, 3, False, "NONE"), + ] + + for i, (cuda_env_var, num_processes, use_gpu, expected) in enumerate(inputs): + with self.subTest(f"subtest: {i + 1}"): + # setup + if cuda_env_var: + self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var}) + + dist = TorchDistributor(num_processes, True, use_gpu) + dist._run_training_on_pytorch_file = lambda *args: os.environ.get( # type: ignore + CUDA_VISIBLE_DEVICES, "NONE" + ) + self.assertEqual( + expected, + dist._run_local_training(dist._run_training_on_pytorch_file, "train.py"), + ) + # cleanup + if cuda_env_var: + self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var}) + class TorchDistributorDistributedUnitTests(unittest.TestCase): def setUp(self) -> None: @@ -178,6 +300,29 @@ class TorchDistributorDistributedUnitTests(unittest.TestCase): self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1") +class TorchWrapperUnitTests(unittest.TestCase): + def test_clean_and_terminate(self) -> None: + def kill_task(task: "subprocess.Popen") -> None: + time.sleep(1) + clean_and_terminate(task) + + command = [sys.executable, "-c", '"import time; time.sleep(20)"'] + task = subprocess.Popen(command) + t = threading.Thread(target=kill_task, args=(task,)) + t.start() + time.sleep(2) + self.assertEqual(task.poll(), 0) # implies task ended + + @patch("pyspark.ml.torch.torch_run_process_wrapper.clean_and_terminate") + def test_check_parent_alive(self, mock_clean_and_terminate: Callable) -> None: + command = [sys.executable, "-c", '"import time; time.sleep(2)"'] + task = subprocess.Popen(command) + t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True) + t.start() + time.sleep(2) + self.assertEqual(mock_clean_and_terminate.call_count, 0) # type: ignore[attr-defined] + + if __name__ == "__main__": from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403 type: ignore diff --git a/python/pyspark/ml/torch/torch_run_process_wrapper.py b/python/pyspark/ml/torch/torch_run_process_wrapper.py new file mode 100644 index 00000000000..67ec492329d --- /dev/null +++ b/python/pyspark/ml/torch/torch_run_process_wrapper.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import signal +import subprocess +import sys +import threading +import time +from typing import Any + + +def clean_and_terminate(task: "subprocess.Popen") -> None: + task.terminate() + time.sleep(0.5) + if task.poll() is None: + task.kill() + # TODO(SPARK-41775): Cleanup temp files + + +def check_parent_alive(task: "subprocess.Popen") -> None: + orig_parent_id = os.getppid() + while True: + if os.getppid() != orig_parent_id: + clean_and_terminate(task) + break + time.sleep(0.5) + + +if __name__ == "__main__": + """ + This is a wrapper around torch.distributed.run and it kills the child process + if the parent process fails, crashes, or exits. + """ + + args = sys.argv[1:] + + cmd = [sys.executable, "-m", "torch.distributed.run", *args] + task = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE, + env=os.environ, + ) + t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True) + + def sigterm_handler(*args: Any) -> None: + clean_and_terminate(task) + os._exit(0) + + signal.signal(signal.SIGTERM, sigterm_handler) + + t.start() + task.stdin.close() # type: ignore[union-attr] + try: + for line in task.stdout: # type: ignore[union-attr] + decoded = line.decode() + print(decoded.rstrip()) + task.wait() + finally: + if task.poll() is None: + try: + task.terminate() + time.sleep(0.5) + if task.poll() is None: + task.kill() + except OSError: + pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org