This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 480b14f4b45 [SPARK-41593][FOLLOW-UP][ML] Torch distributor log
streaming server: Avoid duplicated log to stdout redirection
480b14f4b45 is described below
commit 480b14f4b45a4e01e7e4bdea82475d4c77a6b89f
Author: Weichen Xu <[email protected]>
AuthorDate: Thu Jun 1 16:26:49 2023 +0800
[SPARK-41593][FOLLOW-UP][ML] Torch distributor log streaming server: Avoid
duplicated log to stdout redirection
### What changes were proposed in this pull request?
Torch distributor log streaming server: Avoid duplicated log to stdout
redirection.
In some cases (typically spark local mode), the remote tasks runs on the
same node with spark driver,
in this case, both torch process created by spark task and driver side
torch distributor log streaming server redirect logs to STDOUT, then it causes
STDOUT prints duplicate logs. This PR fixes the case.
### Why are the changes needed?
In some cases (typically spark local mode), the remote tasks runs on the
same node with spark driver,
in this case, both torch process created by spark task and driver side
torch distributor log streaming server redirect logs to STDOUT, then it causes
STDOUT prints duplicate logs. This PR fixes the case.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT.
Closes #41404 from WeichenXu123/torch-distributor-avoid-dup-log.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
python/pyspark/ml/torch/distributor.py | 17 ++++++++++++++++-
1 file changed, 16 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index 711f76db09b..2af92f8399f 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -462,7 +462,22 @@ class TorchDistributor(Distributor):
decoded = line.decode()
tail.append(decoded)
if redirect_to_stdout:
- sys.stdout.write(decoded)
+ if (
+ log_streaming_client
+ and not log_streaming_client.failed
+ and (
+ log_streaming_client.sock.getsockname()[0]
+ == log_streaming_client.sock.getpeername()[0]
+ )
+ ):
+ # If log_streaming_client and log_stream_server are in
the same
+ # node (typical case is spark local mode),
+ # server side will redirect the log to STDOUT,
+ # to avoid STDOUT outputs duplication, skip redirecting
+ # logs to STDOUT in client side.
+ pass
+ else:
+ sys.stdout.write(decoded)
if log_streaming_client:
log_streaming_client.send(decoded.rstrip())
task.wait()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]