This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch execution-time-code-in-task-sdk in repository https://gitbox.apache.org/repos/asf/airflow.git
commit fa4d908515ac90f7a65bb00acc2606413eeceb3a Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Fri Nov 8 15:08:23 2024 +0000 WIP --- airflow/utils/net.py | 4 +- task_sdk/pyproject.toml | 12 +- task_sdk/src/airflow/sdk/api/client.py | 143 +++++++++++++++++++++ .../src/airflow/sdk/api/datamodels/activities.py | 31 +++++ task_sdk/src/airflow/sdk/api/datamodels/dagrun.py | 25 ++++ task_sdk/src/airflow/sdk/api/datamodels/ti.py | 33 +++++ task_sdk/src/airflow/sdk/execution_time/comms.py | 26 +--- .../src/airflow/sdk/execution_time/supervisor.py | 108 +++++++++++++--- .../src/airflow/sdk/execution_time/task_runner.py | 6 - task_sdk/src/airflow/sdk/log.py | 7 +- task_sdk/src/airflow/sdk/types.py | 2 +- task_sdk/tests/defintions/test_baseoperator.py | 19 +++ task_sdk/tests/execution_time/test_supervisor.py | 8 +- task_sdk/tests/execution_time/test_task_runner.py | 17 ++- 14 files changed, 374 insertions(+), 67 deletions(-) diff --git a/airflow/utils/net.py b/airflow/utils/net.py index 992aee67e80..9fc79b3842c 100644 --- a/airflow/utils/net.py +++ b/airflow/utils/net.py @@ -20,8 +20,6 @@ from __future__ import annotations import socket from functools import lru_cache -from airflow.configuration import conf - # patched version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254 @lru_cache(maxsize=None) @@ -53,4 +51,6 @@ def get_host_ip_address(): def get_hostname(): """Fetch the hostname using the callable from config or use `airflow.utils.net.getfqdn` as a fallback.""" + from airflow.configuration import conf + return conf.getimport("core", "hostname_callable", fallback="airflow.utils.net.getfqdn")() diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 5dd4abb3f30..f2aa138c14a 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -24,6 +24,7 @@ requires-python = ">=3.9, <3.13" dependencies = [ "attrs>=24.2.0", "google-re2>=1.1.20240702", + "httpx>=0.27.0", "methodtools>=0.4.7", "msgspec>=0.18.6", "psutil>=6.1.0", @@ -47,10 +48,17 @@ namespace-packages = ["src/airflow"] # Ignore Doc rules et al for anything outside of tests "!src/*" = ["D", "TID253", "S101", "TRY002"] +# Only have pytest rules in tests - https://github.com/astral-sh/ruff/issues/14205 +"!tests/*" = ["PT"] "src/airflow/sdk/__init__.py" = ["TCH004"] -# This is not part of the public API, so disable some of the doc requirements -"src/airflow/sdk/execution_time/*" = ["D101"] + +# msgspec needs types for annotations to be defined, even with future +# annotations, so disable the "type check only import" for these files +"src/airflow/sdk/api/datamodels/*.py" = ["TCH001"] + +# Only the public API should _require_ docstrings on classes +"!src/airflow/sdk/definitions/*" = ["D101"] # Generated file, be less strict "src/airflow/sdk/*/_generated.py" = ["D"] diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py new file mode 100644 index 00000000000..00868d050a5 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING, Any + +import httpx +import methodtools +import msgspec +import structlog +from uuid6 import uuid7 + +from airflow.sdk.api.datamodels._generated import ( + State1 as TerminalState, + TaskInstanceState, + TIEnterRunningPayload, + TITerminalStatePayload, +) +from airflow.utils.net import get_hostname +from airflow.utils.platform import getuser + +if TYPE_CHECKING: + from datetime import datetime + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "Client", + "TaskInstanceOperations", +] + + +def raise_on_4xx_5xx(response: httpx.Response): + return response.raise_for_status() + + +# Py 3.11+ version +def raise_on_4xx_5xx_with_note(response: httpx.Response): + try: + return response.raise_for_status() + except httpx.HTTPStatusError as e: + if TYPE_CHECKING: + assert hasattr(e, "add_note") + e.add_note( + f"Correlation-id={response.headers.get('correlation-id', None) or response.request.headers.get('correlation-id', 'no-correlction-id')}" + ) + raise + + +if hasattr(BaseException, "add_note"): + # Py 3.11+ + raise_on_4xx_5xx = raise_on_4xx_5xx_with_note + + +def add_correlation_id(request: httpx.Request): + request.headers["correlation-id"] = str(uuid7()) + + +class TaskInstanceOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def start(self, id: uuid.UUID, pid: int, when: datetime): + """Tell the API server that this TI has started running.""" + body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when) + + self.client.patch(f"task_instance/{id}/state", content=self.client.encoder.encode(body)) + + def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime): + """Tell the API server that this TI has reached a terminal state.""" + body = TITerminalStatePayload(end_date=when, state=TerminalState(state)) + + self.client.patch(f"task_instance/{id}/state", content=self.client.encoder.encode(body)) + + def heartbeat(self, id: uuid.UUID): + self.client.put(f"task_instance/{id}/heartbeat") + + +class BearerAuth(httpx.Auth): + def __init__(self, token: str): + self.token: str = token + + def auth_flow(self, request: httpx.Request): + if self.token: + request.headers["Authorization"] = "Bearer " + self.token + yield request + + +def noop_handler(request: httpx.Request) -> httpx.Response: + log.debug("Dry-run request", method=request.method, path=request.url.path) + return httpx.Response(200, json={"text": "Hello, world!"}) + + +class Client(httpx.Client): + encoder: msgspec.json.Encoder + + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): + if (not base_url) ^ dry_run: + raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") + auth = BearerAuth(token) + + self.encoder = msgspec.json.Encoder() + if dry_run: + # If dry run is requests, install a no op handler so that simple tasks can "heartbeat" using a + # real client, but just don't make any HTTP requests + kwargs["transport"] = httpx.MockTransport(noop_handler) + kwargs["base_url"] = "dry-run://server" + else: + kwargs["base_url"] = base_url + super().__init__( + auth=auth, + headers={"airflow-api-version": "2024-07-30"}, + event_hooks={"response": [raise_on_4xx_5xx], "request": [add_correlation_id]}, + **kwargs, + ) + + # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all + # methods on one object prefixed with the object type (`.task_instances.update` rather than + # `task_instance_update` etc.) + + @methodtools.lru_cache() # type: ignore[misc] + @property + def task_instances(self) -> TaskInstanceOperations: + """Operations related to TaskInstances.""" + return TaskInstanceOperations(self) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/activities.py b/task_sdk/src/airflow/sdk/api/datamodels/activities.py new file mode 100644 index 00000000000..17b15d7c017 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os + +import msgspec + +from airflow.sdk.api.datamodels.ti import TaskInstance + + +class ExecuteTaskActivity(msgspec.Struct, tag="ExecuteTask", tag_field="kind"): + ti: TaskInstance + path: os.PathLike[str] + token: str + """The identity token for this workload""" diff --git a/task_sdk/src/airflow/sdk/api/datamodels/dagrun.py b/task_sdk/src/airflow/sdk/api/datamodels/dagrun.py new file mode 100644 index 00000000000..224006774df --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/dagrun.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import msgspec + + +class DagRun(msgspec.Struct, omit_defaults=True): + run_id: str + dag_id: str diff --git a/task_sdk/src/airflow/sdk/api/datamodels/ti.py b/task_sdk/src/airflow/sdk/api/datamodels/ti.py new file mode 100644 index 00000000000..6c83b5b7081 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/ti.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid + +import msgspec + +from airflow.sdk.api.datamodels.dagrun import DagRun + + +class TaskInstance(msgspec.Struct, omit_defaults=True): + id: uuid.UUID + + task_id: str + run: DagRun + try_number: int + map_index: int | None = None diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 7bc6d12314e..41108638e2c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -48,31 +48,7 @@ from typing import Any, Union import msgspec from airflow.sdk.api.datamodels._generated import TaskInstanceState # noqa: TCH001 - - -class ExecuteTaskActivity(msgspec.Struct): - """Information needed to start a task on a worker.""" - - ti: TaskInstance - token: str - path: str | None = None - - -# Temporary: These will next two live in a generated client soon. -class DagRun(msgspec.Struct): - dag_id: str - run_id: str - data_interval_end: str | None = None - data_interval_start: str | None = None - - -class TaskInstance(msgspec.Struct): - id: str - task_id: str - try_number: int - map_index: int | None = None - is_eligible_to_retry: bool = False - run: DagRun | None = None +from airflow.sdk.api.datamodels.ti import TaskInstance # noqa: TCH001 class StartupDetails(msgspec.Struct, omit_defaults=True, tag=True): diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index c5b1bb6b969..64ebcf9f59c 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -30,21 +30,26 @@ import time import weakref from collections.abc import Generator from contextlib import suppress -from datetime import datetime +from datetime import datetime, timezone from socket import socket, socketpair -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload +from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload +from uuid import UUID import attrs +import httpx import msgspec import psutil import structlog +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import TaskInstanceState from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger - from airflow.sdk.execution_time.comms import ExecuteTaskActivity + from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity + from airflow.sdk.api.datamodels.ti import TaskInstance __all__ = ["WatchedSubprocess"] @@ -98,6 +103,8 @@ def _fork_main( log_fd: int, target: Callable[[], None], ) -> NoReturn: + # TODO: Make this process a session leader + # Uninstall the rich etc. exception handler sys.excepthook = sys.__excepthook__ signal.signal(signal.SIGINT, signal.SIG_DFL) @@ -186,16 +193,21 @@ def _fork_main( @attrs.define() class WatchedSubprocess: + ti_id: UUID pid: int stdin: BinaryIO stdout: socket stderr: socket + client: Client + _process: psutil.Process _exit_code: int | None = None _terminal_state: str | None = None + _last_heartbeat: float = 0 + selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary() @@ -205,7 +217,11 @@ class WatchedSubprocess: @classmethod def start( - cls, path: str | os.PathLike[str], ti: Any, target: Callable[[], None] = _subprocess_main + cls, + path: str | os.PathLike[str], + ti: TaskInstance, + client: Client, + target: Callable[[], None] = _subprocess_main, ) -> WatchedSubprocess: """Fork and start a new subprocess to execute the given task.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess @@ -225,13 +241,26 @@ class WatchedSubprocess: _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) proc = cls( + ti_id=ti.id, pid=pid, stdin=feed_stdin, stdout=read_stdout, stderr=read_stderr, process=psutil.Process(pid), + client=client, ) + # We've forked, but the task won't start until we send it the StartupDetails message. But before we do + # that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any + # reason) + try: + client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc)) + proc._last_heartbeat = time.monotonic() + except Exception: + # On any error kill that subprocess! + proc.kill(signal.SIGKILL) + raise + # TODO: Use logging providers to handle the chunked upload for us task_logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() @@ -270,7 +299,6 @@ class WatchedSubprocess: log.debug("Sending", msg=msg) feed_stdin.write(msgspec.json.encode(msg)) feed_stdin.write(b"\n") - # feed_stdin.flush() return proc @@ -285,24 +313,56 @@ class WatchedSubprocess: if self._exit_code is not None: return self._exit_code + # TODO: Pull this from config + heartbeat_rate = 30 + + # Until we have a selector for the process, don't poll for more than 10s, just in case it exists but + # doesn't produce any output + max_poll_interval = 10 + try: while self._exit_code is None or len(self.selector.get_map()): - events = self.selector.select(timeout=10.0) + # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible + # so we notice the subprocess finishing as quick as we can. + max_wait_time = max( + 0, # Make sure this value is never negative, + min( + # Ensure we heartbeat _at most_ 75% through the time the zombie threshold time + heartbeat_rate - (time.monotonic() - self._last_heartbeat) * 0.75, + max_poll_interval, + ), + ) + events = self.selector.select(timeout=max_wait_time) for key, _ in events: callback = key.data - open = callback(key.fileobj) + need_more = callback(key.fileobj) - if not open: - log.debug("Remote end closed, closing", fileobj=key.fileobj) + if not need_more: self.selector.unregister(key.fileobj) key.fileobj.close() # type: ignore[union-attr] - # TODO: Send heartbeat here + + if self._exit_code is None: + try: + self._exit_code = self._process.wait(timeout=0) + log.debug("Task process exited", exit_code=self._exit_code) + except psutil.TimeoutExpired: + pass + try: - self._exit_code = self._process.wait(timeout=0.1) - except psutil.TimeoutExpired: + # TODO: Currently this will heartbeat _every_ time we read any log message. That is way + # too frequent! + self.client.task_instances.heartbeat(self.ti_id) + self._last_heartbeat = time.monotonic() + except Exception: + log.warning("Couldn't heartbeat", exc_info=True) + # TODO: If we couldn't heartbeat for X times the interval, kill ourselves pass finally: self.selector.close() + + self.client.task_instances.finish( + id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) + ) return self._exit_code @property @@ -316,10 +376,9 @@ class WatchedSubprocess: Not valid before the process has finished. """ - # TODO: state enums if self._exit_code == 0: - return self._terminal_state if self._terminal_state is not None else "success" - return "failed" + return self._terminal_state or TaskInstanceState.SUCCESS + return TaskInstanceState.FAILED def __rich_repr__(self): yield "pid", self.pid @@ -424,10 +483,14 @@ def process_log_messages_from_subprocess(log: FilteringBoundLogger) -> Generator continue if ts := event.get("timestamp"): - # We use msgspec to decode the json as it does it orders of magnitude quicker than - # datetime.strptime does - # TODO: don't hard-code the time format here - event["timestamp"] = msgspec.json.decode(f'"{ts}"', type=datetime) + # We use msgspec to decode the timestamp as it does it orders of magnitude quicker than + # datetime.strptime cn + # + # We remove the timezone info here, as the json encoding has `+00:00`, and since the log came + # from a subprocess we know that the timezone of the log message is the same, so having some + # messages include tz (from subprocess) but others not (ones from supervisor process) is + # confusing. + event["timestamp"] = msgspec.json.decode(f'"{ts}"', type=datetime).replace(tzinfo=None) if exc := event.pop("exception", None): # TODO: convert the dict back to a pretty stack trace @@ -464,11 +527,14 @@ def supervise(activity: ExecuteTaskActivity, server: str | None = None, dry_run: if not activity.path: raise ValueError("path filed of activity missing") + limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) + client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=activity.token) + start = time.monotonic() - process = WatchedSubprocess.start(activity.path, activity.ti) + process = WatchedSubprocess.start(activity.path, activity.ti, client=client) exit_code = process.wait() end = time.monotonic() - log.debug("Process exited", exit_code=exit_code, duration=end - start) + log.debug("Task finished", exit_code=exit_code, duration=end - start) return exit_code diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 11601a98c7a..e1fcc18d98e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -30,7 +30,6 @@ import structlog from airflow.sdk import BaseOperator from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, ToSupervisor, ToTask -from airflow.sdk.log import configure_logging if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -175,11 +174,6 @@ def finalize(log: Logger): ... def main(): - # Configure logs to be JSON, so that we can pass it to the parent process - # Don't cache this log though! - - configure_logging(enable_pretty_log=False) - # TODO: add an exception here, it causes an oof of a stack trace! global SUPERVISOR_COMMS diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py index fc609011626..624e3f16768 100644 --- a/task_sdk/src/airflow/sdk/log.py +++ b/task_sdk/src/airflow/sdk/log.py @@ -161,7 +161,12 @@ def logging_processors( indent_guides=False, suppress=[asyncio, httpcore, httpx, contextlib, click, typer], ) - console = structlog.dev.ConsoleRenderer(exception_formatter=rich_exc_formatter) + my_styles = structlog.dev.ConsoleRenderer.get_default_level_styles() + my_styles["debug"] = structlog.dev.CYAN + + console = structlog.dev.ConsoleRenderer( + exception_formatter=rich_exc_formatter, level_styles=my_styles + ) processors.append(console) return processors, { "timestamper": timestamper, diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 232d08e27f9..ffde2170b17 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -57,7 +57,7 @@ if TYPE_CHECKING: Logger = logging.Logger else: - class Logger: ... # noqa: D101 + class Logger: ... def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, Any]) -> None: diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index 427d1ee0e3e..19035319cdc 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -29,6 +29,25 @@ from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _U DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) [email protected](autouse=True, scope="module") +def _disable_ol_plugin(): + # The OpenLineage plugin imports setproctitle, and that now causes (C) level thread calls, which on Py + # 3.12+ issues a warning when os.fork happens. So for this plugin we disable it + + # And we load plugins when setting the priorty_weight field + import airflow.plugins_manager + + old = airflow.plugins_manager.plugins + + assert old is None, "Plugins already loaded, too late to stop them being loaded!" + + airflow.plugins_manager.plugins = [] + + yield + + airflow.plugins_manager.plugins = None + + # Essentially similar to airflow.models.baseoperator.BaseOperator class FakeOperator(metaclass=BaseOperatorMeta): def __init__(self, test_param, params=None, default_args=None): diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 8770cd9827f..b86f5f1a593 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -39,7 +39,6 @@ def subprocess_main(): import logging logging.getLogger("airflow.foobar").error("An error message") - ... @pytest.fixture @@ -86,5 +85,10 @@ class TestWatchedSubprocess: "logger": "task", "timestamp": "2024-11-07T12:34:56.078901Z", }, - {"event": "An error message", "level": "error", "logger": "airflow.foobar", "timestamp": instant}, + { + "event": "An error message", + "level": "error", + "logger": "airflow.foobar", + "timestamp": instant.replace(tzinfo=None), + }, ] diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index e1fd338b5e5..fec2f339dcb 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -17,14 +17,12 @@ from __future__ import annotations +import uuid from socket import socketpair -from typing import TYPE_CHECKING +from airflow.sdk.execution_time.comms import StartupDetails from airflow.sdk.execution_time.task_runner import CommsDecoder -if TYPE_CHECKING: - from airflow.sdk.execution_time.comms import StartupDetails - class TestCommsDecoder: """Test the communication between the subprocess and the "supervisor".""" @@ -33,15 +31,20 @@ class TestCommsDecoder: r, w = socketpair() w.makefile("wb").write( - b'{"type":"StartupDetails", "ti": {"id": "a", "task_id": "b", "try_number": 1}, ' + b'{"type":"StartupDetails", "ti": {' + b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run":' + b'{"run_id": "b", "dag_id": "c"} }, ' b'"file": "/dev/null", "requests_fd": 4' b"}\n" ) decoder = CommsDecoder(input=r.makefile("r")) - msg: StartupDetails = decoder.get_message() - assert msg.ti.task_id == "b" + msg = decoder.get_message() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.run.dag_id == "c" assert msg.file == "/dev/null" # Since this was a StartupDetails message, the decoder should open the other socket
