Yicong-Huang commented on code in PR #52689: URL: https://github.com/apache/spark/pull/52689#discussion_r2450029019
########## core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala: ########## @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{BufferedReader, InputStream, InputStreamReader} +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.storage.{PythonWorkerLogBlockIdGenerator, PythonWorkerLogLine, RollingLogWriter} + +/** + * Manages Python UDF log capture and routing to per-worker log writers. + * + * This class handles the parsing of Python worker output streams and routes + * log messages to appropriate rolling log writers based on worker PIDs. + * Works for both daemon and non-daemon modes. + */ +private[python] class PythonWorkerLogCapture( + sessionId: String, + logMarker: String = "PYTHON_WORKER_LOGGING") extends Logging { + + // Map to track per-worker log writers: workerId(PID) -> (writer, sequenceId) + private val workerLogWriters = new ConcurrentHashMap[String, (RollingLogWriter, AtomicLong)]() + + /** + * Creates an InputStream wrapper that captures Python UDF logs from the given stream. + * + * @param inputStream The input stream to wrap (typically daemon stdout or worker stdout) + * @return A wrapped InputStream that captures and routes log messages + */ + def wrapInputStream(inputStream: InputStream): InputStream = { + new CaptureWorkerLogsInputStream(inputStream) + } + + /** + * Removes and closes the log writer for a specific worker. + * + * @param workerId The worker ID (typically PID as string) + */ + def removeAndCloseWorkerLogWriter(workerId: String): Unit = { + Option(workerLogWriters.remove(workerId)).foreach { case (writer, _) => + try { + writer.close() + } catch { + case e: Exception => + logWarning(s"Failed to close log writer for worker $workerId", e) + } + } + } + + /** + * Closes all active worker log writers. + */ + def closeAllWriters(): Unit = { + workerLogWriters.values().asScala.foreach { case (writer, _) => + try { + writer.close() + } catch { + case e: Exception => + logWarning("Failed to close log writer", e) + } + } + workerLogWriters.clear() + } + + /** + * Gets or creates a log writer for the specified worker. + * + * @param workerId Unique identifier for the worker (typically PID) + * @return Tuple of (RollingLogWriter, AtomicLong sequence counter) + */ + private def getOrCreateLogWriter(workerId: String): (RollingLogWriter, AtomicLong) = { + workerLogWriters.computeIfAbsent(workerId, _ => { + val logWriter = SparkEnv.get.blockManager.getRollingLogWriter( + new PythonWorkerLogBlockIdGenerator(sessionId, workerId) + ) + (logWriter, new AtomicLong()) + }) + } + + /** + * Processes a log line from a Python worker. + * + * @param line The complete line containing the log marker and JSON + * @return The prefix (non-log content) that should be passed through + */ + private def processLogLine(line: String): String = { + val markerIndex = line.indexOf(s"$logMarker:") + if (markerIndex >= 0) { + val prefix = line.substring(0, markerIndex) + val markerAndJson = line.substring(markerIndex) + Review Comment: does this work with multi-line log records? ########## python/pyspark/logger/logger.py: ########## @@ -291,7 +295,7 @@ def _log( msg=msg, args=args, exc_info=exc_info, - extra={"kwargs": kwargs}, + extra={"context": kwargs}, Review Comment: is this change intentional? ########## python/pyspark/logger/worker_io.py: ########## @@ -0,0 +1,293 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from contextlib import contextmanager +import inspect +import io +import logging +import os +import sys +import time +from typing import BinaryIO, Callable, Generator, Iterable, Iterator, Optional, TextIO, Union +from types import FrameType, TracebackType + +from pyspark.logger.logger import JSONFormatter + + +class DelegatingTextIOWrapper(TextIO): + """A TextIO that delegates all operations to another TextIO object.""" + + def __init__(self, delegate: TextIO): + self._delegate = delegate + + # Required TextIO properties + @property + def encoding(self) -> str: + return self._delegate.encoding + + @property + def errors(self) -> Optional[str]: + return self._delegate.errors + + @property + def newlines(self) -> Optional[Union[str, tuple[str, ...]]]: + return self._delegate.newlines + + @property + def buffer(self) -> BinaryIO: + return self._delegate.buffer + + @property + def mode(self) -> str: + return self._delegate.mode + + @property + def name(self) -> str: + return self._delegate.name + + @property + def line_buffering(self) -> int: + return self._delegate.line_buffering + + @property + def closed(self) -> bool: + return self._delegate.closed + + # Iterator protocol + def __iter__(self) -> Iterator[str]: + return iter(self._delegate) + + def __next__(self) -> str: + return next(self._delegate) + + # Context manager protocol + def __enter__(self) -> TextIO: + return self._delegate.__enter__() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return self._delegate.__exit__(exc_type, exc_val, exc_tb) + + # Core I/O methods + def write(self, s: str) -> int: + return self._delegate.write(s) + + def writelines(self, lines: Iterable[str]) -> None: + return self._delegate.writelines(lines) + + def read(self, size: int = -1) -> str: + return self._delegate.read(size) + + def readline(self, size: int = -1) -> str: + return self._delegate.readline(size) + + def readlines(self, hint: int = -1) -> list[str]: + return self._delegate.readlines(hint) + + # Stream control methods + def close(self) -> None: + return self._delegate.close() + + def flush(self) -> None: + return self._delegate.flush() + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + return self._delegate.seek(offset, whence) + + def tell(self) -> int: + return self._delegate.tell() + + def truncate(self, size: Optional[int] = None) -> int: + return self._delegate.truncate(size) + + # Stream capability methods + def fileno(self) -> int: + return self._delegate.fileno() + + def isatty(self) -> bool: + return self._delegate.isatty() + + def readable(self) -> bool: + return self._delegate.readable() + + def seekable(self) -> bool: + return self._delegate.seekable() + + def writable(self) -> bool: + return self._delegate.writable() + + +class JSONFormatterWithMarker(JSONFormatter): + default_microsec_format = "%s.%06d" + + def __init__(self, marker: str, context_provider: Callable[[], dict[str, str]]): + super().__init__(ensure_ascii=True) + self._marker = marker + self._context_provider = context_provider + + def format(self, record: logging.LogRecord) -> str: + context = self._context_provider() + if context: + context.update(record.__dict__.get("context", {})) + record.__dict__["context"] = context + return f"{self._marker}:{os.getpid()}:{super().format(record)}" + + def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str: + ct = self.converter(record.created) + if datefmt: + s = time.strftime(datefmt, ct) + else: + s = time.strftime(self.default_time_format, ct) + if self.default_microsec_format: + s = self.default_microsec_format % ( + s, + int((record.created - int(record.created)) * 1000000), + ) + elif self.default_msec_format: + s = self.default_msec_format % (s, record.msecs) + return s + + +class JsonOutput(DelegatingTextIOWrapper): + def __init__( + self, + delegate: TextIO, + json_out: TextIO, + logger_name: str, + log_level: int, + marker: str, + context_provider: Callable[[], dict[str, str]], + ): + super().__init__(delegate) + self._json_out = json_out + self._logger_name = logger_name + self._log_level = log_level + self._formatter = JSONFormatterWithMarker(marker, context_provider) + + def write(self, s: str) -> int: + if s.strip(): + log_record = logging.LogRecord( + name=self._logger_name, + level=self._log_level, + pathname=None, # type: ignore[arg-type] + lineno=None, # type: ignore[arg-type] + msg=s.strip(), + args=None, + exc_info=None, + func=None, + sinfo=None, + ) + self._json_out.write(f"{self._formatter.format(log_record)}\n") + self._json_out.flush() + return self._delegate.write(s) + + def writelines(self, lines: Iterable[str]) -> None: + # Process each line through our JSON logging logic + for line in lines: + self.write(line) + + def close(self) -> None: + pass + + +def context_provider() -> dict[str, str]: + """ + Provides context information for logging, including caller function name. + Finds the function name from the bottom of the stack, ignoring Python builtin + libraries and PySpark modules. Test packages are included. + + Returns: + dict[str, str]: A dictionary containing context information including: + - func_name: Name of the function that initiated the logging + - class_name: Name of the class that initiated the logging if available Review Comment: do we consider add module name into context? ########## core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala: ########## @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{BufferedReader, InputStream, InputStreamReader} +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.storage.{PythonWorkerLogBlockIdGenerator, PythonWorkerLogLine, RollingLogWriter} + +/** + * Manages Python UDF log capture and routing to per-worker log writers. + * + * This class handles the parsing of Python worker output streams and routes + * log messages to appropriate rolling log writers based on worker PIDs. + * Works for both daemon and non-daemon modes. + */ +private[python] class PythonWorkerLogCapture( + sessionId: String, + logMarker: String = "PYTHON_WORKER_LOGGING") extends Logging { + + // Map to track per-worker log writers: workerId(PID) -> (writer, sequenceId) + private val workerLogWriters = new ConcurrentHashMap[String, (RollingLogWriter, AtomicLong)]() + + /** + * Creates an InputStream wrapper that captures Python UDF logs from the given stream. + * + * @param inputStream The input stream to wrap (typically daemon stdout or worker stdout) + * @return A wrapped InputStream that captures and routes log messages + */ + def wrapInputStream(inputStream: InputStream): InputStream = { + new CaptureWorkerLogsInputStream(inputStream) + } + + /** + * Removes and closes the log writer for a specific worker. + * + * @param workerId The worker ID (typically PID as string) + */ + def removeAndCloseWorkerLogWriter(workerId: String): Unit = { + Option(workerLogWriters.remove(workerId)).foreach { case (writer, _) => + try { + writer.close() + } catch { + case e: Exception => + logWarning(s"Failed to close log writer for worker $workerId", e) + } + } + } + + /** + * Closes all active worker log writers. + */ + def closeAllWriters(): Unit = { + workerLogWriters.values().asScala.foreach { case (writer, _) => + try { + writer.close() + } catch { + case e: Exception => + logWarning("Failed to close log writer", e) + } + } + workerLogWriters.clear() + } + + /** + * Gets or creates a log writer for the specified worker. + * + * @param workerId Unique identifier for the worker (typically PID) + * @return Tuple of (RollingLogWriter, AtomicLong sequence counter) + */ + private def getOrCreateLogWriter(workerId: String): (RollingLogWriter, AtomicLong) = { + workerLogWriters.computeIfAbsent(workerId, _ => { + val logWriter = SparkEnv.get.blockManager.getRollingLogWriter( + new PythonWorkerLogBlockIdGenerator(sessionId, workerId) + ) + (logWriter, new AtomicLong()) + }) + } + + /** + * Processes a log line from a Python worker. + * + * @param line The complete line containing the log marker and JSON + * @return The prefix (non-log content) that should be passed through + */ + private def processLogLine(line: String): String = { + val markerIndex = line.indexOf(s"$logMarker:") + if (markerIndex >= 0) { + val prefix = line.substring(0, markerIndex) + val markerAndJson = line.substring(markerIndex) + + // Parse: "PYTHON_UDF_LOGGING:12345:{json}" + val parts = markerAndJson.split(":", 3) + if (parts.length >= 3) { + val workerId = parts(1) // This is the PID from Python worker + val json = parts(2) + + try { + if (json.isEmpty) { + removeAndCloseWorkerLogWriter(workerId) + } else { + val (writer, seqId) = getOrCreateLogWriter(workerId) + writer.writeLog( + PythonWorkerLogLine(System.currentTimeMillis(), seqId.getAndIncrement(), json) Review Comment: I think it is better to use the original python log's generation time as the timestamp, here we are appending a new timestamp to it, that is not ground truth of when the log was generated right? ########## core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala: ########## @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{BufferedReader, InputStream, InputStreamReader} +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.storage.{PythonWorkerLogBlockIdGenerator, PythonWorkerLogLine, RollingLogWriter} + +/** + * Manages Python UDF log capture and routing to per-worker log writers. + * + * This class handles the parsing of Python worker output streams and routes + * log messages to appropriate rolling log writers based on worker PIDs. + * Works for both daemon and non-daemon modes. + */ +private[python] class PythonWorkerLogCapture( + sessionId: String, + logMarker: String = "PYTHON_WORKER_LOGGING") extends Logging { + + // Map to track per-worker log writers: workerId(PID) -> (writer, sequenceId) + private val workerLogWriters = new ConcurrentHashMap[String, (RollingLogWriter, AtomicLong)]() Review Comment: does each `PythonWorkerLogCapture` manages logs from one executor machine? I see you are using Pid as the key, so it does not work across multiple machines right? ########## python/pyspark/logger/worker_io.py: ########## @@ -0,0 +1,293 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from contextlib import contextmanager +import inspect +import io +import logging +import os +import sys +import time +from typing import BinaryIO, Callable, Generator, Iterable, Iterator, Optional, TextIO, Union +from types import FrameType, TracebackType + +from pyspark.logger.logger import JSONFormatter + + +class DelegatingTextIOWrapper(TextIO): + """A TextIO that delegates all operations to another TextIO object.""" + + def __init__(self, delegate: TextIO): + self._delegate = delegate + + # Required TextIO properties + @property + def encoding(self) -> str: + return self._delegate.encoding + + @property + def errors(self) -> Optional[str]: + return self._delegate.errors + + @property + def newlines(self) -> Optional[Union[str, tuple[str, ...]]]: + return self._delegate.newlines + + @property + def buffer(self) -> BinaryIO: + return self._delegate.buffer + + @property + def mode(self) -> str: + return self._delegate.mode + + @property + def name(self) -> str: + return self._delegate.name + + @property + def line_buffering(self) -> int: + return self._delegate.line_buffering + + @property + def closed(self) -> bool: + return self._delegate.closed + + # Iterator protocol + def __iter__(self) -> Iterator[str]: + return iter(self._delegate) + + def __next__(self) -> str: + return next(self._delegate) + + # Context manager protocol + def __enter__(self) -> TextIO: + return self._delegate.__enter__() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return self._delegate.__exit__(exc_type, exc_val, exc_tb) + + # Core I/O methods + def write(self, s: str) -> int: + return self._delegate.write(s) + + def writelines(self, lines: Iterable[str]) -> None: + return self._delegate.writelines(lines) + + def read(self, size: int = -1) -> str: + return self._delegate.read(size) + + def readline(self, size: int = -1) -> str: + return self._delegate.readline(size) + + def readlines(self, hint: int = -1) -> list[str]: + return self._delegate.readlines(hint) + + # Stream control methods + def close(self) -> None: + return self._delegate.close() + + def flush(self) -> None: + return self._delegate.flush() + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + return self._delegate.seek(offset, whence) + + def tell(self) -> int: + return self._delegate.tell() + + def truncate(self, size: Optional[int] = None) -> int: + return self._delegate.truncate(size) + + # Stream capability methods + def fileno(self) -> int: + return self._delegate.fileno() + + def isatty(self) -> bool: + return self._delegate.isatty() + + def readable(self) -> bool: + return self._delegate.readable() + + def seekable(self) -> bool: + return self._delegate.seekable() + + def writable(self) -> bool: + return self._delegate.writable() + + +class JSONFormatterWithMarker(JSONFormatter): + default_microsec_format = "%s.%06d" + + def __init__(self, marker: str, context_provider: Callable[[], dict[str, str]]): + super().__init__(ensure_ascii=True) + self._marker = marker + self._context_provider = context_provider + + def format(self, record: logging.LogRecord) -> str: + context = self._context_provider() + if context: + context.update(record.__dict__.get("context", {})) + record.__dict__["context"] = context + return f"{self._marker}:{os.getpid()}:{super().format(record)}" + + def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str: + ct = self.converter(record.created) + if datefmt: + s = time.strftime(datefmt, ct) + else: + s = time.strftime(self.default_time_format, ct) + if self.default_microsec_format: + s = self.default_microsec_format % ( + s, + int((record.created - int(record.created)) * 1000000), + ) + elif self.default_msec_format: + s = self.default_msec_format % (s, record.msecs) + return s + + +class JsonOutput(DelegatingTextIOWrapper): + def __init__( + self, + delegate: TextIO, + json_out: TextIO, + logger_name: str, + log_level: int, + marker: str, + context_provider: Callable[[], dict[str, str]], + ): + super().__init__(delegate) + self._json_out = json_out + self._logger_name = logger_name + self._log_level = log_level + self._formatter = JSONFormatterWithMarker(marker, context_provider) + + def write(self, s: str) -> int: + if s.strip(): + log_record = logging.LogRecord( + name=self._logger_name, + level=self._log_level, + pathname=None, # type: ignore[arg-type] + lineno=None, # type: ignore[arg-type] + msg=s.strip(), + args=None, + exc_info=None, + func=None, + sinfo=None, + ) + self._json_out.write(f"{self._formatter.format(log_record)}\n") + self._json_out.flush() + return self._delegate.write(s) + + def writelines(self, lines: Iterable[str]) -> None: + # Process each line through our JSON logging logic + for line in lines: + self.write(line) + + def close(self) -> None: + pass + + +def context_provider() -> dict[str, str]: + """ + Provides context information for logging, including caller function name. + Finds the function name from the bottom of the stack, ignoring Python builtin + libraries and PySpark modules. Test packages are included. + + Returns: + dict[str, str]: A dictionary containing context information including: + - func_name: Name of the function that initiated the logging + - class_name: Name of the class that initiated the logging if available + """ + + def is_pyspark_module(module_name: str) -> bool: + return module_name.startswith("pyspark.") and ".tests." not in module_name + + bottom: Optional[FrameType] = None + + # Get caller function information using inspect + try: + frame = inspect.currentframe() + is_in_pyspark_module = False + + if frame: + while frame.f_back: + f_back = frame.f_back + module_name = f_back.f_globals.get("__name__", "") + + if is_pyspark_module(module_name): + if not is_in_pyspark_module: + bottom = frame + is_in_pyspark_module = True + else: + is_in_pyspark_module = False Review Comment: if we ignore any frame in the stack that are within pyspark module, then we could break after finding any pyspark module right? if we ignore only the bottom frame that happens to be in pyspark module, this if/else check can be moved out of while? ########## python/pyspark/logger/worker_io.py: ########## @@ -0,0 +1,293 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from contextlib import contextmanager +import inspect +import io +import logging +import os +import sys +import time +from typing import BinaryIO, Callable, Generator, Iterable, Iterator, Optional, TextIO, Union +from types import FrameType, TracebackType + +from pyspark.logger.logger import JSONFormatter + + +class DelegatingTextIOWrapper(TextIO): Review Comment: qq: why don't we base on TextIO directly? ########## python/pyspark/logger/worker_io.py: ########## @@ -0,0 +1,293 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from contextlib import contextmanager +import inspect +import io +import logging +import os +import sys +import time +from typing import BinaryIO, Callable, Generator, Iterable, Iterator, Optional, TextIO, Union +from types import FrameType, TracebackType + +from pyspark.logger.logger import JSONFormatter + + +class DelegatingTextIOWrapper(TextIO): + """A TextIO that delegates all operations to another TextIO object.""" + + def __init__(self, delegate: TextIO): + self._delegate = delegate + + # Required TextIO properties + @property + def encoding(self) -> str: + return self._delegate.encoding + + @property + def errors(self) -> Optional[str]: + return self._delegate.errors + + @property + def newlines(self) -> Optional[Union[str, tuple[str, ...]]]: + return self._delegate.newlines + + @property + def buffer(self) -> BinaryIO: + return self._delegate.buffer + + @property + def mode(self) -> str: + return self._delegate.mode + + @property + def name(self) -> str: + return self._delegate.name + + @property + def line_buffering(self) -> int: + return self._delegate.line_buffering + + @property + def closed(self) -> bool: + return self._delegate.closed + + # Iterator protocol + def __iter__(self) -> Iterator[str]: + return iter(self._delegate) + + def __next__(self) -> str: + return next(self._delegate) + + # Context manager protocol + def __enter__(self) -> TextIO: + return self._delegate.__enter__() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return self._delegate.__exit__(exc_type, exc_val, exc_tb) + + # Core I/O methods + def write(self, s: str) -> int: + return self._delegate.write(s) + + def writelines(self, lines: Iterable[str]) -> None: + return self._delegate.writelines(lines) + + def read(self, size: int = -1) -> str: + return self._delegate.read(size) + + def readline(self, size: int = -1) -> str: + return self._delegate.readline(size) + + def readlines(self, hint: int = -1) -> list[str]: + return self._delegate.readlines(hint) + + # Stream control methods + def close(self) -> None: + return self._delegate.close() + + def flush(self) -> None: + return self._delegate.flush() + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + return self._delegate.seek(offset, whence) + + def tell(self) -> int: + return self._delegate.tell() + + def truncate(self, size: Optional[int] = None) -> int: + return self._delegate.truncate(size) + + # Stream capability methods + def fileno(self) -> int: + return self._delegate.fileno() + + def isatty(self) -> bool: + return self._delegate.isatty() + + def readable(self) -> bool: + return self._delegate.readable() + + def seekable(self) -> bool: + return self._delegate.seekable() + + def writable(self) -> bool: + return self._delegate.writable() + + +class JSONFormatterWithMarker(JSONFormatter): + default_microsec_format = "%s.%06d" + + def __init__(self, marker: str, context_provider: Callable[[], dict[str, str]]): + super().__init__(ensure_ascii=True) + self._marker = marker + self._context_provider = context_provider + + def format(self, record: logging.LogRecord) -> str: + context = self._context_provider() + if context: + context.update(record.__dict__.get("context", {})) + record.__dict__["context"] = context + return f"{self._marker}:{os.getpid()}:{super().format(record)}" + + def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str: + ct = self.converter(record.created) + if datefmt: + s = time.strftime(datefmt, ct) + else: + s = time.strftime(self.default_time_format, ct) + if self.default_microsec_format: + s = self.default_microsec_format % ( + s, + int((record.created - int(record.created)) * 1000000), + ) + elif self.default_msec_format: + s = self.default_msec_format % (s, record.msecs) + return s + + +class JsonOutput(DelegatingTextIOWrapper): + def __init__( + self, + delegate: TextIO, + json_out: TextIO, + logger_name: str, + log_level: int, + marker: str, + context_provider: Callable[[], dict[str, str]], + ): + super().__init__(delegate) + self._json_out = json_out + self._logger_name = logger_name + self._log_level = log_level + self._formatter = JSONFormatterWithMarker(marker, context_provider) + + def write(self, s: str) -> int: + if s.strip(): + log_record = logging.LogRecord( + name=self._logger_name, + level=self._log_level, + pathname=None, # type: ignore[arg-type] + lineno=None, # type: ignore[arg-type] + msg=s.strip(), + args=None, + exc_info=None, + func=None, + sinfo=None, + ) + self._json_out.write(f"{self._formatter.format(log_record)}\n") + self._json_out.flush() + return self._delegate.write(s) + + def writelines(self, lines: Iterable[str]) -> None: + # Process each line through our JSON logging logic + for line in lines: + self.write(line) + + def close(self) -> None: + pass Review Comment: is this intentional? it will cause the super's close not being called -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
