This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e38a1b74356 [SPARK-41593][PYTHON][ML] Adding logging from executors
e38a1b74356 is described below

commit e38a1b7435649649a7de0dadd135c9f78e6b8099
Author: Rithwik Ediga Lakhamsani <[email protected]>
AuthorDate: Sat Jan 21 16:18:20 2023 +0900

    [SPARK-41593][PYTHON][ML] Adding logging from executors
    
    Check the last commit of this PR for the actual relevant diff since this PR 
is batched in with other PRs.
    
    ### What changes were proposed in this pull request?
    
    I added support for logging from the executor nodes by using `socketserver` 
to create a server on the driver node and clients on each of the partitions 
that sends logs to the server.
    
    ### Why are the changes needed?
    
    We want to make it easier for users to see data from the executor nodes on 
the driver nodes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    A couple of unit tests were created.
    
    Closes #39299 from rithwik-db/add-logging.
    
    Authored-by: Rithwik Ediga Lakhamsani <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/ml/torch/distributor.py             |  81 +++++++--
 python/pyspark/ml/torch/log_communication.py       | 201 +++++++++++++++++++++
 python/pyspark/ml/torch/tests/test_distributor.py  |   5 +-
 .../ml/torch/tests/test_log_communication.py       | 176 ++++++++++++++++++
 5 files changed, 449 insertions(+), 15 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index ae41cae4696..634be286065 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -635,6 +635,7 @@ pyspark_ml = Module(
         "pyspark.ml.tests.test_util",
         "pyspark.ml.tests.test_wrapper",
         "pyspark.ml.torch.tests.test_distributor",
+        "pyspark.ml.torch.tests.test_log_communication",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy and it 
isn't available there
diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 51a1203cbd5..804e0d4a27a 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -16,6 +16,7 @@
 #
 
 import collections
+import logging
 import math
 import os
 import random
@@ -24,9 +25,13 @@ import sys
 import subprocess
 import time
 from typing import Union, Callable, List, Dict, Optional, Any
-import warnings
 
 from pyspark.sql import SparkSession
+from pyspark.ml.torch.log_communication import (  # type: ignore
+    get_driver_host,
+    LogStreamingClient,
+    LogStreamingServer,
+)
 from pyspark.context import SparkContext
 from pyspark.taskcontext import BarrierTaskContext
 
@@ -72,6 +77,19 @@ def get_conf_boolean(sc: SparkContext, key: str, 
default_value: str) -> bool:
     )
 
 
+def get_logger(name: str) -> logging.Logger:
+    """
+    Gets a logger by name, or creates and configures it for the first time.
+    """
+    logger = logging.getLogger(name)
+    logger.setLevel(logging.INFO)
+    # If the logger is configured, skip the configure
+    if not logger.handlers and not logging.getLogger().handlers:
+        handler = logging.StreamHandler(sys.stderr)
+        logger.addHandler(handler)
+    return logger
+
+
 def get_gpus_owned(context: Union[SparkContext, BarrierTaskContext]) -> 
List[str]:
     """Gets the number of GPUs that Spark scheduled to the calling task.
 
@@ -122,6 +140,7 @@ class Distributor:
         local_mode: bool = True,
         use_gpu: bool = True,
     ):
+        self.logger = get_logger(self.__class__.__name__)
         self.num_processes = num_processes
         self.local_mode = local_mode
         self.use_gpu = use_gpu
@@ -134,7 +153,7 @@ class Distributor:
 
     def _create_input_params(self) -> Dict[str, Any]:
         input_params = self.__dict__.copy()
-        for unneeded_param in ["spark", "sc", "ssl_conf"]:
+        for unneeded_param in ["spark", "sc", "ssl_conf", "logger"]:
             del input_params[unneeded_param]
         return input_params
 
@@ -169,11 +188,10 @@ class Distributor:
                 if num_available_gpus == 0:
                     raise RuntimeError("GPU resources were not configured 
properly on the driver.")
                 if self.num_processes > num_available_gpus:
-                    warnings.warn(
+                    self.logger.warning(
                         f"'num_processes' cannot be set to a value greater 
than the number of "
                         f"available GPUs on the driver, which is 
{num_available_gpus}. "
                         f"'num_processes' was reset to be equal to the number 
of available GPUs.",
-                        RuntimeWarning,
                     )
                     self.num_processes = num_available_gpus
         return self.num_processes
@@ -201,7 +219,7 @@ class Distributor:
         if is_ssl_enabled:
             name = self.__class__.__name__
             if ignore_ssl:
-                warnings.warn(
+                self.logger.warning(
                     f"""
                     This cluster has TLS encryption enabled;
                     however, {name} does not
@@ -348,7 +366,10 @@ class TorchDistributor(Distributor):
 
     @staticmethod
     def _execute_command(
-        cmd: List[str], _prctl: bool = True, redirect_to_stdout: bool = True
+        cmd: List[str],
+        _prctl: bool = True,
+        redirect_to_stdout: bool = True,
+        log_streaming_client: Optional[LogStreamingClient] = None,
     ) -> None:
         _TAIL_LINES_TO_KEEP = 100
 
@@ -367,6 +388,8 @@ class TorchDistributor(Distributor):
                 tail.append(decoded)
                 if redirect_to_stdout:
                     sys.stdout.write(decoded)
+                if log_streaming_client:
+                    log_streaming_client.send(decoded.rstrip())
             task.wait()
         finally:
             if task.poll() is None:
@@ -404,7 +427,10 @@ class TorchDistributor(Distributor):
                 selected_gpus = [str(e) for e in random.sample(gpus_owned, 
self.num_processes)]
                 os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
 
+            self.logger.info(f"Started local training with 
{self.num_processes} processes")
             output = framework_wrapper_fn(self.input_params, train_object, 
*args)  # type: ignore
+            self.logger.info(f"Finished local training with 
{self.num_processes} processes")
+
         finally:
             if cuda_state_was_set:
                 os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices
@@ -438,6 +464,8 @@ class TorchDistributor(Distributor):
         num_processes = self.num_processes
         use_gpu = self.use_gpu
         input_params = self.input_params
+        driver_address = self.driver_address
+        log_streaming_server_port = self.log_streaming_server_port
 
         # Spark task program
         def wrapped_train_fn(_):  # type: ignore[no-untyped-def]
@@ -486,7 +514,15 @@ class TorchDistributor(Distributor):
                 os.environ[CUDA_VISIBLE_DEVICES] = ""
             set_torch_config(context)
 
-            output = framework_wrapper_fn(input_params, train_object, *args)
+            log_streaming_client = LogStreamingClient(driver_address, 
log_streaming_server_port)
+            input_params["log_streaming_client"] = log_streaming_client
+            try:
+                output = framework_wrapper_fn(input_params, train_object, 
*args)
+            finally:
+                try:
+                    LogStreamingClient._destroy()
+                except BaseException:
+                    pass
 
             if context.partitionId() == 0:
                 yield output
@@ -501,15 +537,31 @@ class TorchDistributor(Distributor):
     ) -> Optional[Any]:
         if not framework_wrapper_fn:
             raise RuntimeError("Unknown combination of parameters")
+
+        log_streaming_server = LogStreamingServer()
+        self.driver_address = get_driver_host(self.sc)
+        log_streaming_server.start(spark_host_address=self.driver_address)
+        time.sleep(1)  # wait for the server to start
+        self.log_streaming_server_port = log_streaming_server.port
+
         spark_task_function = self._get_spark_task_function(
             framework_wrapper_fn, train_object, *args
         )
         self._check_encryption()
-        result = (
-            self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-            .barrier()
-            .mapPartitions(spark_task_function)
-            .collect()[0]
+        self.logger.info(
+            f"Started distributed training with {self.num_processes} executor 
proceses"
+        )
+        try:
+            result = (
+                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
+                .barrier()
+                .mapPartitions(spark_task_function)
+                .collect()[0]
+            )
+        finally:
+            log_streaming_server.shutdown()
+        self.logger.info(
+            f"Finished distributed training with {self.num_processes} executor 
proceses"
         )
         return result
 
@@ -517,10 +569,13 @@ class TorchDistributor(Distributor):
     def _run_training_on_pytorch_file(
         input_params: Dict[str, Any], train_path: str, *args: Any
     ) -> None:
+        log_streaming_client = input_params.get("log_streaming_client", None)
         training_command = TorchDistributor._create_torchrun_command(
             input_params, train_path, *args
         )
-        TorchDistributor._execute_command(training_command)
+        TorchDistributor._execute_command(
+            training_command, log_streaming_client=log_streaming_client
+        )
 
     def run(self, train_object: Union[Callable, str], *args: Any) -> 
Optional[Any]:
         """Runs distributed training.
diff --git a/python/pyspark/ml/torch/log_communication.py 
b/python/pyspark/ml/torch/log_communication.py
new file mode 100644
index 00000000000..e269aed42d1
--- /dev/null
+++ b/python/pyspark/ml/torch/log_communication.py
@@ -0,0 +1,201 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# type: ignore
+
+from contextlib import closing
+import time
+import socket
+import socketserver
+from struct import pack, unpack
+import sys
+import threading
+import traceback
+from typing import Optional, Generator
+import warnings
+from pyspark.context import SparkContext
+
+# Use b'\x00' as separator instead of b'\n', because the bytes are encoded in 
utf-8
+_SERVER_POLL_INTERVAL = 0.1
+_TRUNCATE_MSG_LEN = 4000
+
+
+def get_driver_host(sc: SparkContext) -> Optional[str]:
+    return sc.getConf().get("spark.driver.host")
+
+
+_log_print_lock = threading.Lock()  # pylint: disable=invalid-name
+
+
+def _get_log_print_lock() -> threading.Lock:
+    return _log_print_lock
+
+
+class WriteLogToStdout(socketserver.StreamRequestHandler):
+    def _read_bline(self) -> Generator[bytes, None, None]:
+        while self.server.is_active:
+            packed_number_bytes = self.rfile.read(4)
+            if not packed_number_bytes:
+                time.sleep(_SERVER_POLL_INTERVAL)
+                continue
+            number_bytes = unpack(">i", packed_number_bytes)[0]
+            message = self.rfile.read(number_bytes)
+            yield message
+
+    def handle(self) -> None:
+        self.request.setblocking(0)  # non-blocking mode
+        for bline in self._read_bline():
+            with _get_log_print_lock():
+                sys.stderr.write(bline.decode("utf-8") + "\n")
+                sys.stderr.flush()
+
+
+# What is run on the local driver
+class LogStreamingServer:
+    def __init__(self) -> None:
+        self.server = None
+        self.serve_thread = None
+        self.port = None
+
+    @staticmethod
+    def _get_free_port(spark_host_address: str = "") -> int:
+        with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as tcp:
+            tcp.bind((spark_host_address, 0))
+            _, port = tcp.getsockname()
+        return port
+
+    def start(self, spark_host_address: str = "") -> None:
+        if self.server:
+            raise RuntimeError("Cannot start the server twice.")
+
+        def serve_task(port: int) -> None:
+            with socketserver.ThreadingTCPServer(("0.0.0.0", port), 
WriteLogToStdout) as server:
+                self.server = server
+                server.is_active = True
+                server.serve_forever(poll_interval=_SERVER_POLL_INTERVAL)
+
+        self.port = LogStreamingServer._get_free_port(spark_host_address)
+        self.serve_thread = threading.Thread(target=serve_task, 
args=(self.port,))
+        self.serve_thread.setDaemon(True)
+        self.serve_thread.start()
+
+    def shutdown(self) -> None:
+        if self.server:
+            # Sleep to ensure all log has been received and printed.
+            time.sleep(_SERVER_POLL_INTERVAL * 2)
+            # Before close we need flush to ensure all stdout buffer were 
printed.
+            sys.stdout.flush()
+            self.server.is_active = False
+            self.server.shutdown()
+            self.serve_thread.join()
+            self.server = None
+            self.serve_thread = None
+
+
+class LogStreamingClientBase:
+    @staticmethod
+    def _maybe_truncate_msg(message: str) -> str:
+        if len(message) > _TRUNCATE_MSG_LEN:
+            message = message[:_TRUNCATE_MSG_LEN]
+            return message + "...(truncated)"
+        else:
+            return message
+
+    def send(self, message: str) -> None:
+        pass
+
+    def close(self) -> None:
+        pass
+
+
+class LogStreamingClient(LogStreamingClientBase):
+    """
+    A client that streams log messages to :class:`LogStreamingServer`.
+    In case of failures, the client will skip messages instead of raising an 
error.
+    """
+
+    _log_callback_client = None
+    _server_address = None
+    _singleton_lock = threading.Lock()
+
+    @staticmethod
+    def _init(address: str, port: int) -> None:
+        LogStreamingClient._server_address = (address, port)
+
+    @staticmethod
+    def _destroy() -> None:
+        LogStreamingClient._server_address = None
+        if LogStreamingClient._log_callback_client is not None:
+            LogStreamingClient._log_callback_client.close()
+
+    def __init__(self, address: str, port: int, timeout: int = 10):
+        """
+        Creates a connection to the logging server and authenticates.This 
client is best effort,
+        if authentication or sending a message  fails, the client will be 
marked as not alive and
+        stop trying to send message.
+
+        :param address: Address where the service is running.
+        :param port: Port where the service is listening for new connections.
+        """
+        self.address = address
+        self.port = port
+        self.timeout = timeout
+        self.sock = None
+        self.failed = True
+        self._lock = threading.RLock()
+
+    def _fail(self, error_msg: str) -> None:
+        self.failed = True
+        warnings.warn(f"{error_msg}: {traceback.format_exc()}\n")
+
+    def _connect(self) -> None:
+        try:
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            sock.settimeout(self.timeout)
+            sock.connect((self.address, self.port))
+            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+            self.sock = sock
+            self.failed = False
+        except (OSError, IOError):  # pylint: disable=broad-except
+            self._fail("Error connecting log streaming server")
+
+    def send(self, message: str) -> None:
+        """
+        Sends a message.
+        """
+        with self._lock:
+            if self.sock is None:
+                self._connect()
+            if not self.failed:
+                try:
+                    message = 
LogStreamingClientBase._maybe_truncate_msg(message)
+                    # TODO:
+                    #  1) addressing issue: idle TCP connection might get 
disconnected by
+                    #     cloud provider
+                    #  2) sendall may block when server is busy handling data.
+                    binary_message = message.encode("utf-8")
+                    packed_number_bytes = pack(">i", len(binary_message))
+                    self.sock.sendall(packed_number_bytes + binary_message)
+                except Exception:  # pylint: disable=broad-except
+                    self._fail("Error sending logs to driver, stopping log 
streaming")
+
+    def close(self) -> None:
+        """
+        Closes the connection.
+        """
+        if self.sock:
+            self.sock.close()
+            self.sock = None
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index 607cc7cd1ad..619e733c0bb 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -248,9 +248,10 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
 
         for num_processes in fails:
             with self.subTest():
-                with self.assertWarns(RuntimeWarning):
+                with self.assertLogs("TorchDistributor", level="WARNING") as 
log:
                     distributor = TorchDistributor(num_processes, True, True)
-                    distributor.num_processes = 3
+                    self.assertEqual(len(log.records), 1)
+                    self.assertEqual(distributor.num_processes, 3)
 
     def test_get_gpus_owned_local(self) -> None:
         addresses = ["0", "1", "2"]
diff --git a/python/pyspark/ml/torch/tests/test_log_communication.py 
b/python/pyspark/ml/torch/tests/test_log_communication.py
new file mode 100644
index 00000000000..0c937926480
--- /dev/null
+++ b/python/pyspark/ml/torch/tests/test_log_communication.py
@@ -0,0 +1,176 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import, division, print_function
+
+import contextlib
+from six import StringIO  # type: ignore
+import sys
+import time
+from typing import Any, Callable
+import unittest
+
+import pyspark.ml.torch.log_communication
+from pyspark.ml.torch.log_communication import (  # type: ignore
+    LogStreamingServer,
+    LogStreamingClient,
+    LogStreamingClientBase,
+    _SERVER_POLL_INTERVAL,
+)
+
+
[email protected]
+def patch_stderr() -> StringIO:
+    """patch stdout and give an output"""
+    sys_stderr = sys.stderr
+    io_out = StringIO()
+    sys.stderr = io_out
+    try:
+        yield io_out
+    finally:
+        sys.stderr = sys_stderr
+
+
+class LogStreamingServiceTestCase(unittest.TestCase):
+    def setUp(self) -> None:
+        self.default_truncate_msg_len = (
+            pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN  # type: 
ignore
+        )
+        pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10  # type: 
ignore
+
+    def tearDown(self) -> None:
+        pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = (  # type: 
ignore
+            self.default_truncate_msg_len
+        )
+
+    def basic_test(self) -> None:
+        server = LogStreamingServer()
+        server.start()
+        time.sleep(1)
+        client = LogStreamingClient("localhost", server.port)
+        with patch_stderr() as output:
+            client.send("msg 001")
+            client.send("msg 002")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+            output = output.getvalue()
+            self.assertIn("msg 001\nmsg 002\n", output)
+        client.close()
+        server.shutdown()
+
+    def test_truncate_message(self) -> None:
+        msg1 = "abc"
+        assert LogStreamingClientBase._maybe_truncate_msg(msg1) == msg1
+        msg2 = "abcdefghijkl"
+        assert LogStreamingClientBase._maybe_truncate_msg(msg2) == 
"abcdefghij...(truncated)"
+
+    def test_multiple_clients(self) -> None:
+        server = LogStreamingServer()
+        server.start()
+        time.sleep(1)
+        client1 = LogStreamingClient("localhost", server.port)
+        client2 = LogStreamingClient("localhost", server.port)
+        with patch_stderr() as output:
+            client1.send("c1 msg1")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+            client2.send("c2 msg1")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+            client1.send("c1 msg2")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+            client2.send("c2 msg2")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+            output = output.getvalue()
+            self.assertIn("c1 msg1\nc2 msg1\nc1 msg2\nc2 msg2\n", output)
+        client1.close()
+        client2.close()
+        server.shutdown()
+
+    def test_client_should_fail_gracefully(self) -> None:
+        server = LogStreamingServer()
+        server.start()
+        time.sleep(1)
+        client = LogStreamingClient("localhost", server.port)
+        client.send("msg 001")
+        server.shutdown()
+        for i in range(5):
+            client.send("msg 002")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+        self.assertTrue(client.failed)
+        client.close()
+
+    def test_client_send_intermittently(self) -> None:
+        server = LogStreamingServer()
+        server.start()
+        time.sleep(1)
+        client = LogStreamingClient("localhost", server.port)
+        with patch_stderr() as output:
+            client._connect()
+            # test client send half message first
+            client.send("msg part1")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+            # test client send another half message
+            client.send(" msg part2")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+            output = output.getvalue()
+            self.assertIn("msg part1\n msg part2\n", output)
+        client.close()
+        server.shutdown()
+
+    @staticmethod
+    def test_server_shutdown() -> None:
+        def run_test(client_ops: Callable) -> None:
+            server = LogStreamingServer()
+            server.start()
+            time.sleep(1)
+            client = LogStreamingClient("localhost", server.port)
+            client_ops(client)
+            server.shutdown()
+            client.close()
+
+        def client_ops_close(client: Any) -> None:
+            client.close()
+
+        def client_ops_send_half_msg(client: Any) -> None:
+            # Test server only recv incomplete message from client can exit.
+            client._connect()
+            client.sock.sendall(b"msg part1 ")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+
+        def client_ops_send_a_msg(client: Any) -> None:
+            client.send("msg1")
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+
+        def client_ops_send_a_msg_and_close(client: Any) -> None:
+            client.send("msg1")
+            client.close()
+            time.sleep(_SERVER_POLL_INTERVAL + 1)
+
+        run_test(client_ops_close)
+        run_test(client_ops_send_half_msg)
+        run_test(client_ops_send_a_msg)
+        run_test(client_ops_send_a_msg_and_close)
+
+
+if __name__ == "__main__":
+    from pyspark.ml.torch.tests.test_log_communication import *  # noqa: 
F401,F403 type: ignore
+
+    try:
+        import xmlrunner  # type: ignore
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to