This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b800ba8461 [SPARK-41591][PYTHON][ML] Training PyTorch Files on Single 
Node Multi GPU
6b800ba8461 is described below

commit 6b800ba8461935a205d8c15eba2ff11f141dea47
Author: Rithwik Ediga Lakhamsani <rithwik.ed...@databricks.com>
AuthorDate: Thu Jan 12 08:42:01 2023 +0900

    [SPARK-41591][PYTHON][ML] Training PyTorch Files on Single Node Multi GPU
    
    ### What changes were proposed in this pull request?
    
    This is an addition to https://github.com/apache/spark/pull/39146 to add 
support for single node training using PyTorch files. The users would follow 
the second workflow in the [design 
document](https://docs.google.com/document/d/1QPO1Ly8WteL6aIPvVcR7Xne9qVtJiB3fdrRn7NwBcpA/edit#heading=h.8yvw9xq428fh)
 to run training. I added some new utility functions as well as built on top of 
current functions.
    
    ### Why are the changes needed?
    
    Look at the [main 
ticket](https://issues.apache.org/jira/browse/SPARK-41589) for more details.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Some unit tests were added and integration tests will be added in a later 
PR (https://issues.apache.org/jira/browse/SPARK-41777).
    
    Closes #39188 from rithwik-db/pytorch-file-local-training.
    
    Authored-by: Rithwik Ediga Lakhamsani <rithwik.ed...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/ml/torch/distributor.py             | 186 ++++++++++++++++++++-
 python/pyspark/ml/torch/tests/test_distributor.py  | 147 +++++++++++++++-
 .../pyspark/ml/torch/torch_run_process_wrapper.py  |  83 +++++++++
 3 files changed, 412 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 2a4027cbb25..80d5ad31c3c 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -15,7 +15,16 @@
 # limitations under the License.
 #
 
+import collections
+import ctypes
 import math
+import os
+import random
+import re
+import signal
+import sys
+import subprocess
+import time
 from typing import Union, Callable, Optional, Any
 import warnings
 
@@ -34,8 +43,8 @@ def get_conf_boolean(sc: SparkContext, key: str, 
default_value: str) -> bool:
 
     Parameters
     ----------
-    sc : SparkContext
-        The SparkContext for the distributor.
+    sc : :class:`SparkContext`
+        The :class:`SparkContext` for the distributor.
     key : str
         string for conf name
     default_value : str
@@ -64,6 +73,42 @@ def get_conf_boolean(sc: SparkContext, key: str, 
default_value: str) -> bool:
     )
 
 
+def get_gpus_owned(sc: SparkContext) -> list[str]:
+    """Gets the number of GPUs that Spark scheduled to the calling task.
+
+    Parameters
+    ----------
+    sc : :class:`SparkContext`
+        The :class:`SparkContext` that has GPUs available.
+
+    Returns
+    -------
+    list
+        The correct mapping of addresses to workers.
+
+    Raises
+    ------
+    ValueError
+        Raised if the input addresses were not found.
+    """
+    CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+    pattern = re.compile("^[1-9][0-9]*|0$")
+    addresses = sc.resources["gpu"].addresses
+    if any(not pattern.match(address) for address in addresses):
+        raise ValueError(
+            f"Found GPU addresses {addresses} which "
+            "are not all in the correct format "
+            "for CUDA_VISIBLE_DEVICES, which requires "
+            "integers with no zero padding."
+        )
+    if CUDA_VISIBLE_DEVICES in os.environ:
+        gpu_indices = list(map(int, addresses))
+        gpu_list = os.environ[CUDA_VISIBLE_DEVICES].split(",")
+        gpu_owned = [gpu_list[i] for i in gpu_indices]
+        return gpu_owned
+    return addresses
+
+
 class Distributor:
     """
     The parent class for TorchDistributor. This class shouldn't be 
instantiated directly.
@@ -85,6 +130,12 @@ class Distributor:
         self.num_tasks = self._get_num_tasks()
         self.ssl_conf = None
 
+    def _create_input_params(self) -> dict[str, Any]:
+        input_params = self.__dict__.copy()
+        for unneeded_param in ["spark", "sc", "ssl_conf"]:
+            del input_params[unneeded_param]
+        return input_params
+
     def _get_num_tasks(self) -> int:
         """
         Returns the number of Spark tasks to use for distributed training
@@ -261,6 +312,130 @@ class TorchDistributor(Distributor):
         super().__init__(num_processes, local_mode, use_gpu)
         self.ssl_conf = "pytorch.spark.distributor.ignoreSsl"  # type: ignore
         self._validate_input_params()
+        self.input_params = self._create_input_params()
+
+    @staticmethod
+    def _create_torchrun_command(
+        input_params: dict[str, Any], path_to_train_file: str, *args: Any
+    ) -> list[str]:
+        local_mode = input_params["local_mode"]
+        num_processes = input_params["num_processes"]
+
+        if local_mode:
+            torchrun_args = ["--standalone", "--nnodes=1"]
+            processes_per_node = num_processes
+        else:
+            pass
+            # TODO(SPARK-41592): Handle distributed training
+
+        args_string = list(map(str, args))  # converting all args to strings
+
+        return (
+            [sys.executable, "-m", 
"pyspark.ml.torch.distributor.torch_run_process_wrapper"]
+            + torchrun_args
+            + [f"--nproc_per_node={processes_per_node}"]
+            + [path_to_train_file, *args_string]
+        )
+
+    @staticmethod
+    def _execute_command(
+        cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True
+    ) -> None:
+        _TAIL_LINES_TO_KEEP = 100
+
+        def sigterm_on_parent_death() -> None:
+            """
+            Uses prctl to automatically send SIGTERM to the command process 
when its parent is dead.
+            This handles the case when the parent is a PySpark worker process.
+            If a user cancels the PySpark job, the worker process gets killed, 
regardless of
+            PySpark daemon and worker reuse settings.
+            """
+            if _prctl:
+                try:
+                    libc = ctypes.CDLL("libc.so.6")
+                    # Set the parent process death signal of the command 
process to SIGTERM.
+                    libc.prctl(1, signal.SIGTERM)
+                except OSError:
+                    pass
+
+        task = subprocess.Popen(
+            cmd,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT,
+            stdin=subprocess.PIPE,
+            env=os.environ,
+            preexec_fn=sigterm_on_parent_death,
+        )
+        task.stdin.close()  # type: ignore
+        tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP)
+        try:
+            for line in task.stdout:  # type: ignore
+                decoded = line.decode()
+                tail.append(decoded)
+                if redirect_to_stdout:
+                    sys.stdout.write(decoded)
+            task.wait()
+        finally:
+            if task.poll() is None:
+                try:
+                    task.terminate()  # SIGTERM
+                    time.sleep(0.5)
+                    if task.poll() is None:
+                        task.kill()  # SIGKILL
+                except OSError:
+                    pass
+        if task.returncode != os.EX_OK:
+            if len(tail) == _TAIL_LINES_TO_KEEP:
+                last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task 
output are"
+            else:
+                last_n_msg = "task output is"
+            task_output = "".join(tail)
+            raise RuntimeError(
+                f"Command {cmd} failed with return code {task.returncode}."
+                f"The {last_n_msg} included below: {task_output}"
+            )
+
+    def _run_local_training(
+        self,
+        framework_wrapper_fn: Optional[Callable],
+        train_object: Union[Callable, str],
+        *args: Any,
+    ) -> Optional[Any]:
+        CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+        cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ
+        old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
+        try:
+            if self.use_gpu:
+                gpus_owned = get_gpus_owned(self.sc)
+
+                if self.num_processes > len(gpus_owned):
+                    raise ValueError(
+                        f"""{self.num_processes} processes were requested
+                        for local training with GPU training but only
+                        {len(gpus_owned)} GPUs were available."""
+                    )
+                random.seed(hash(train_object))
+                selected_gpus = [str(e) for e in random.sample(gpus_owned, 
self.num_processes)]
+                os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
+
+            output = framework_wrapper_fn(self.input_params, train_object, 
*args)  # type: ignore
+        finally:
+            if cuda_state_was_set:
+                os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices
+            else:
+                if CUDA_VISIBLE_DEVICES in os.environ:
+                    del os.environ[CUDA_VISIBLE_DEVICES]
+
+        return output
+
+    @staticmethod
+    def _run_training_on_pytorch_file(
+        input_params: dict[str, Any], train_path: str, *args: Any
+    ) -> None:
+        training_command = TorchDistributor._create_torchrun_command(
+            input_params, train_path, *args
+        )
+        TorchDistributor._execute_command(training_command)
 
     def run(self, train_object: Union[Callable, str], *args: Any) -> 
Optional[Any]:
         """Runs distributed training.
@@ -278,4 +453,9 @@ class TorchDistributor(Distributor):
             Returns the output of train_object called with args if 
train_object is a
             Callable with an expected output.
         """
-        pass
+        framework_wrapper_fn = None
+        if isinstance(train_object, str):
+            framework_wrapper_fn = 
TorchDistributor._run_training_on_pytorch_file
+        if self.local_mode:
+            output = self._run_local_training(framework_wrapper_fn, 
train_object, *args)
+        return output
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index e84505f92fe..4b24eff8742 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -15,17 +15,38 @@
 # limitations under the License.
 #
 
+import contextlib
 import os
+from six import StringIO  # type: ignore
 import stat
+import subprocess
+import sys
+import time
 import tempfile
+import threading
+from typing import Callable
 import unittest
+from unittest.mock import patch
 
 from pyspark import SparkConf, SparkContext
-from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.ml.torch.distributor import TorchDistributor, get_gpus_owned
+from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, 
check_parent_alive
 from pyspark.sql import SparkSession
 from pyspark.testing.utils import SPARK_HOME
 
 
+@contextlib.contextmanager
+def patch_stdout() -> StringIO:
+    """patch stdout and give an output"""
+    sys_stdout = sys.stdout
+    io_out = StringIO()
+    sys.stdout = io_out
+    try:
+        yield io_out
+    finally:
+        sys.stdout = sys_stdout
+
+
 class TorchDistributorBaselineUnitTests(unittest.TestCase):
     def setUp(self) -> None:
         conf = SparkConf()
@@ -35,6 +56,14 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
     def tearDown(self) -> None:
         self.spark.stop()
 
+    def setup_env_vars(self, input_map: dict[str, str]) -> None:
+        for key, value in input_map.items():
+            os.environ[key] = value
+
+    def delete_env_vars(self, input_map: dict[str, str]) -> None:
+        for key in input_map.keys():
+            del os.environ[key]
+
     def test_validate_correct_inputs(self) -> None:
         inputs = [
             (1, True, False),
@@ -90,6 +119,55 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
                 with self.assertRaisesRegex(RuntimeError, "unset"):
                     TorchDistributor(num_processes, False, True)
 
+    def test_execute_command(self) -> None:
+        """Test that run command runs the process and logs are written 
correctly"""
+
+        with patch_stdout() as output:
+            stdout_command = ["echo", "hello_stdout"]
+            TorchDistributor._execute_command(stdout_command)
+            self.assertIn(
+                "hello_stdout", output.getvalue().strip(), "hello_stdout 
should print to stdout"
+            )
+
+        with patch_stdout() as output:
+            stderr_command = ["bash", "-c", "echo hello_stderr >&2"]
+            TorchDistributor._execute_command(stderr_command)
+            self.assertIn(
+                "hello_stderr", output.getvalue().strip(), "hello_stderr 
should print to stdout"
+            )
+
+        # include command in the exception message
+        with self.assertRaisesRegexp(RuntimeError, "exit 1"):  # pylint: 
disable=deprecated-method
+            error_command = ["bash", "-c", "exit 1"]
+            TorchDistributor._execute_command(error_command)
+
+        with self.assertRaisesRegexp(RuntimeError, "abcdef"):  # pylint: 
disable=deprecated-method
+            error_command = ["bash", "-c", "'abc''def'"]
+            TorchDistributor._execute_command(error_command)
+
+    def test_create_torchrun_command(self) -> None:
+        train_path = "train.py"
+        args_string = ["1", "3"]
+        local_mode_input_params = {"num_processes": 4, "local_mode": True}
+
+        expected_local_mode_output = [
+            sys.executable,
+            "-m",
+            "pyspark.ml.torch.distributor.torch_run_process_wrapper",
+            "--standalone",
+            "--nnodes=1",
+            "--nproc_per_node=4",
+            "train.py",
+            "1",
+            "3",
+        ]
+        self.assertEqual(
+            TorchDistributor._create_torchrun_command(
+                local_mode_input_params, train_path, *args_string
+            ),
+            expected_local_mode_output,
+        )
+
 
 class TorchDistributorLocalUnitTests(unittest.TestCase):
     def setUp(self) -> None:
@@ -118,6 +196,14 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
         os.unlink(self.tempFile.name)
         self.spark.stop()
 
+    def setup_env_vars(self, input_map: dict[str, str]) -> None:
+        for key, value in input_map.items():
+            os.environ[key] = value
+
+    def delete_env_vars(self, input_map: dict[str, str]) -> None:
+        for key in input_map.keys():
+            del os.environ[key]
+
     def test_get_num_tasks_locally(self) -> None:
         succeeds = [1, 2]
         fails = [4, 8]
@@ -133,6 +219,42 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
                     distributor = TorchDistributor(num_processes, True, True)
                     distributor.num_processes = 3
 
+    def test_get_gpus_owned_local(self) -> None:
+        addresses = ["0", "1", "2"]
+        self.assertEqual(get_gpus_owned(self.sc), addresses)
+
+        env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"}
+        self.setup_env_vars(env_vars)
+        self.assertEqual(get_gpus_owned(self.sc), ["3", "4", "5"])
+        self.delete_env_vars(env_vars)
+
+    def test_local_training_succeeds(self) -> None:
+        CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+        inputs = [
+            ("0,1,2", 1, True, "1"),
+            ("0,1,2", 3, True, "1,2,0"),
+            ("0,1,2", 2, False, "0,1,2"),
+            (None, 3, False, "NONE"),
+        ]
+
+        for i, (cuda_env_var, num_processes, use_gpu, expected) in 
enumerate(inputs):
+            with self.subTest(f"subtest: {i + 1}"):
+                # setup
+                if cuda_env_var:
+                    self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
+                dist = TorchDistributor(num_processes, True, use_gpu)
+                dist._run_training_on_pytorch_file = lambda *args: 
os.environ.get(  # type: ignore
+                    CUDA_VISIBLE_DEVICES, "NONE"
+                )
+                self.assertEqual(
+                    expected,
+                    
dist._run_local_training(dist._run_training_on_pytorch_file, "train.py"),
+                )
+                # cleanup
+                if cuda_env_var:
+                    self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
 
 class TorchDistributorDistributedUnitTests(unittest.TestCase):
     def setUp(self) -> None:
@@ -178,6 +300,29 @@ class 
TorchDistributorDistributedUnitTests(unittest.TestCase):
         self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", 
"1")
 
 
+class TorchWrapperUnitTests(unittest.TestCase):
+    def test_clean_and_terminate(self) -> None:
+        def kill_task(task: "subprocess.Popen") -> None:
+            time.sleep(1)
+            clean_and_terminate(task)
+
+        command = [sys.executable, "-c", '"import time; time.sleep(20)"']
+        task = subprocess.Popen(command)
+        t = threading.Thread(target=kill_task, args=(task,))
+        t.start()
+        time.sleep(2)
+        self.assertEqual(task.poll(), 0)  # implies task ended
+
+    @patch("pyspark.ml.torch.torch_run_process_wrapper.clean_and_terminate")
+    def test_check_parent_alive(self, mock_clean_and_terminate: Callable) -> 
None:
+        command = [sys.executable, "-c", '"import time; time.sleep(2)"']
+        task = subprocess.Popen(command)
+        t = threading.Thread(target=check_parent_alive, args=(task,), 
daemon=True)
+        t.start()
+        time.sleep(2)
+        self.assertEqual(mock_clean_and_terminate.call_count, 0)  # type: 
ignore[attr-defined]
+
+
 if __name__ == "__main__":
     from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403 
type: ignore
 
diff --git a/python/pyspark/ml/torch/torch_run_process_wrapper.py 
b/python/pyspark/ml/torch/torch_run_process_wrapper.py
new file mode 100644
index 00000000000..67ec492329d
--- /dev/null
+++ b/python/pyspark/ml/torch/torch_run_process_wrapper.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import signal
+import subprocess
+import sys
+import threading
+import time
+from typing import Any
+
+
+def clean_and_terminate(task: "subprocess.Popen") -> None:
+    task.terminate()
+    time.sleep(0.5)
+    if task.poll() is None:
+        task.kill()
+    # TODO(SPARK-41775): Cleanup temp files
+
+
+def check_parent_alive(task: "subprocess.Popen") -> None:
+    orig_parent_id = os.getppid()
+    while True:
+        if os.getppid() != orig_parent_id:
+            clean_and_terminate(task)
+            break
+        time.sleep(0.5)
+
+
+if __name__ == "__main__":
+    """
+    This is a wrapper around torch.distributed.run and it kills the child 
process
+    if the parent process fails, crashes, or exits.
+    """
+
+    args = sys.argv[1:]
+
+    cmd = [sys.executable, "-m", "torch.distributed.run", *args]
+    task = subprocess.Popen(
+        cmd,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.STDOUT,
+        stdin=subprocess.PIPE,
+        env=os.environ,
+    )
+    t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True)
+
+    def sigterm_handler(*args: Any) -> None:
+        clean_and_terminate(task)
+        os._exit(0)
+
+    signal.signal(signal.SIGTERM, sigterm_handler)
+
+    t.start()
+    task.stdin.close()  # type: ignore[union-attr]
+    try:
+        for line in task.stdout:  # type: ignore[union-attr]
+            decoded = line.decode()
+            print(decoded.rstrip())
+        task.wait()
+    finally:
+        if task.poll() is None:
+            try:
+                task.terminate()
+                time.sleep(0.5)
+                if task.poll() is None:
+                    task.kill()
+            except OSError:
+                pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to