ashb commented on code in PR #43893:
URL: https://github.com/apache/airflow/pull/43893#discussion_r1838721785


##########
task_sdk/src/airflow/sdk/log.py:
##########
@@ -0,0 +1,377 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import itertools
+import logging.config
+import os
+import sys
+import warnings
+from functools import cache
+from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, 
TypeVar
+
+import msgspec
+import structlog
+
+if TYPE_CHECKING:
+    from structlog.typing import EventDict, ExcInfo, Processor
+
+
+__all__ = [
+    "configure_logging",
+    "reset_logging",
+]
+
+
+def exception_group_tracebacks(format_exception: Callable[[ExcInfo], 
list[dict[str, Any]]]) -> Processor:
+    # Make mypy happy
+    if not hasattr(__builtins__, "BaseExceptionGroup"):
+        T = TypeVar("T")
+
+        class BaseExceptionGroup(Generic[T]):
+            exceptions: list[T]
+
+    def _exception_group_tracebacks(logger: Any, method_name: Any, event_dict: 
EventDict) -> EventDict:
+        if exc_info := event_dict.get("exc_info", None):
+            group: BaseExceptionGroup[Exception] | None = None
+            if exc_info is True:
+                # `log.exception('mesg")` case
+                exc_info = sys.exc_info()
+                if exc_info[0] is None:
+                    exc_info = None
+
+            if (
+                isinstance(exc_info, tuple)
+                and len(exc_info) == 3
+                and isinstance(exc_info[1], BaseExceptionGroup)
+            ):
+                group = exc_info[1]
+            elif isinstance(exc_info, BaseExceptionGroup):
+                group = exc_info
+
+            if group:
+                # Only remove it from event_dict if we handle it
+                del event_dict["exc_info"]
+                event_dict["exception"] = list(
+                    itertools.chain.from_iterable(
+                        format_exception((type(exc), exc, exc.__traceback__))  
# type: ignore[attr-defined,arg-type]
+                        for exc in (*group.exceptions, group)
+                    )
+                )
+
+        return event_dict
+
+    return _exception_group_tracebacks
+
+
+def logger_name(logger: Any, method_name: Any, event_dict: EventDict) -> 
EventDict:
+    if logger_name := event_dict.pop("logger_name", None):
+        event_dict.setdefault("logger", logger_name)
+    return event_dict
+
+
+def redact_jwt(logger: Any, method_name: str, event_dict: EventDict) -> 
EventDict:
+    for k, v in event_dict.items():
+        if isinstance(v, str) and v.startswith("eyJ"):
+            event_dict[k] = "eyJ***"
+    return event_dict
+
+
+def drop_positional_args(logger: Any, method_name: Any, event_dict: EventDict) 
-> EventDict:
+    event_dict.pop("positional_args", None)
+    return event_dict
+
+
+def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> 
str:
+    """Encode event into JSON format."""
+    return msgspec.json.encode(event_dict).decode("ascii")
+
+
+class StdBinaryStreamHandler(logging.StreamHandler):
+    """A logging.StreamHandler that sends logs as binary JSON over the given 
stream."""
+
+    stream: BinaryIO
+
+    def __init__(self, stream: BinaryIO):
+        super().__init__(stream)
+
+    def emit(self, record: logging.LogRecord):
+        try:
+            msg = self.format(record)
+            buffer = bytearray(msg, "ascii", "backslashreplace")
+
+            buffer += b"\n"
+
+            stream = self.stream
+            stream.write(buffer)
+            self.flush()
+        except RecursionError:  # See issue 36272
+            raise
+        except Exception:
+            self.handleError(record)
+
+
+@cache
+def logging_processors(
+    enable_pretty_log: bool,
+):
+    if enable_pretty_log:
+        timestamper = structlog.processors.MaybeTimeStamper(fmt="%Y-%m-%d 
%H:%M:%S.%f")
+    else:
+        timestamper = structlog.processors.MaybeTimeStamper(fmt="iso")
+
+    processors: list[structlog.typing.Processor] = [
+        timestamper,
+        structlog.contextvars.merge_contextvars,
+        structlog.processors.add_log_level,
+        structlog.stdlib.PositionalArgumentsFormatter(),
+        logger_name,
+        redact_jwt,
+        structlog.processors.StackInfoRenderer(),
+    ]
+
+    if enable_pretty_log:
+        # Imports to suppress showing code from these modules
+        import asyncio
+        import contextlib
+
+        import click
+        import httpcore
+        import httpx
+        import typer
+
+        rich_exc_formatter = structlog.dev.RichTracebackFormatter(
+            extra_lines=0,
+            max_frames=30,

Review Comment:
   Likely not! Picked largely at random. `extra_lines=0` was to make the stack 
traces smaller (as otherwise they get huge fast). max frames was an arbitrary 
choice



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to