This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch localexecutor-uses-task-sdk
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit c22e1ce1e3433bf9b8a458636361b66440f0e38b
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Nov 20 17:49:45 2024 +0000

    WIP: LocalExecutor runs task using new Task SDK supervisor
---
 airflow/config_templates/config.yml                |   2 +-
 airflow/executors/base_executor.py                 |  61 ++++++++++-
 airflow/executors/local_executor.py                |  82 +++++++++------
 airflow/jobs/scheduler_job_runner.py               |   8 +-
 airflow/utils/helpers.py                           |   2 +-
 task_sdk/pyproject.toml                            |   1 +
 .../src/airflow/sdk/execution_time/supervisor.py   | 112 +++++++++++++++++----
 .../src/airflow/sdk/execution_time/task_runner.py  |  12 ++-
 task_sdk/src/airflow/sdk/log.py                    |  18 ++--
 9 files changed, 234 insertions(+), 64 deletions(-)

diff --git a/airflow/config_templates/config.yml 
b/airflow/config_templates/config.yml
index 4a7fe5f1897..c7a4ce159cc 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -908,7 +908,7 @@ logging:
       is_template: true
       default: "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ 
ti.task_id }}/\
                {%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% 
endif %%}\
-               attempt={{ try_number }}.log"
+               attempt={{ ti.try_number }}.log"
     log_processor_filename_template:
       description: |
         Formatting for how airflow generates file names for log
diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index d24fd4f1516..bab5642fa0b 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -19,13 +19,17 @@
 from __future__ import annotations
 
 import logging
+import os
 import sys
 from collections import defaultdict, deque
 from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple
+from enum import Enum, auto
+from functools import cache
+from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, 
Sequence, Tuple, overload
 
 import pendulum
 from deprecated import deprecated
+from pydantic import BaseModel
 
 from airflow.cli.cli_config import DefaultHelpParser
 from airflow.configuration import conf
@@ -70,6 +74,9 @@ if TYPE_CHECKING:
     # Task tuple to send to be executed
     TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], 
Optional[Any]]
 
+# This feels wrong importing from the fastapi app here!!!!
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TaskInstance as SdkTI
+
 log = logging.getLogger(__name__)
 
 
@@ -106,6 +113,50 @@ class RunningRetryAttemptType:
         return True
 
 
+class ExecuteTaskActivity(BaseModel):  # noqa: D101
+    ti: SdkTI
+    dag_path: os.PathLike[str]
+    token: str
+    """The identity token for this workload"""
+
+    log_filename_suffix: str
+
+    @classmethod
+    def from_ti(cls, ti: TaskInstance) -> ExecuteTaskActivity:
+        from pathlib import Path
+
+        serde_ti = SdkTI.model_validate(ti, from_attributes=True)
+        path = Path(ti.dag_run.dag_model.relative_fileloc)
+
+        if path and not path.is_absolute():
+            path = "DAGS_FOLDER" / path
+
+        # Quasi-breaking change for format strings. Lets just deprecate those
+        fname = cls.log_filename_template()(ti=ti, try_number=ti.try_number)
+        return ExecuteTaskActivity(ti=serde_ti, dag_path=path, token="", 
log_filename_suffix=fname)
+
+    # Parse the string once, and cache the rendered template
+    @classmethod
+    @cache
+    def log_filename_template(cls) -> Callable[..., str]:
+        # TODO: We should probably use the version from 
ti.dag_run.log_template_id
+
+        # For now we always return the active log template
+        template = conf.get("logging", "log_filename_template")
+
+        if "{{" in template:
+            import jinja2
+
+            return jinja2.Template(template).render
+        else:
+            return template.format
+
+
+class ExecuteActivityKind(Enum):  # noqa: D101
+    TASK_INSTANCE = auto()
+    DAG_CALLBACK = auto()
+
+
 class BaseExecutor(LoggingMixin):
     """
     Base class to inherit for concrete executors such as Celery, Kubernetes, 
Local, Sequential, etc.
@@ -169,6 +220,14 @@ class BaseExecutor(LoggingMixin):
         else:
             self.log.error("could not queue task %s", task_instance.key)
 
+    @overload
+    def queue_activity(self, *, kind: 
Literal[ExecuteActivityKind.TASK_INSTANCE], context: TaskInstance): ...
+    @overload
+    def queue_activity(self, *, kind: 
Literal[ExecuteActivityKind.DAG_CALLBACK], context: dict[str, Any]): ...
+
+    def queue_activity(self, *, kind: ExecuteActivityKind, context: Any):
+        raise ValueError(f"Un-handled activity kind {kind!r} in 
{type(self).__name__}")
+
     def queue_task_instance(
         self,
         task_instance: TaskInstance,
diff --git a/airflow/executors/local_executor.py 
b/airflow/executors/local_executor.py
index 3b8b52176db..c4cd260d336 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -32,14 +32,18 @@ import multiprocessing.sharedctypes
 import os
 import subprocess
 from multiprocessing import Queue, SimpleQueue
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, overload
 
 from setproctitle import setproctitle
 
 from airflow import settings
-from airflow.executors.base_executor import PARALLELISM, BaseExecutor
+from airflow.executors.base_executor import (
+    PARALLELISM,
+    BaseExecutor,
+    ExecuteActivityKind,
+    ExecuteTaskActivity,
+)
 from airflow.traces.tracer import add_span
-from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
 from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
@@ -55,26 +59,29 @@ if TYPE_CHECKING:
 
 def _run_worker(
     logger_name: str,
-    input: SimpleQueue[ExecutorWorkType],
+    input: SimpleQueue[ExecuteTaskActivity | None],
     output: Queue[TaskInstanceStateType],
     unread_messages: multiprocessing.sharedctypes.Synchronized[int],
 ):
+    import os
     import signal
 
+    os.environ["PUDB_TTY"] = "/dev/ttys006"
+
     # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we 
let tasks run to completion
     signal.signal(signal.SIGINT, signal.SIG_IGN)
 
     log = logging.getLogger(logger_name)
+    log.info("Worker starting up pid=%d", os.getpid())
 
     # We know we've just started a new process, so lets disconnect from the 
metadata db now
     settings.engine.pool.dispose()
     settings.engine.dispose()
 
-    setproctitle("airflow worker -- LocalExecutor: <idle>")
-
     while True:
+        setproctitle("airflow worker -- LocalExecutor: <idle>")
         try:
-            item = input.get()
+            activity = input.get()
         except EOFError:
             log.info(
                 "Failed to read tasks from the task queue because the other "
@@ -83,7 +90,7 @@ def _run_worker(
             )
             break
 
-        if item is None:
+        if activity is None:
             # Received poison pill, no more tasks to run
             return
 
@@ -91,33 +98,33 @@ def _run_worker(
         with unread_messages:
             unread_messages.value -= 1
 
-        (key, command) = item
         try:
-            state = _execute_work(log, key, command)
+            state = _execute_work(log, activity)
 
-            output.put((key, state, None))
+            # output.put((key, state, None))
+            output.put((None, TaskInstanceState.SUCCESS, None))
         except Exception as e:
-            output.put((key, TaskInstanceState.FAILED, e))
+            log.exception("uhoh")
+            output.put((None, TaskInstanceState.FAILED, e))
 
 
-def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: 
CommandType) -> TaskInstanceState:
+def _execute_work(log: logging.Logger, activity: ExecuteTaskActivity) -> None:
     """
     Execute command received and stores result state in queue.
 
     :param key: the key to identify the task instance
     :param command: the command to execute
     """
-    setproctitle(f"airflow worker -- LocalExecutor: {command}")
-    dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command)
-    try:
-        with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id):
-            if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
-                return _execute_work_in_subprocess(log, command)
-            else:
-                return _execute_work_in_fork(log, command)
-    finally:
-        # Remove the command since the worker is done executing the task
-        setproctitle("airflow worker -- LocalExecutor: <idle>")
+    from airflow.sdk.execution_time.supervisor import supervise
+
+    setproctitle(f"airflow worker -- LocalExecutor: {activity.ti.id}")
+    supervise(
+        ti=activity.ti,
+        dag_path=activity.dag_path,
+        token=activity.token,
+        server="http://localhost:9091/execution/";,
+        log_filename_suffix=activity.log_filename_suffix,
+    )
 
 
 def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> 
TaskInstanceState:
@@ -180,7 +187,7 @@ class LocalExecutor(BaseExecutor):
 
     serve_logs: bool = True
 
-    activity_queue: SimpleQueue[ExecutorWorkType]
+    activity_queue: SimpleQueue[ExecuteTaskActivity | None]
     result_queue: SimpleQueue[TaskInstanceStateType]
     workers: dict[int, multiprocessing.Process]
     _unread_messages: multiprocessing.sharedctypes.Synchronized[int]
@@ -216,9 +223,9 @@ class LocalExecutor(BaseExecutor):
         self.activity_queue.put((key, command))
         with self._unread_messages:
             self._unread_messages.value += 1
-        self._check_workers(can_start=True)
+        self._check_workers()
 
-    def _check_workers(self, can_start: bool = True):
+    def _check_workers(self):
         # Reap any dead workers
         to_remove = set()
         for pid, proc in self.workers.items():
@@ -276,7 +283,7 @@ class LocalExecutor(BaseExecutor):
                     exc.add_note("(This stacktrace is incorrect -- the 
exception came from a subprocess)")
                 raise exc
 
-            self.change_state(key, state)
+            # self.change_state(key, state)
 
     def end(self) -> None:
         """End the executor."""
@@ -306,3 +313,22 @@ class LocalExecutor(BaseExecutor):
 
     def terminate(self):
         """Terminate the executor is not doing anything."""
+
+
+class TooLocalTooAPIExecutor(LocalExecutor):
+    """TooLocalTooAPIExecutor implementation to run airflow commands."""
+
+    @overload
+    def queue_activity(self, *, kind: 
Literal[ExecuteActivityKind.TASK_INSTANCE], context: TaskInstance): ...
+    @overload
+    def queue_activity(self, *, kind: 
Literal[ExecuteActivityKind.DAG_CALLBACK], context: dict[str, Any]): ...
+
+    def queue_activity(self, *, kind: ExecuteActivityKind, context: 
TaskInstance | dict[str, Any]):
+        if kind != ExecuteActivityKind.TASK_INSTANCE:
+            return super().queue_activity(kind, context)
+
+        activity = ExecuteTaskActivity.from_ti(context)
+        self.activity_queue.put(activity)
+        with self._unread_messages:
+            self._unread_messages.value += 1
+        self._check_workers()
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 96a09eb99bb..63d02060105 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -43,6 +43,7 @@ from airflow.callbacks.callback_requests import 
DagCallbackRequest, TaskCallback
 from airflow.callbacks.pipe_callback_sink import PipeCallbackSink
 from airflow.configuration import conf
 from airflow.exceptions import RemovedInAirflow3Warning, 
UnknownExecutorException
+from airflow.executors.base_executor import BaseExecutor, ExecuteActivityKind
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.base_job_runner import BaseJobRunner
 from airflow.jobs.job import Job, perform_heartbeat
@@ -90,7 +91,6 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Query, Session
 
     from airflow.dag_processing.manager import DagFileProcessorAgent
-    from airflow.executors.base_executor import BaseExecutor
     from airflow.executors.executor_utils import ExecutorName
     from airflow.models.taskinstance import TaskInstanceKey
     from airflow.utils.sqlalchemy import (
@@ -645,6 +645,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             if ti.dag_run.state in State.finished_dr_states:
                 ti.set_state(None, session=session)
                 continue
+
+            # Has a real queue_activity implemented
+            if executor.queue_activity is not BaseExecutor.queue_activity:
+                
executor.queue_activity(kind=ExecuteActivityKind.TASK_INSTANCE, context=ti)
+                continue
+
             command = ti.command_as_list(
                 local=True,
             )
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 1a6b7396e32..73598123d4d 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -170,7 +170,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> 
list[T]:
     return [e for i in iterable for e in i]
 
 
-def parse_template_string(template_string: str) -> tuple[str | None, 
jinja2.Template | None]:
+def parse_template_string(template_string: str) -> tuple[str, None] | 
tuple[None, jinja2.Template]:
     """Parse Jinja template string."""
     import jinja2
 
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index 5da673a79bf..e338ecb9e0c 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -25,6 +25,7 @@ dependencies = [
     "attrs>=24.2.0",
     "google-re2>=1.1.20240702",
     "httpx>=0.27.0",
+    "jinja2>=3.1.4",
     "methodtools>=0.4.7",
     "msgspec>=0.18.6",
     "psutil>=6.1.0",
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 1d72f7be633..85eb86f43ff 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -31,8 +31,9 @@ import weakref
 from collections.abc import Generator
 from contextlib import suppress
 from datetime import datetime, timezone
+from pathlib import Path
 from socket import socket, socketpair
-from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, 
NoReturn, cast, overload
+from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Literal, 
NoReturn, cast, overload
 from uuid import UUID
 
 import attrs
@@ -52,9 +53,8 @@ from airflow.sdk.execution_time.comms import (
 )
 
 if TYPE_CHECKING:
-    from structlog.typing import FilteringBoundLogger
+    from structlog.typing import FilteringBoundLogger, WrappedLogger
 
-    from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
 
 __all__ = ["WatchedSubprocess", "supervise"]
 
@@ -140,8 +140,8 @@ def _reopen_std_io_handles(child_stdin, child_stdout, 
child_stderr):
         try:
             fd = handle.fileno()
             os.dup2(sock.fileno(), fd)
-            if close:
-                handle.close()
+            # We now have two open copies of the socket, we need to close one!
+            sock.close()
         except io.UnsupportedOperation:
             if "PYTEST_CURRENT_TEST" in os.environ:
                 # When we're running under pytest, the stdin is not a real 
filehandle with an fd, so we need
@@ -269,6 +269,7 @@ class WatchedSubprocess:
         ti: TaskInstance,
         client: Client,
         target: Callable[[], None] = _subprocess_main,
+        logger: FilteringBoundLogger | None = None,
     ) -> WatchedSubprocess:
         """Fork and start a new subprocess to execute the given task."""
         # Create socketpairs/"pipes" to connect to the stdin and out from the 
subprocess
@@ -284,6 +285,10 @@ class WatchedSubprocess:
         if pid == 0:
             # Parent ends of the sockets are closed by the OS as they are set 
as non-inheritable
 
+            del path
+            del client
+            del ti
+
             # Run the child entrypoint
             _fork_main(child_stdin, child_stdout, child_stderr, 
child_logs.fileno(), target)
 
@@ -308,17 +313,20 @@ class WatchedSubprocess:
             proc.kill(signal.SIGKILL)
             raise
 
-        proc._register_pipes(read_msgs, read_logs)
+        if logger is None:
+            logger = structlog.get_logger(logger_name="task").bind()
+
+        proc._register_pipes(logger, read_msgs, read_logs)
 
         # Close the remaining parent-end of the sockets we've passed to the 
child via fork. We still have the
         # other end of the pair open
-        proc._close_unused_sockets(child_stdout, child_stdin, child_comms, 
child_logs)
+        proc._close_unused_sockets(child_stdout, child_stdin, child_stderr, 
child_comms, child_logs)
 
         # Tell the task process what it needs to do!
         proc._send_startup_message(ti, path, child_comms)
         return proc
 
-    def _register_pipes(self, read_msgs, read_logs):
+    def _register_pipes(self, logger: FilteringBoundLogger, read_msgs, 
read_logs):
         """Register handlers for subprocess communication channels."""
         # self.selector is a way of registering a handler/callback to be 
called when the given IO channel has
         # activity to read on 
(https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better
@@ -326,23 +334,24 @@ class WatchedSubprocess:
         # needing full async, to read and process output from each socket as 
it is received.
 
         # TODO: Use logging providers to handle the chunked upload for us
-        logger: FilteringBoundLogger = 
structlog.get_logger(logger_name="task").bind()
 
         self.selector.register(
-            self.stdout, selectors.EVENT_READ, 
self._create_socket_handler(logger, "stdout")
+            self.stdout, selectors.EVENT_READ, 
(self._create_socket_handler(logger, "stdout"), "stdout")
         )
         self.selector.register(
             self.stderr,
             selectors.EVENT_READ,
-            self._create_socket_handler(logger, "stderr", 
log_level=logging.ERROR),
+            (self._create_socket_handler(logger, "stderr", 
log_level=logging.ERROR), "stderr"),
         )
         self.selector.register(
             read_logs,
             selectors.EVENT_READ,
-            
make_buffered_socket_reader(process_log_messages_from_subprocess(logger)),
+            
(make_buffered_socket_reader(process_log_messages_from_subprocess(logger)), 
"read_logs"),
         )
         self.selector.register(
-            read_msgs, selectors.EVENT_READ, 
make_buffered_socket_reader(self.handle_requests(log))
+            read_msgs,
+            selectors.EVENT_READ,
+            (make_buffered_socket_reader(self.handle_requests(log)), 
"read_msgs"),
         )
 
     @staticmethod
@@ -358,7 +367,7 @@ class WatchedSubprocess:
 
     def _send_startup_message(self, ti: TaskInstance, path: str | 
os.PathLike[str], child_comms: socket):
         """Send startup message to the subprocess."""
-        msg = StartupDetails(
+        msg = StartupDetails.model_construct(
             ti=ti,
             file=str(path),
             requests_fd=child_comms.fileno(),
@@ -422,7 +431,7 @@ class WatchedSubprocess:
             )
             events = self.selector.select(timeout=max_wait_time)
             for key, _ in events:
-                socket_handler = key.data
+                (socket_handler, _) = key.data
                 need_more = socket_handler(key.fileobj)
 
                 if not need_more:
@@ -606,25 +615,86 @@ def forward_to_log(target_log: FilteringBoundLogger, 
level: int) -> Generator[No
             target_log.log(level, msg)
 
 
-def supervise(activity: ExecuteTaskActivity, server: str | None = None, 
dry_run: bool = False) -> int:
+def _init_log_file(local_relative_path: str) -> Path:
+    """
+    Create log directory and give it permissions that are configured.
+
+    See above _prepare_log_folder method for more detailed explanation.
+
+    :param ti: task instance object
+    :return: relative log path of the given task instance
+    """
+    from airflow.configuration import conf
+
+    new_file_permissions = int(
+        conf.get("logging", "file_task_handler_new_file_permissions", 
fallback="0o664"), 8
+    )
+    new_folder_permissions = int(
+        conf.get("logging", "file_task_handler_new_folder_permissions", 
fallback="0o775"), 8
+    )
+
+    base_log_folder = conf.get("logging", "base_log_folder")
+    full_path = Path(base_log_folder, local_relative_path)
+
+    directory = full_path.parent
+    for parent in reversed(directory.parents):
+        parent.mkdir(mode=new_folder_permissions, exist_ok=True)
+    directory.mkdir(mode=new_folder_permissions, exist_ok=True)
+
+    try:
+        full_path.touch(new_file_permissions)
+    except OSError as e:
+        log.warning("OSError while changing ownership of the log file. ", e)
+
+    return full_path
+
+
+def supervise(
+    *,
+    ti: Any,
+    dag_path: str | os.PathLike[str],
+    token: str,
+    server: str | None = None,
+    dry_run: bool = False,
+    log_filename_suffix: str | None = None,
+) -> int:
     """
     Run a single task execution to completion.
 
     Returns the exit code of the process
     """
     # One or the other
-    if (server == "") ^ dry_run:
+    if (not server) ^ dry_run:
         raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
 
-    if not activity.path:
-        raise ValueError("path filed of activity missing")
+    if not dag_path:
+        raise ValueError("dag_path is required")
+
+    if (str_path := os.fspath(dag_path)).startswith("DAGS_FOLDER/"):
+        from airflow.settings import DAGS_FOLDER
+
+        dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1)
 
     limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
-    client = Client(base_url=server or "", limits=limits, dry_run=dry_run, 
token=activity.token)
+    client = Client(base_url=server or "", limits=limits, dry_run=dry_run, 
token=token)
 
     start = time.monotonic()
 
-    process = WatchedSubprocess.start(activity.path, activity.ti, 
client=client)
+    logger: FilteringBoundLogger | None = None
+    if log_filename_suffix:
+        from airflow.sdk.log import logging_processors
+
+        log_file = _init_log_file(log_filename_suffix)
+
+        pretty_logs = False
+        if pretty_logs:
+            underlying_logger: WrappedLogger = 
structlog.WriteLogger(log_file.open("w", buffering=1))
+        else:
+            underlying_logger = structlog.BytesLogger(log_file.open("wb"))
+        processors = logging_processors(enable_pretty_log=pretty_logs)[0]
+        logger = structlog.wrap_logger(underlying_logger, 
processors=processors, logger_name="task").bind()
+
+    process = WatchedSubprocess.start(dag_path, ti, client=client, 
logger=logger)
 
     exit_code = process.wait()
     end = time.monotonic()
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index e00efe4597b..a0c12b79984 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -22,11 +22,11 @@ from __future__ import annotations
 import os
 import sys
 from io import FileIO
-from typing import TYPE_CHECKING, TextIO
+from typing import TYPE_CHECKING, Annotated, TextIO
 
 import attrs
 import structlog
-from pydantic import ConfigDict, TypeAdapter
+from pydantic import ConfigDict, SkipValidation, TypeAdapter
 
 from airflow.sdk.api.datamodels._generated import TaskInstance
 from airflow.sdk.definitions.baseoperator import BaseOperator
@@ -37,9 +37,9 @@ if TYPE_CHECKING:
 
 
 class RuntimeTaskInstance(TaskInstance):
-    model_config = ConfigDict(arbitrary_types_allowed=True)
+    model_config = ConfigDict(arbitrary_types_allowed=True, defer_build=True)
 
-    task: BaseOperator
+    task: Annotated[BaseOperator, SkipValidation]
 
 
 def parse(what: StartupDetails) -> RuntimeTaskInstance:
@@ -118,6 +118,10 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
     msg = SUPERVISOR_COMMS.get_message()
 
     if isinstance(msg, StartupDetails):
+        from setproctitle import setproctitle
+
+        setproctitle(f"airflow worker -- {msg.ti.id}")
+
         log = structlog.get_logger(logger_name="task")
         # TODO: set the "magic loop" context vars for parsing
         ti = parse(msg)
diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py
index f8e06eda4a6..b120f32a594 100644
--- a/task_sdk/src/airflow/sdk/log.py
+++ b/task_sdk/src/airflow/sdk/log.py
@@ -285,9 +285,11 @@ def configure_logging(
         std_lib_formatter.append(named["json"])
 
     global _warnings_showwarning
-    _warnings_showwarning = warnings.showwarning
-    # Capture warnings and show them via structlog
-    warnings.showwarning = _showwarning
+
+    if _warnings_showwarning is None:
+        _warnings_showwarning = warnings.showwarning
+        # Capture warnings and show them via structlog
+        warnings.showwarning = _showwarning
 
     logging.config.dictConfig(
         {
@@ -327,11 +329,13 @@ def configure_logging(
                     "propagate": True,
                 },
                 # Some modules we _never_ want at debug level
-                "asyncio": {"level": "INFO"},
                 "alembic": {"level": "INFO"},
+                "asyncio": {"level": "INFO"},
+                "cron_descriptor.GetText": {"level": "INFO"},
                 "httpcore": {"level": "INFO"},
                 "httpx": {"level": "WARN"},
                 "psycopg.pq": {"level": "INFO"},
+                # These ones are too chatty even at info
                 "sqlalchemy.engine": {"level": "WARN"},
             },
         }
@@ -344,17 +348,17 @@ def reset_logging():
     configure_logging.cache_clear()
 
 
-_warnings_showwarning = None
+_warnings_showwarning: Any = None
 
 
 def _showwarning(
-    message: str | Warning,
+    message: Warning | str,
     category: type[Warning],
     filename: str,
     lineno: int,
     file: TextIO | None = None,
     line: str | None = None,
-):
+) -> Any:
     """
     Redirects warnings to structlog so they appear in task logs etc.
 

Reply via email to