This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e38a1b74356 [SPARK-41593][PYTHON][ML] Adding logging from executors
e38a1b74356 is described below
commit e38a1b7435649649a7de0dadd135c9f78e6b8099
Author: Rithwik Ediga Lakhamsani <[email protected]>
AuthorDate: Sat Jan 21 16:18:20 2023 +0900
[SPARK-41593][PYTHON][ML] Adding logging from executors
Check the last commit of this PR for the actual relevant diff since this PR
is batched in with other PRs.
### What changes were proposed in this pull request?
I added support for logging from the executor nodes by using `socketserver`
to create a server on the driver node and clients on each of the partitions
that sends logs to the server.
### Why are the changes needed?
We want to make it easier for users to see data from the executor nodes on
the driver nodes.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
A couple of unit tests were created.
Closes #39299 from rithwik-db/add-logging.
Authored-by: Rithwik Ediga Lakhamsani <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
python/pyspark/ml/torch/distributor.py | 81 +++++++--
python/pyspark/ml/torch/log_communication.py | 201 +++++++++++++++++++++
python/pyspark/ml/torch/tests/test_distributor.py | 5 +-
.../ml/torch/tests/test_log_communication.py | 176 ++++++++++++++++++
5 files changed, 449 insertions(+), 15 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index ae41cae4696..634be286065 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -635,6 +635,7 @@ pyspark_ml = Module(
"pyspark.ml.tests.test_util",
"pyspark.ml.tests.test_wrapper",
"pyspark.ml.torch.tests.test_distributor",
+ "pyspark.ml.torch.tests.test_log_communication",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it
isn't available there
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index 51a1203cbd5..804e0d4a27a 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -16,6 +16,7 @@
#
import collections
+import logging
import math
import os
import random
@@ -24,9 +25,13 @@ import sys
import subprocess
import time
from typing import Union, Callable, List, Dict, Optional, Any
-import warnings
from pyspark.sql import SparkSession
+from pyspark.ml.torch.log_communication import ( # type: ignore
+ get_driver_host,
+ LogStreamingClient,
+ LogStreamingServer,
+)
from pyspark.context import SparkContext
from pyspark.taskcontext import BarrierTaskContext
@@ -72,6 +77,19 @@ def get_conf_boolean(sc: SparkContext, key: str,
default_value: str) -> bool:
)
+def get_logger(name: str) -> logging.Logger:
+ """
+ Gets a logger by name, or creates and configures it for the first time.
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ # If the logger is configured, skip the configure
+ if not logger.handlers and not logging.getLogger().handlers:
+ handler = logging.StreamHandler(sys.stderr)
+ logger.addHandler(handler)
+ return logger
+
+
def get_gpus_owned(context: Union[SparkContext, BarrierTaskContext]) ->
List[str]:
"""Gets the number of GPUs that Spark scheduled to the calling task.
@@ -122,6 +140,7 @@ class Distributor:
local_mode: bool = True,
use_gpu: bool = True,
):
+ self.logger = get_logger(self.__class__.__name__)
self.num_processes = num_processes
self.local_mode = local_mode
self.use_gpu = use_gpu
@@ -134,7 +153,7 @@ class Distributor:
def _create_input_params(self) -> Dict[str, Any]:
input_params = self.__dict__.copy()
- for unneeded_param in ["spark", "sc", "ssl_conf"]:
+ for unneeded_param in ["spark", "sc", "ssl_conf", "logger"]:
del input_params[unneeded_param]
return input_params
@@ -169,11 +188,10 @@ class Distributor:
if num_available_gpus == 0:
raise RuntimeError("GPU resources were not configured
properly on the driver.")
if self.num_processes > num_available_gpus:
- warnings.warn(
+ self.logger.warning(
f"'num_processes' cannot be set to a value greater
than the number of "
f"available GPUs on the driver, which is
{num_available_gpus}. "
f"'num_processes' was reset to be equal to the number
of available GPUs.",
- RuntimeWarning,
)
self.num_processes = num_available_gpus
return self.num_processes
@@ -201,7 +219,7 @@ class Distributor:
if is_ssl_enabled:
name = self.__class__.__name__
if ignore_ssl:
- warnings.warn(
+ self.logger.warning(
f"""
This cluster has TLS encryption enabled;
however, {name} does not
@@ -348,7 +366,10 @@ class TorchDistributor(Distributor):
@staticmethod
def _execute_command(
- cmd: List[str], _prctl: bool = True, redirect_to_stdout: bool = True
+ cmd: List[str],
+ _prctl: bool = True,
+ redirect_to_stdout: bool = True,
+ log_streaming_client: Optional[LogStreamingClient] = None,
) -> None:
_TAIL_LINES_TO_KEEP = 100
@@ -367,6 +388,8 @@ class TorchDistributor(Distributor):
tail.append(decoded)
if redirect_to_stdout:
sys.stdout.write(decoded)
+ if log_streaming_client:
+ log_streaming_client.send(decoded.rstrip())
task.wait()
finally:
if task.poll() is None:
@@ -404,7 +427,10 @@ class TorchDistributor(Distributor):
selected_gpus = [str(e) for e in random.sample(gpus_owned,
self.num_processes)]
os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
+ self.logger.info(f"Started local training with
{self.num_processes} processes")
output = framework_wrapper_fn(self.input_params, train_object,
*args) # type: ignore
+ self.logger.info(f"Finished local training with
{self.num_processes} processes")
+
finally:
if cuda_state_was_set:
os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices
@@ -438,6 +464,8 @@ class TorchDistributor(Distributor):
num_processes = self.num_processes
use_gpu = self.use_gpu
input_params = self.input_params
+ driver_address = self.driver_address
+ log_streaming_server_port = self.log_streaming_server_port
# Spark task program
def wrapped_train_fn(_): # type: ignore[no-untyped-def]
@@ -486,7 +514,15 @@ class TorchDistributor(Distributor):
os.environ[CUDA_VISIBLE_DEVICES] = ""
set_torch_config(context)
- output = framework_wrapper_fn(input_params, train_object, *args)
+ log_streaming_client = LogStreamingClient(driver_address,
log_streaming_server_port)
+ input_params["log_streaming_client"] = log_streaming_client
+ try:
+ output = framework_wrapper_fn(input_params, train_object,
*args)
+ finally:
+ try:
+ LogStreamingClient._destroy()
+ except BaseException:
+ pass
if context.partitionId() == 0:
yield output
@@ -501,15 +537,31 @@ class TorchDistributor(Distributor):
) -> Optional[Any]:
if not framework_wrapper_fn:
raise RuntimeError("Unknown combination of parameters")
+
+ log_streaming_server = LogStreamingServer()
+ self.driver_address = get_driver_host(self.sc)
+ log_streaming_server.start(spark_host_address=self.driver_address)
+ time.sleep(1) # wait for the server to start
+ self.log_streaming_server_port = log_streaming_server.port
+
spark_task_function = self._get_spark_task_function(
framework_wrapper_fn, train_object, *args
)
self._check_encryption()
- result = (
- self.sc.parallelize(range(self.num_tasks), self.num_tasks)
- .barrier()
- .mapPartitions(spark_task_function)
- .collect()[0]
+ self.logger.info(
+ f"Started distributed training with {self.num_processes} executor
proceses"
+ )
+ try:
+ result = (
+ self.sc.parallelize(range(self.num_tasks), self.num_tasks)
+ .barrier()
+ .mapPartitions(spark_task_function)
+ .collect()[0]
+ )
+ finally:
+ log_streaming_server.shutdown()
+ self.logger.info(
+ f"Finished distributed training with {self.num_processes} executor
proceses"
)
return result
@@ -517,10 +569,13 @@ class TorchDistributor(Distributor):
def _run_training_on_pytorch_file(
input_params: Dict[str, Any], train_path: str, *args: Any
) -> None:
+ log_streaming_client = input_params.get("log_streaming_client", None)
training_command = TorchDistributor._create_torchrun_command(
input_params, train_path, *args
)
- TorchDistributor._execute_command(training_command)
+ TorchDistributor._execute_command(
+ training_command, log_streaming_client=log_streaming_client
+ )
def run(self, train_object: Union[Callable, str], *args: Any) ->
Optional[Any]:
"""Runs distributed training.
diff --git a/python/pyspark/ml/torch/log_communication.py
b/python/pyspark/ml/torch/log_communication.py
new file mode 100644
index 00000000000..e269aed42d1
--- /dev/null
+++ b/python/pyspark/ml/torch/log_communication.py
@@ -0,0 +1,201 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# type: ignore
+
+from contextlib import closing
+import time
+import socket
+import socketserver
+from struct import pack, unpack
+import sys
+import threading
+import traceback
+from typing import Optional, Generator
+import warnings
+from pyspark.context import SparkContext
+
+# Use b'\x00' as separator instead of b'\n', because the bytes are encoded in
utf-8
+_SERVER_POLL_INTERVAL = 0.1
+_TRUNCATE_MSG_LEN = 4000
+
+
+def get_driver_host(sc: SparkContext) -> Optional[str]:
+ return sc.getConf().get("spark.driver.host")
+
+
+_log_print_lock = threading.Lock() # pylint: disable=invalid-name
+
+
+def _get_log_print_lock() -> threading.Lock:
+ return _log_print_lock
+
+
+class WriteLogToStdout(socketserver.StreamRequestHandler):
+ def _read_bline(self) -> Generator[bytes, None, None]:
+ while self.server.is_active:
+ packed_number_bytes = self.rfile.read(4)
+ if not packed_number_bytes:
+ time.sleep(_SERVER_POLL_INTERVAL)
+ continue
+ number_bytes = unpack(">i", packed_number_bytes)[0]
+ message = self.rfile.read(number_bytes)
+ yield message
+
+ def handle(self) -> None:
+ self.request.setblocking(0) # non-blocking mode
+ for bline in self._read_bline():
+ with _get_log_print_lock():
+ sys.stderr.write(bline.decode("utf-8") + "\n")
+ sys.stderr.flush()
+
+
+# What is run on the local driver
+class LogStreamingServer:
+ def __init__(self) -> None:
+ self.server = None
+ self.serve_thread = None
+ self.port = None
+
+ @staticmethod
+ def _get_free_port(spark_host_address: str = "") -> int:
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as tcp:
+ tcp.bind((spark_host_address, 0))
+ _, port = tcp.getsockname()
+ return port
+
+ def start(self, spark_host_address: str = "") -> None:
+ if self.server:
+ raise RuntimeError("Cannot start the server twice.")
+
+ def serve_task(port: int) -> None:
+ with socketserver.ThreadingTCPServer(("0.0.0.0", port),
WriteLogToStdout) as server:
+ self.server = server
+ server.is_active = True
+ server.serve_forever(poll_interval=_SERVER_POLL_INTERVAL)
+
+ self.port = LogStreamingServer._get_free_port(spark_host_address)
+ self.serve_thread = threading.Thread(target=serve_task,
args=(self.port,))
+ self.serve_thread.setDaemon(True)
+ self.serve_thread.start()
+
+ def shutdown(self) -> None:
+ if self.server:
+ # Sleep to ensure all log has been received and printed.
+ time.sleep(_SERVER_POLL_INTERVAL * 2)
+ # Before close we need flush to ensure all stdout buffer were
printed.
+ sys.stdout.flush()
+ self.server.is_active = False
+ self.server.shutdown()
+ self.serve_thread.join()
+ self.server = None
+ self.serve_thread = None
+
+
+class LogStreamingClientBase:
+ @staticmethod
+ def _maybe_truncate_msg(message: str) -> str:
+ if len(message) > _TRUNCATE_MSG_LEN:
+ message = message[:_TRUNCATE_MSG_LEN]
+ return message + "...(truncated)"
+ else:
+ return message
+
+ def send(self, message: str) -> None:
+ pass
+
+ def close(self) -> None:
+ pass
+
+
+class LogStreamingClient(LogStreamingClientBase):
+ """
+ A client that streams log messages to :class:`LogStreamingServer`.
+ In case of failures, the client will skip messages instead of raising an
error.
+ """
+
+ _log_callback_client = None
+ _server_address = None
+ _singleton_lock = threading.Lock()
+
+ @staticmethod
+ def _init(address: str, port: int) -> None:
+ LogStreamingClient._server_address = (address, port)
+
+ @staticmethod
+ def _destroy() -> None:
+ LogStreamingClient._server_address = None
+ if LogStreamingClient._log_callback_client is not None:
+ LogStreamingClient._log_callback_client.close()
+
+ def __init__(self, address: str, port: int, timeout: int = 10):
+ """
+ Creates a connection to the logging server and authenticates.This
client is best effort,
+ if authentication or sending a message fails, the client will be
marked as not alive and
+ stop trying to send message.
+
+ :param address: Address where the service is running.
+ :param port: Port where the service is listening for new connections.
+ """
+ self.address = address
+ self.port = port
+ self.timeout = timeout
+ self.sock = None
+ self.failed = True
+ self._lock = threading.RLock()
+
+ def _fail(self, error_msg: str) -> None:
+ self.failed = True
+ warnings.warn(f"{error_msg}: {traceback.format_exc()}\n")
+
+ def _connect(self) -> None:
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(self.timeout)
+ sock.connect((self.address, self.port))
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ self.sock = sock
+ self.failed = False
+ except (OSError, IOError): # pylint: disable=broad-except
+ self._fail("Error connecting log streaming server")
+
+ def send(self, message: str) -> None:
+ """
+ Sends a message.
+ """
+ with self._lock:
+ if self.sock is None:
+ self._connect()
+ if not self.failed:
+ try:
+ message =
LogStreamingClientBase._maybe_truncate_msg(message)
+ # TODO:
+ # 1) addressing issue: idle TCP connection might get
disconnected by
+ # cloud provider
+ # 2) sendall may block when server is busy handling data.
+ binary_message = message.encode("utf-8")
+ packed_number_bytes = pack(">i", len(binary_message))
+ self.sock.sendall(packed_number_bytes + binary_message)
+ except Exception: # pylint: disable=broad-except
+ self._fail("Error sending logs to driver, stopping log
streaming")
+
+ def close(self) -> None:
+ """
+ Closes the connection.
+ """
+ if self.sock:
+ self.sock.close()
+ self.sock = None
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py
b/python/pyspark/ml/torch/tests/test_distributor.py
index 607cc7cd1ad..619e733c0bb 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -248,9 +248,10 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
for num_processes in fails:
with self.subTest():
- with self.assertWarns(RuntimeWarning):
+ with self.assertLogs("TorchDistributor", level="WARNING") as
log:
distributor = TorchDistributor(num_processes, True, True)
- distributor.num_processes = 3
+ self.assertEqual(len(log.records), 1)
+ self.assertEqual(distributor.num_processes, 3)
def test_get_gpus_owned_local(self) -> None:
addresses = ["0", "1", "2"]
diff --git a/python/pyspark/ml/torch/tests/test_log_communication.py
b/python/pyspark/ml/torch/tests/test_log_communication.py
new file mode 100644
index 00000000000..0c937926480
--- /dev/null
+++ b/python/pyspark/ml/torch/tests/test_log_communication.py
@@ -0,0 +1,176 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import, division, print_function
+
+import contextlib
+from six import StringIO # type: ignore
+import sys
+import time
+from typing import Any, Callable
+import unittest
+
+import pyspark.ml.torch.log_communication
+from pyspark.ml.torch.log_communication import ( # type: ignore
+ LogStreamingServer,
+ LogStreamingClient,
+ LogStreamingClientBase,
+ _SERVER_POLL_INTERVAL,
+)
+
+
[email protected]
+def patch_stderr() -> StringIO:
+ """patch stdout and give an output"""
+ sys_stderr = sys.stderr
+ io_out = StringIO()
+ sys.stderr = io_out
+ try:
+ yield io_out
+ finally:
+ sys.stderr = sys_stderr
+
+
+class LogStreamingServiceTestCase(unittest.TestCase):
+ def setUp(self) -> None:
+ self.default_truncate_msg_len = (
+ pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN # type:
ignore
+ )
+ pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10 # type:
ignore
+
+ def tearDown(self) -> None:
+ pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = ( # type:
ignore
+ self.default_truncate_msg_len
+ )
+
+ def basic_test(self) -> None:
+ server = LogStreamingServer()
+ server.start()
+ time.sleep(1)
+ client = LogStreamingClient("localhost", server.port)
+ with patch_stderr() as output:
+ client.send("msg 001")
+ client.send("msg 002")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ output = output.getvalue()
+ self.assertIn("msg 001\nmsg 002\n", output)
+ client.close()
+ server.shutdown()
+
+ def test_truncate_message(self) -> None:
+ msg1 = "abc"
+ assert LogStreamingClientBase._maybe_truncate_msg(msg1) == msg1
+ msg2 = "abcdefghijkl"
+ assert LogStreamingClientBase._maybe_truncate_msg(msg2) ==
"abcdefghij...(truncated)"
+
+ def test_multiple_clients(self) -> None:
+ server = LogStreamingServer()
+ server.start()
+ time.sleep(1)
+ client1 = LogStreamingClient("localhost", server.port)
+ client2 = LogStreamingClient("localhost", server.port)
+ with patch_stderr() as output:
+ client1.send("c1 msg1")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ client2.send("c2 msg1")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ client1.send("c1 msg2")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ client2.send("c2 msg2")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ output = output.getvalue()
+ self.assertIn("c1 msg1\nc2 msg1\nc1 msg2\nc2 msg2\n", output)
+ client1.close()
+ client2.close()
+ server.shutdown()
+
+ def test_client_should_fail_gracefully(self) -> None:
+ server = LogStreamingServer()
+ server.start()
+ time.sleep(1)
+ client = LogStreamingClient("localhost", server.port)
+ client.send("msg 001")
+ server.shutdown()
+ for i in range(5):
+ client.send("msg 002")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ self.assertTrue(client.failed)
+ client.close()
+
+ def test_client_send_intermittently(self) -> None:
+ server = LogStreamingServer()
+ server.start()
+ time.sleep(1)
+ client = LogStreamingClient("localhost", server.port)
+ with patch_stderr() as output:
+ client._connect()
+ # test client send half message first
+ client.send("msg part1")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ # test client send another half message
+ client.send(" msg part2")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+ output = output.getvalue()
+ self.assertIn("msg part1\n msg part2\n", output)
+ client.close()
+ server.shutdown()
+
+ @staticmethod
+ def test_server_shutdown() -> None:
+ def run_test(client_ops: Callable) -> None:
+ server = LogStreamingServer()
+ server.start()
+ time.sleep(1)
+ client = LogStreamingClient("localhost", server.port)
+ client_ops(client)
+ server.shutdown()
+ client.close()
+
+ def client_ops_close(client: Any) -> None:
+ client.close()
+
+ def client_ops_send_half_msg(client: Any) -> None:
+ # Test server only recv incomplete message from client can exit.
+ client._connect()
+ client.sock.sendall(b"msg part1 ")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+
+ def client_ops_send_a_msg(client: Any) -> None:
+ client.send("msg1")
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+
+ def client_ops_send_a_msg_and_close(client: Any) -> None:
+ client.send("msg1")
+ client.close()
+ time.sleep(_SERVER_POLL_INTERVAL + 1)
+
+ run_test(client_ops_close)
+ run_test(client_ops_send_half_msg)
+ run_test(client_ops_send_a_msg)
+ run_test(client_ops_send_a_msg_and_close)
+
+
+if __name__ == "__main__":
+ from pyspark.ml.torch.tests.test_log_communication import * # noqa:
F401,F403 type: ignore
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]