allisonwang-db commented on code in PR #52689:
URL: https://github.com/apache/spark/pull/52689#discussion_r2453353572


##########
core/src/main/scala/org/apache/spark/storage/BlockId.scala:
##########
@@ -188,30 +189,45 @@ object LogBlockType extends Enumeration {
  *                    and log management.
  * @param executorId the ID of the executor that produced this log block.
  */
-abstract sealed class LogBlockId(
-    val lastLogTime: Long,
-    val executorId: String) extends BlockId {
+abstract sealed class LogBlockId extends BlockId {
+  def lastLogTime: Long
+  def executorId: String
   def logBlockType: LogBlockType
 }
 
 object LogBlockId {
   def empty(logBlockType: LogBlockType): LogBlockId = {
     logBlockType match {
       case LogBlockType.TEST => TestLogBlockId(0L, "")
+      case LogBlockType.PYTHON_WORKER => PythonWorkerLogBlockId(0L, "", "", "")
       case _ => throw new SparkException(s"Unsupported log block type: 
$logBlockType")
     }
   }
 }
 
 // Used for test purpose only.
-case class TestLogBlockId(override val lastLogTime: Long, override val 
executorId: String)
-  extends LogBlockId(lastLogTime, executorId) {
+case class TestLogBlockId(lastLogTime: Long, executorId: String)
+  extends LogBlockId {
   override def name: String =
     "test_log_" + lastLogTime + "_" + executorId
 
   override def logBlockType: LogBlockType = LogBlockType.TEST
 }
 
+@DeveloperApi
+case class PythonWorkerLogBlockId(
+    lastLogTime: Long,
+    executorId: String,
+    sessionId: String,
+    workerId: String)

Review Comment:
   Can we add some comments to explain the these input values?



##########
python/pyspark/logger/worker_io.py:
##########
@@ -0,0 +1,293 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from contextlib import contextmanager
+import inspect
+import io
+import logging
+import os
+import sys
+import time
+from typing import BinaryIO, Callable, Generator, Iterable, Iterator, 
Optional, TextIO, Union
+from types import FrameType, TracebackType
+
+from pyspark.logger.logger import JSONFormatter
+
+
+class DelegatingTextIOWrapper(TextIO):
+    """A TextIO that delegates all operations to another TextIO object."""
+
+    def __init__(self, delegate: TextIO):
+        self._delegate = delegate
+
+    # Required TextIO properties
+    @property
+    def encoding(self) -> str:
+        return self._delegate.encoding
+
+    @property
+    def errors(self) -> Optional[str]:
+        return self._delegate.errors
+
+    @property
+    def newlines(self) -> Optional[Union[str, tuple[str, ...]]]:
+        return self._delegate.newlines
+
+    @property
+    def buffer(self) -> BinaryIO:
+        return self._delegate.buffer
+
+    @property
+    def mode(self) -> str:
+        return self._delegate.mode
+
+    @property
+    def name(self) -> str:
+        return self._delegate.name
+
+    @property
+    def line_buffering(self) -> int:
+        return self._delegate.line_buffering
+
+    @property
+    def closed(self) -> bool:
+        return self._delegate.closed
+
+    # Iterator protocol
+    def __iter__(self) -> Iterator[str]:
+        return iter(self._delegate)
+
+    def __next__(self) -> str:
+        return next(self._delegate)
+
+    # Context manager protocol
+    def __enter__(self) -> TextIO:
+        return self._delegate.__enter__()
+
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        return self._delegate.__exit__(exc_type, exc_val, exc_tb)
+
+    # Core I/O methods
+    def write(self, s: str) -> int:
+        return self._delegate.write(s)
+
+    def writelines(self, lines: Iterable[str]) -> None:
+        return self._delegate.writelines(lines)
+
+    def read(self, size: int = -1) -> str:
+        return self._delegate.read(size)
+
+    def readline(self, size: int = -1) -> str:
+        return self._delegate.readline(size)
+
+    def readlines(self, hint: int = -1) -> list[str]:
+        return self._delegate.readlines(hint)
+
+    # Stream control methods
+    def close(self) -> None:
+        return self._delegate.close()
+
+    def flush(self) -> None:
+        return self._delegate.flush()
+
+    def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
+        return self._delegate.seek(offset, whence)
+
+    def tell(self) -> int:
+        return self._delegate.tell()
+
+    def truncate(self, size: Optional[int] = None) -> int:
+        return self._delegate.truncate(size)
+
+    # Stream capability methods
+    def fileno(self) -> int:
+        return self._delegate.fileno()
+
+    def isatty(self) -> bool:
+        return self._delegate.isatty()
+
+    def readable(self) -> bool:
+        return self._delegate.readable()
+
+    def seekable(self) -> bool:
+        return self._delegate.seekable()
+
+    def writable(self) -> bool:
+        return self._delegate.writable()
+
+
+class JSONFormatterWithMarker(JSONFormatter):
+    default_microsec_format = "%s.%06d"
+
+    def __init__(self, marker: str, context_provider: Callable[[], dict[str, 
str]]):
+        super().__init__(ensure_ascii=True)
+        self._marker = marker
+        self._context_provider = context_provider
+
+    def format(self, record: logging.LogRecord) -> str:
+        context = self._context_provider()
+        if context:
+            context.update(record.__dict__.get("context", {}))
+            record.__dict__["context"] = context
+        return f"{self._marker}:{os.getpid()}:{super().format(record)}"
+
+    def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = 
None) -> str:
+        ct = self.converter(record.created)
+        if datefmt:
+            s = time.strftime(datefmt, ct)
+        else:
+            s = time.strftime(self.default_time_format, ct)
+            if self.default_microsec_format:
+                s = self.default_microsec_format % (
+                    s,
+                    int((record.created - int(record.created)) * 1000000),
+                )
+            elif self.default_msec_format:
+                s = self.default_msec_format % (s, record.msecs)
+        return s
+
+
+class JsonOutput(DelegatingTextIOWrapper):
+    def __init__(
+        self,
+        delegate: TextIO,
+        json_out: TextIO,
+        logger_name: str,
+        log_level: int,
+        marker: str,
+        context_provider: Callable[[], dict[str, str]],
+    ):
+        super().__init__(delegate)
+        self._json_out = json_out
+        self._logger_name = logger_name
+        self._log_level = log_level
+        self._formatter = JSONFormatterWithMarker(marker, context_provider)
+
+    def write(self, s: str) -> int:
+        if s.strip():
+            log_record = logging.LogRecord(
+                name=self._logger_name,
+                level=self._log_level,
+                pathname=None,  # type: ignore[arg-type]
+                lineno=None,  # type: ignore[arg-type]
+                msg=s.strip(),
+                args=None,
+                exc_info=None,
+                func=None,
+                sinfo=None,
+            )
+            self._json_out.write(f"{self._formatter.format(log_record)}\n")
+            self._json_out.flush()
+        return self._delegate.write(s)
+
+    def writelines(self, lines: Iterable[str]) -> None:
+        # Process each line through our JSON logging logic
+        for line in lines:
+            self.write(line)
+
+    def close(self) -> None:
+        pass
+
+
+def context_provider() -> dict[str, str]:
+    """
+    Provides context information for logging, including caller function name.
+    Finds the function name from the bottom of the stack, ignoring Python 
builtin
+    libraries and PySpark modules. Test packages are included.
+
+    Returns:
+        dict[str, str]: A dictionary containing context information including:
+            - func_name: Name of the function that initiated the logging
+            - class_name: Name of the class that initiated the logging if 
available
+    """
+
+    def is_pyspark_module(module_name: str) -> bool:
+        return module_name.startswith("pyspark.") and ".tests." not in 
module_name
+
+    bottom: Optional[FrameType] = None
+
+    # Get caller function information using inspect
+    try:
+        frame = inspect.currentframe()
+        is_in_pyspark_module = False
+
+        if frame:
+            while frame.f_back:
+                f_back = frame.f_back
+                module_name = f_back.f_globals.get("__name__", "")
+
+                if is_pyspark_module(module_name):
+                    if not is_in_pyspark_module:
+                        bottom = frame
+                        is_in_pyspark_module = True
+                else:
+                    is_in_pyspark_module = False
+
+                frame = f_back
+    except Exception:
+        # If anything goes wrong with introspection, don't fail the logging
+        # Just continue without caller information
+        pass
+
+    context = {}
+    if bottom:
+        context["func_name"] = bottom.f_code.co_name
+        if "self" in bottom.f_locals:

Review Comment:
   So this can also work with UDTFs and data sources (i.e you can print out the 
UDTF/DataSource names) 



##########
python/pyspark/logger/worker_io.py:
##########
@@ -0,0 +1,293 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from contextlib import contextmanager
+import inspect
+import io
+import logging
+import os
+import sys
+import time
+from typing import BinaryIO, Callable, Generator, Iterable, Iterator, 
Optional, TextIO, Union
+from types import FrameType, TracebackType
+
+from pyspark.logger.logger import JSONFormatter
+
+
+class DelegatingTextIOWrapper(TextIO):
+    """A TextIO that delegates all operations to another TextIO object."""
+
+    def __init__(self, delegate: TextIO):
+        self._delegate = delegate
+
+    # Required TextIO properties
+    @property
+    def encoding(self) -> str:
+        return self._delegate.encoding
+
+    @property
+    def errors(self) -> Optional[str]:
+        return self._delegate.errors
+
+    @property
+    def newlines(self) -> Optional[Union[str, tuple[str, ...]]]:
+        return self._delegate.newlines
+
+    @property
+    def buffer(self) -> BinaryIO:
+        return self._delegate.buffer
+
+    @property
+    def mode(self) -> str:
+        return self._delegate.mode
+
+    @property
+    def name(self) -> str:
+        return self._delegate.name
+
+    @property
+    def line_buffering(self) -> int:
+        return self._delegate.line_buffering
+
+    @property
+    def closed(self) -> bool:
+        return self._delegate.closed
+
+    # Iterator protocol
+    def __iter__(self) -> Iterator[str]:
+        return iter(self._delegate)
+
+    def __next__(self) -> str:
+        return next(self._delegate)
+
+    # Context manager protocol
+    def __enter__(self) -> TextIO:
+        return self._delegate.__enter__()
+
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        return self._delegate.__exit__(exc_type, exc_val, exc_tb)
+
+    # Core I/O methods
+    def write(self, s: str) -> int:
+        return self._delegate.write(s)
+
+    def writelines(self, lines: Iterable[str]) -> None:
+        return self._delegate.writelines(lines)
+
+    def read(self, size: int = -1) -> str:
+        return self._delegate.read(size)
+
+    def readline(self, size: int = -1) -> str:
+        return self._delegate.readline(size)
+
+    def readlines(self, hint: int = -1) -> list[str]:
+        return self._delegate.readlines(hint)
+
+    # Stream control methods
+    def close(self) -> None:
+        return self._delegate.close()
+
+    def flush(self) -> None:
+        return self._delegate.flush()
+
+    def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
+        return self._delegate.seek(offset, whence)
+
+    def tell(self) -> int:
+        return self._delegate.tell()
+
+    def truncate(self, size: Optional[int] = None) -> int:
+        return self._delegate.truncate(size)
+
+    # Stream capability methods
+    def fileno(self) -> int:
+        return self._delegate.fileno()
+
+    def isatty(self) -> bool:
+        return self._delegate.isatty()
+
+    def readable(self) -> bool:
+        return self._delegate.readable()
+
+    def seekable(self) -> bool:
+        return self._delegate.seekable()
+
+    def writable(self) -> bool:
+        return self._delegate.writable()
+
+
+class JSONFormatterWithMarker(JSONFormatter):
+    default_microsec_format = "%s.%06d"
+
+    def __init__(self, marker: str, context_provider: Callable[[], dict[str, 
str]]):
+        super().__init__(ensure_ascii=True)
+        self._marker = marker
+        self._context_provider = context_provider
+
+    def format(self, record: logging.LogRecord) -> str:
+        context = self._context_provider()
+        if context:
+            context.update(record.__dict__.get("context", {}))
+            record.__dict__["context"] = context
+        return f"{self._marker}:{os.getpid()}:{super().format(record)}"

Review Comment:
   can we cache the value of `os.getpid`?



##########
core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala:
##########
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io.{BufferedReader, InputStream, InputStreamReader}
+import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.storage.{PythonWorkerLogBlockIdGenerator, 
PythonWorkerLogLine, RollingLogWriter}
+
+/**
+ * Manages Python UDF log capture and routing to per-worker log writers.
+ *
+ * This class handles the parsing of Python worker output streams and routes
+ * log messages to appropriate rolling log writers based on worker PIDs.
+ * Works for both daemon and non-daemon modes.
+ */
+private[python] class PythonWorkerLogCapture(
+    sessionId: String,
+    logMarker: String = "PYTHON_WORKER_LOGGING") extends Logging {
+
+  // Map to track per-worker log writers: workerId(PID) -> (writer, sequenceId)
+  private val workerLogWriters = new ConcurrentHashMap[String, 
(RollingLogWriter, AtomicLong)]()
+
+  /**
+   * Creates an InputStream wrapper that captures Python UDF logs from the 
given stream.
+   *
+   * @param inputStream The input stream to wrap (typically daemon stdout or 
worker stdout)
+   * @return A wrapped InputStream that captures and routes log messages
+   */
+  def wrapInputStream(inputStream: InputStream): InputStream = {
+    new CaptureWorkerLogsInputStream(inputStream)
+  }
+
+  /**
+   * Removes and closes the log writer for a specific worker.
+   *
+   * @param workerId The worker ID (typically PID as string)
+   */
+  def removeAndCloseWorkerLogWriter(workerId: String): Unit = {
+    Option(workerLogWriters.remove(workerId)).foreach { case (writer, _) =>
+      try {
+        writer.close()
+      } catch {
+        case e: Exception =>
+          logWarning(s"Failed to close log writer for worker $workerId", e)
+      }
+    }
+  }
+
+  /**
+   * Closes all active worker log writers.
+   */
+  def closeAllWriters(): Unit = {
+    workerLogWriters.values().asScala.foreach { case (writer, _) =>
+      try {
+        writer.close()
+      } catch {
+        case e: Exception =>
+          logWarning("Failed to close log writer", e)
+      }
+    }
+    workerLogWriters.clear()
+  }
+
+  /**
+   * Gets or creates a log writer for the specified worker.
+   *
+   * @param workerId Unique identifier for the worker (typically PID)
+   * @return Tuple of (RollingLogWriter, AtomicLong sequence counter)
+   */
+  private def getOrCreateLogWriter(workerId: String): (RollingLogWriter, 
AtomicLong) = {
+    workerLogWriters.computeIfAbsent(workerId, _ => {
+      val logWriter = SparkEnv.get.blockManager.getRollingLogWriter(
+        new PythonWorkerLogBlockIdGenerator(sessionId, workerId)
+      )
+      (logWriter, new AtomicLong())
+    })
+  }
+
+  /**
+   * Processes a log line from a Python worker.
+   *
+   * @param line The complete line containing the log marker and JSON
+   * @return The prefix (non-log content) that should be passed through
+   */
+  private def processLogLine(line: String): String = {
+    val markerIndex = line.indexOf(s"$logMarker:")
+    if (markerIndex >= 0) {
+      val prefix = line.substring(0, markerIndex)
+      val markerAndJson = line.substring(markerIndex)
+
+      // Parse: "PYTHON_UDF_LOGGING:12345:{json}"
+      val parts = markerAndJson.split(":", 3)
+      if (parts.length >= 3) {
+        val workerId = parts(1) // This is the PID from Python worker
+        val json = parts(2)
+
+        try {
+          if (json.isEmpty) {
+            removeAndCloseWorkerLogWriter(workerId)
+          } else {
+            val (writer, seqId) = getOrCreateLogWriter(workerId)
+            writer.writeLog(
+              PythonWorkerLogLine(System.currentTimeMillis(), 
seqId.getAndIncrement(), json)

Review Comment:
   Do we want to limit the number of lines written to block manager?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to