This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch execution-time-code-in-task-sdk
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit fa4d908515ac90f7a65bb00acc2606413eeceb3a
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Nov 8 15:08:23 2024 +0000

    WIP
---
 airflow/utils/net.py                               |   4 +-
 task_sdk/pyproject.toml                            |  12 +-
 task_sdk/src/airflow/sdk/api/client.py             | 143 +++++++++++++++++++++
 .../src/airflow/sdk/api/datamodels/activities.py   |  31 +++++
 task_sdk/src/airflow/sdk/api/datamodels/dagrun.py  |  25 ++++
 task_sdk/src/airflow/sdk/api/datamodels/ti.py      |  33 +++++
 task_sdk/src/airflow/sdk/execution_time/comms.py   |  26 +---
 .../src/airflow/sdk/execution_time/supervisor.py   | 108 +++++++++++++---
 .../src/airflow/sdk/execution_time/task_runner.py  |   6 -
 task_sdk/src/airflow/sdk/log.py                    |   7 +-
 task_sdk/src/airflow/sdk/types.py                  |   2 +-
 task_sdk/tests/defintions/test_baseoperator.py     |  19 +++
 task_sdk/tests/execution_time/test_supervisor.py   |   8 +-
 task_sdk/tests/execution_time/test_task_runner.py  |  17 ++-
 14 files changed, 374 insertions(+), 67 deletions(-)

diff --git a/airflow/utils/net.py b/airflow/utils/net.py
index 992aee67e80..9fc79b3842c 100644
--- a/airflow/utils/net.py
+++ b/airflow/utils/net.py
@@ -20,8 +20,6 @@ from __future__ import annotations
 import socket
 from functools import lru_cache
 
-from airflow.configuration import conf
-
 
 # patched version of socket.getfqdn() - see 
https://github.com/python/cpython/issues/49254
 @lru_cache(maxsize=None)
@@ -53,4 +51,6 @@ def get_host_ip_address():
 
 def get_hostname():
     """Fetch the hostname using the callable from config or use 
`airflow.utils.net.getfqdn` as a fallback."""
+    from airflow.configuration import conf
+
     return conf.getimport("core", "hostname_callable", 
fallback="airflow.utils.net.getfqdn")()
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index 5dd4abb3f30..f2aa138c14a 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -24,6 +24,7 @@ requires-python = ">=3.9, <3.13"
 dependencies = [
     "attrs>=24.2.0",
     "google-re2>=1.1.20240702",
+    "httpx>=0.27.0",
     "methodtools>=0.4.7",
     "msgspec>=0.18.6",
     "psutil>=6.1.0",
@@ -47,10 +48,17 @@ namespace-packages = ["src/airflow"]
 # Ignore Doc rules et al for anything outside of tests
 "!src/*" = ["D", "TID253", "S101", "TRY002"]
 
+# Only have pytest rules in tests - 
https://github.com/astral-sh/ruff/issues/14205
+"!tests/*" = ["PT"]
 
 "src/airflow/sdk/__init__.py" = ["TCH004"]
-# This is not part of the public API, so disable some of the doc requirements
-"src/airflow/sdk/execution_time/*" = ["D101"]
+
+# msgspec needs types for annotations to be defined, even with future
+# annotations, so disable the "type check only import" for these files
+"src/airflow/sdk/api/datamodels/*.py" = ["TCH001"]
+
+# Only the public API should _require_ docstrings on classes
+"!src/airflow/sdk/definitions/*" = ["D101"]
 
 # Generated file, be less strict
 "src/airflow/sdk/*/_generated.py" = ["D"]
diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
new file mode 100644
index 00000000000..00868d050a5
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING, Any
+
+import httpx
+import methodtools
+import msgspec
+import structlog
+from uuid6 import uuid7
+
+from airflow.sdk.api.datamodels._generated import (
+    State1 as TerminalState,
+    TaskInstanceState,
+    TIEnterRunningPayload,
+    TITerminalStatePayload,
+)
+from airflow.utils.net import get_hostname
+from airflow.utils.platform import getuser
+
+if TYPE_CHECKING:
+    from datetime import datetime
+
+log = structlog.get_logger(logger_name=__name__)
+
+__all__ = [
+    "Client",
+    "TaskInstanceOperations",
+]
+
+
+def raise_on_4xx_5xx(response: httpx.Response):
+    return response.raise_for_status()
+
+
+# Py 3.11+ version
+def raise_on_4xx_5xx_with_note(response: httpx.Response):
+    try:
+        return response.raise_for_status()
+    except httpx.HTTPStatusError as e:
+        if TYPE_CHECKING:
+            assert hasattr(e, "add_note")
+        e.add_note(
+            f"Correlation-id={response.headers.get('correlation-id', None) or 
response.request.headers.get('correlation-id', 'no-correlction-id')}"
+        )
+        raise
+
+
+if hasattr(BaseException, "add_note"):
+    # Py 3.11+
+    raise_on_4xx_5xx = raise_on_4xx_5xx_with_note
+
+
+def add_correlation_id(request: httpx.Request):
+    request.headers["correlation-id"] = str(uuid7())
+
+
+class TaskInstanceOperations:
+    __slots__ = ("client",)
+
+    def __init__(self, client: Client):
+        self.client = client
+
+    def start(self, id: uuid.UUID, pid: int, when: datetime):
+        """Tell the API server that this TI has started running."""
+        body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), 
unixname=getuser(), start_date=when)
+
+        self.client.patch(f"task_instance/{id}/state", 
content=self.client.encoder.encode(body))
+
+    def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime):
+        """Tell the API server that this TI has reached a terminal state."""
+        body = TITerminalStatePayload(end_date=when, 
state=TerminalState(state))
+
+        self.client.patch(f"task_instance/{id}/state", 
content=self.client.encoder.encode(body))
+
+    def heartbeat(self, id: uuid.UUID):
+        self.client.put(f"task_instance/{id}/heartbeat")
+
+
+class BearerAuth(httpx.Auth):
+    def __init__(self, token: str):
+        self.token: str = token
+
+    def auth_flow(self, request: httpx.Request):
+        if self.token:
+            request.headers["Authorization"] = "Bearer " + self.token
+        yield request
+
+
+def noop_handler(request: httpx.Request) -> httpx.Response:
+    log.debug("Dry-run request", method=request.method, path=request.url.path)
+    return httpx.Response(200, json={"text": "Hello, world!"})
+
+
+class Client(httpx.Client):
+    encoder: msgspec.json.Encoder
+
+    def __init__(self, *, base_url: str | None, dry_run: bool = False, token: 
str, **kwargs: Any):
+        if (not base_url) ^ dry_run:
+            raise ValueError(f"Can only specify one of {base_url=} or 
{dry_run=}")
+        auth = BearerAuth(token)
+
+        self.encoder = msgspec.json.Encoder()
+        if dry_run:
+            # If dry run is requests, install a no op handler so that simple 
tasks can "heartbeat" using a
+            # real client, but just don't make any HTTP requests
+            kwargs["transport"] = httpx.MockTransport(noop_handler)
+            kwargs["base_url"] = "dry-run://server"
+        else:
+            kwargs["base_url"] = base_url
+        super().__init__(
+            auth=auth,
+            headers={"airflow-api-version": "2024-07-30"},
+            event_hooks={"response": [raise_on_4xx_5xx], "request": 
[add_correlation_id]},
+            **kwargs,
+        )
+
+    # We "group" or "namespace" operations by what they operate on, rather 
than a flat namespace with all
+    # methods on one object prefixed with the object type 
(`.task_instances.update` rather than
+    # `task_instance_update` etc.)
+
+    @methodtools.lru_cache()  # type: ignore[misc]
+    @property
+    def task_instances(self) -> TaskInstanceOperations:
+        """Operations related to TaskInstances."""
+        return TaskInstanceOperations(self)
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/activities.py 
b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
new file mode 100644
index 00000000000..17b15d7c017
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import os
+
+import msgspec
+
+from airflow.sdk.api.datamodels.ti import TaskInstance
+
+
+class ExecuteTaskActivity(msgspec.Struct, tag="ExecuteTask", tag_field="kind"):
+    ti: TaskInstance
+    path: os.PathLike[str]
+    token: str
+    """The identity token for this workload"""
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/dagrun.py 
b/task_sdk/src/airflow/sdk/api/datamodels/dagrun.py
new file mode 100644
index 00000000000..224006774df
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/api/datamodels/dagrun.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import msgspec
+
+
+class DagRun(msgspec.Struct, omit_defaults=True):
+    run_id: str
+    dag_id: str
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/ti.py 
b/task_sdk/src/airflow/sdk/api/datamodels/ti.py
new file mode 100644
index 00000000000..6c83b5b7081
--- /dev/null
+++ b/task_sdk/src/airflow/sdk/api/datamodels/ti.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+
+import msgspec
+
+from airflow.sdk.api.datamodels.dagrun import DagRun
+
+
+class TaskInstance(msgspec.Struct, omit_defaults=True):
+    id: uuid.UUID
+
+    task_id: str
+    run: DagRun
+    try_number: int
+    map_index: int | None = None
diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py 
b/task_sdk/src/airflow/sdk/execution_time/comms.py
index 7bc6d12314e..41108638e2c 100644
--- a/task_sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task_sdk/src/airflow/sdk/execution_time/comms.py
@@ -48,31 +48,7 @@ from typing import Any, Union
 import msgspec
 
 from airflow.sdk.api.datamodels._generated import TaskInstanceState  # noqa: 
TCH001
-
-
-class ExecuteTaskActivity(msgspec.Struct):
-    """Information needed to start a task on a worker."""
-
-    ti: TaskInstance
-    token: str
-    path: str | None = None
-
-
-# Temporary: These will next two live in a generated client soon.
-class DagRun(msgspec.Struct):
-    dag_id: str
-    run_id: str
-    data_interval_end: str | None = None
-    data_interval_start: str | None = None
-
-
-class TaskInstance(msgspec.Struct):
-    id: str
-    task_id: str
-    try_number: int
-    map_index: int | None = None
-    is_eligible_to_retry: bool = False
-    run: DagRun | None = None
+from airflow.sdk.api.datamodels.ti import TaskInstance  # noqa: TCH001
 
 
 class StartupDetails(msgspec.Struct, omit_defaults=True, tag=True):
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index c5b1bb6b969..64ebcf9f59c 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -30,21 +30,26 @@ import time
 import weakref
 from collections.abc import Generator
 from contextlib import suppress
-from datetime import datetime
+from datetime import datetime, timezone
 from socket import socket, socketpair
-from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Literal, 
NoReturn, cast, overload
+from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, 
NoReturn, cast, overload
+from uuid import UUID
 
 import attrs
+import httpx
 import msgspec
 import psutil
 import structlog
 
+from airflow.sdk.api.client import Client
+from airflow.sdk.api.datamodels._generated import TaskInstanceState
 from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
 
 if TYPE_CHECKING:
     from structlog.typing import FilteringBoundLogger
 
-    from airflow.sdk.execution_time.comms import ExecuteTaskActivity
+    from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
+    from airflow.sdk.api.datamodels.ti import TaskInstance
 
 
 __all__ = ["WatchedSubprocess"]
@@ -98,6 +103,8 @@ def _fork_main(
     log_fd: int,
     target: Callable[[], None],
 ) -> NoReturn:
+    # TODO: Make this process a session leader
+
     # Uninstall the rich etc. exception handler
     sys.excepthook = sys.__excepthook__
     signal.signal(signal.SIGINT, signal.SIG_DFL)
@@ -186,16 +193,21 @@ def _fork_main(
 
 @attrs.define()
 class WatchedSubprocess:
+    ti_id: UUID
     pid: int
 
     stdin: BinaryIO
     stdout: socket
     stderr: socket
 
+    client: Client
+
     _process: psutil.Process
     _exit_code: int | None = None
     _terminal_state: str | None = None
 
+    _last_heartbeat: float = 0
+
     selector: selectors.BaseSelector = 
attrs.field(factory=selectors.DefaultSelector)
 
     procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = 
weakref.WeakValueDictionary()
@@ -205,7 +217,11 @@ class WatchedSubprocess:
 
     @classmethod
     def start(
-        cls, path: str | os.PathLike[str], ti: Any, target: Callable[[], None] 
= _subprocess_main
+        cls,
+        path: str | os.PathLike[str],
+        ti: TaskInstance,
+        client: Client,
+        target: Callable[[], None] = _subprocess_main,
     ) -> WatchedSubprocess:
         """Fork and start a new subprocess to execute the given task."""
         # Create socketpairs/"pipes" to connect to the stdin and out from the 
subprocess
@@ -225,13 +241,26 @@ class WatchedSubprocess:
             _fork_main(child_stdin, child_stdout, child_stderr, 
child_logs.fileno(), target)
 
         proc = cls(
+            ti_id=ti.id,
             pid=pid,
             stdin=feed_stdin,
             stdout=read_stdout,
             stderr=read_stderr,
             process=psutil.Process(pid),
+            client=client,
         )
 
+        # We've forked, but the task won't start until we send it the 
StartupDetails message. But before we do
+        # that, we need to tell the server it's started (so it has the chance 
to tell us "no, stop!" for any
+        # reason)
+        try:
+            client.task_instances.start(ti.id, pid, 
datetime.now(tz=timezone.utc))
+            proc._last_heartbeat = time.monotonic()
+        except Exception:
+            # On any error kill that subprocess!
+            proc.kill(signal.SIGKILL)
+            raise
+
         # TODO: Use logging providers to handle the chunked upload for us
         task_logger: FilteringBoundLogger = 
structlog.get_logger(logger_name="task").bind()
 
@@ -270,7 +299,6 @@ class WatchedSubprocess:
         log.debug("Sending", msg=msg)
         feed_stdin.write(msgspec.json.encode(msg))
         feed_stdin.write(b"\n")
-        # feed_stdin.flush()
 
         return proc
 
@@ -285,24 +313,56 @@ class WatchedSubprocess:
         if self._exit_code is not None:
             return self._exit_code
 
+        # TODO: Pull this from config
+        heartbeat_rate = 30
+
+        # Until we have a selector for the process, don't poll for more than 
10s, just in case it exists but
+        # doesn't produce any output
+        max_poll_interval = 10
+
         try:
             while self._exit_code is None or len(self.selector.get_map()):
-                events = self.selector.select(timeout=10.0)
+                # Monitor the task to see if it's done. Wait in a syscall 
(`select`) for as long as possible
+                # so we notice the subprocess finishing as quick as we can.
+                max_wait_time = max(
+                    0,  # Make sure this value is never negative,
+                    min(
+                        # Ensure we heartbeat _at most_ 75% through the time 
the zombie threshold time
+                        heartbeat_rate - (time.monotonic() - 
self._last_heartbeat) * 0.75,
+                        max_poll_interval,
+                    ),
+                )
+                events = self.selector.select(timeout=max_wait_time)
                 for key, _ in events:
                     callback = key.data
-                    open = callback(key.fileobj)
+                    need_more = callback(key.fileobj)
 
-                    if not open:
-                        log.debug("Remote end closed, closing", 
fileobj=key.fileobj)
+                    if not need_more:
                         self.selector.unregister(key.fileobj)
                         key.fileobj.close()  # type: ignore[union-attr]
-                # TODO: Send heartbeat here
+
+                if self._exit_code is None:
+                    try:
+                        self._exit_code = self._process.wait(timeout=0)
+                        log.debug("Task process exited", 
exit_code=self._exit_code)
+                    except psutil.TimeoutExpired:
+                        pass
+
                 try:
-                    self._exit_code = self._process.wait(timeout=0.1)
-                except psutil.TimeoutExpired:
+                    # TODO: Currently this will heartbeat _every_ time we read 
any log message. That is way
+                    # too frequent!
+                    self.client.task_instances.heartbeat(self.ti_id)
+                    self._last_heartbeat = time.monotonic()
+                except Exception:
+                    log.warning("Couldn't heartbeat", exc_info=True)
+                    # TODO: If we couldn't heartbeat for X times the interval, 
kill ourselves
                     pass
         finally:
             self.selector.close()
+
+        self.client.task_instances.finish(
+            id=self.ti_id, state=self.final_state, 
when=datetime.now(tz=timezone.utc)
+        )
         return self._exit_code
 
     @property
@@ -316,10 +376,9 @@ class WatchedSubprocess:
 
         Not valid before the process has finished.
         """
-        # TODO: state enums
         if self._exit_code == 0:
-            return self._terminal_state if self._terminal_state is not None 
else "success"
-        return "failed"
+            return self._terminal_state or TaskInstanceState.SUCCESS
+        return TaskInstanceState.FAILED
 
     def __rich_repr__(self):
         yield "pid", self.pid
@@ -424,10 +483,14 @@ def process_log_messages_from_subprocess(log: 
FilteringBoundLogger) -> Generator
             continue
 
         if ts := event.get("timestamp"):
-            # We use msgspec to decode the json as it does it orders of 
magnitude quicker than
-            # datetime.strptime does
-            # TODO: don't hard-code the time format here
-            event["timestamp"] = msgspec.json.decode(f'"{ts}"', type=datetime)
+            # We use msgspec to decode the timestamp as it does it orders of 
magnitude quicker than
+            # datetime.strptime cn
+            #
+            # We remove the timezone info here, as the json encoding has 
`+00:00`, and since the log came
+            # from a subprocess we know that the timezone of the log message 
is the same, so having some
+            # messages include tz (from subprocess) but others not (ones from 
supervisor process) is
+            # confusing.
+            event["timestamp"] = msgspec.json.decode(f'"{ts}"', 
type=datetime).replace(tzinfo=None)
 
         if exc := event.pop("exception", None):
             # TODO: convert the dict back to a pretty stack trace
@@ -464,11 +527,14 @@ def supervise(activity: ExecuteTaskActivity, server: str 
| None = None, dry_run:
     if not activity.path:
         raise ValueError("path filed of activity missing")
 
+    limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
+    client = Client(base_url=server or "", limits=limits, dry_run=dry_run, 
token=activity.token)
+
     start = time.monotonic()
 
-    process = WatchedSubprocess.start(activity.path, activity.ti)
+    process = WatchedSubprocess.start(activity.path, activity.ti, 
client=client)
 
     exit_code = process.wait()
     end = time.monotonic()
-    log.debug("Process exited", exit_code=exit_code, duration=end - start)
+    log.debug("Task finished", exit_code=exit_code, duration=end - start)
     return exit_code
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 11601a98c7a..e1fcc18d98e 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -30,7 +30,6 @@ import structlog
 
 from airflow.sdk import BaseOperator
 from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, 
ToSupervisor, ToTask
-from airflow.sdk.log import configure_logging
 
 if TYPE_CHECKING:
     from structlog.typing import FilteringBoundLogger as Logger
@@ -175,11 +174,6 @@ def finalize(log: Logger): ...
 
 
 def main():
-    # Configure logs to be JSON, so that we can pass it to the parent process
-    # Don't cache this log though!
-
-    configure_logging(enable_pretty_log=False)
-
     # TODO: add an exception here, it causes an oof of a stack trace!
 
     global SUPERVISOR_COMMS
diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py
index fc609011626..624e3f16768 100644
--- a/task_sdk/src/airflow/sdk/log.py
+++ b/task_sdk/src/airflow/sdk/log.py
@@ -161,7 +161,12 @@ def logging_processors(
             indent_guides=False,
             suppress=[asyncio, httpcore, httpx, contextlib, click, typer],
         )
-        console = 
structlog.dev.ConsoleRenderer(exception_formatter=rich_exc_formatter)
+        my_styles = structlog.dev.ConsoleRenderer.get_default_level_styles()
+        my_styles["debug"] = structlog.dev.CYAN
+
+        console = structlog.dev.ConsoleRenderer(
+            exception_formatter=rich_exc_formatter, level_styles=my_styles
+        )
         processors.append(console)
         return processors, {
             "timestamper": timestamper,
diff --git a/task_sdk/src/airflow/sdk/types.py 
b/task_sdk/src/airflow/sdk/types.py
index 232d08e27f9..ffde2170b17 100644
--- a/task_sdk/src/airflow/sdk/types.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
     Logger = logging.Logger
 else:
 
-    class Logger: ...  # noqa: D101
+    class Logger: ...
 
 
 def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, 
Any]) -> None:
diff --git a/task_sdk/tests/defintions/test_baseoperator.py 
b/task_sdk/tests/defintions/test_baseoperator.py
index 427d1ee0e3e..19035319cdc 100644
--- a/task_sdk/tests/defintions/test_baseoperator.py
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -29,6 +29,25 @@ from airflow.task.priority_strategy import 
_DownstreamPriorityWeightStrategy, _U
 DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
 
 
[email protected](autouse=True, scope="module")
+def _disable_ol_plugin():
+    # The OpenLineage plugin imports setproctitle, and that now causes (C) 
level thread calls, which on Py
+    # 3.12+ issues a warning when os.fork happens. So for this plugin we 
disable it
+
+    # And we load plugins when setting the priorty_weight field
+    import airflow.plugins_manager
+
+    old = airflow.plugins_manager.plugins
+
+    assert old is None, "Plugins already loaded, too late to stop them being 
loaded!"
+
+    airflow.plugins_manager.plugins = []
+
+    yield
+
+    airflow.plugins_manager.plugins = None
+
+
 # Essentially similar to airflow.models.baseoperator.BaseOperator
 class FakeOperator(metaclass=BaseOperatorMeta):
     def __init__(self, test_param, params=None, default_args=None):
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 8770cd9827f..b86f5f1a593 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -39,7 +39,6 @@ def subprocess_main():
     import logging
 
     logging.getLogger("airflow.foobar").error("An error message")
-    ...
 
 
 @pytest.fixture
@@ -86,5 +85,10 @@ class TestWatchedSubprocess:
                 "logger": "task",
                 "timestamp": "2024-11-07T12:34:56.078901Z",
             },
-            {"event": "An error message", "level": "error", "logger": 
"airflow.foobar", "timestamp": instant},
+            {
+                "event": "An error message",
+                "level": "error",
+                "logger": "airflow.foobar",
+                "timestamp": instant.replace(tzinfo=None),
+            },
         ]
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index e1fd338b5e5..fec2f339dcb 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -17,14 +17,12 @@
 
 from __future__ import annotations
 
+import uuid
 from socket import socketpair
-from typing import TYPE_CHECKING
 
+from airflow.sdk.execution_time.comms import StartupDetails
 from airflow.sdk.execution_time.task_runner import CommsDecoder
 
-if TYPE_CHECKING:
-    from airflow.sdk.execution_time.comms import StartupDetails
-
 
 class TestCommsDecoder:
     """Test the communication between the subprocess and the "supervisor"."""
@@ -33,15 +31,20 @@ class TestCommsDecoder:
         r, w = socketpair()
 
         w.makefile("wb").write(
-            b'{"type":"StartupDetails", "ti": {"id": "a", "task_id": "b", 
"try_number": 1}, '
+            b'{"type":"StartupDetails", "ti": {'
+            b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", 
"try_number": 1, "run":'
+            b'{"run_id": "b", "dag_id": "c"} }, '
             b'"file": "/dev/null", "requests_fd": 4'
             b"}\n"
         )
 
         decoder = CommsDecoder(input=r.makefile("r"))
 
-        msg: StartupDetails = decoder.get_message()
-        assert msg.ti.task_id == "b"
+        msg = decoder.get_message()
+        assert isinstance(msg, StartupDetails)
+        assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab")
+        assert msg.ti.task_id == "a"
+        assert msg.ti.run.dag_id == "c"
         assert msg.file == "/dev/null"
 
         # Since this was a StartupDetails message, the decoder should open the 
other socket

Reply via email to