This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch execution-time-code-in-task-sdk in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 36a56cce4528a649dcdbdb2e05518c1957f661f0 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Thu Nov 7 22:40:04 2024 +0000 Start building the replacement task runner for Task Execution SDK The eventual goal if this "airflow.sdk.exeuction_time" package is to replace LocalTaskJob and StandardTaskRunner, but at this stage it co-exists with it's replacement. As this PR is not a complete re-implementation of all the features that exist currently (no handling of task level callbacks yet, no AirflowSkipException etc.) the current tests are skeleton at best. Once we get closer to feature parity (in future PRs) the tests will grow to match. This supervisor and task runner operates slightly differently to the current classes in these ways **Logs from the subprocess are send over a different channel to stdout/stderr** This makes the task supervisor a little bit more complex as it now has to read stdout, stderr and a logs channel. The advantage of this approach is that it makes the logs setup in the task process itself markedly simpler -- all it has to do is write logs output to the custom file handle as JSON and it will show up "natively" as logs. structlog has been chosen as the logging engine over stdlib's own logging as the ability to have structured fields in the logs is nice, and stdlib is configured to send logs to a stuctlog processor. **Direct database access is replaced with an HTTP API client** This is the crux of this feature and of AIP-72 in general -- tasks run via this runner can no longer access DB models or DB session directly. This PR doesn't yet implement the code/shims to make `Connection.get_connection_from_secrets` use this client yet - that will be future work. The reason tasks don't speak directly to the API server is primarily for two reasons: 1. The supervisor process already needs to maintain an http session in order to report the task as started, to heart beat it, and to mark it as finished; and so because of that 2. Reduce the number of active HTTP connections for tasks to 1 per task (instead of 2 per task). THe other reason we have this interface is that DAG parsing code will very soon need to be updated to not have direct DB access either, and having this "in process" interface ability already means that we can support commands like `airflow dags reserialize` without having a running API server. The API client itself is not auto-generated: I tried a number of different client generates based on the OpenAPI spec and found them all lacking or buggy in different ways, and the http client side itself is very simple, the only interesting/difficult bit is the generation of the datamodels from the OpenAPI spec which I found one that msgspec was chosen over Pydantic as it is much lighter weight (and thus quicker), especially on a client side when we have next to no validation requirements of response data. I admit that I have not benchmarked it specifically though. --- airflow/utils/net.py | 4 +- pyproject.toml | 4 + task_sdk/pyproject.toml | 48 ++ .../airflow/sdk/api/__init__.py} | 15 - task_sdk/src/airflow/sdk/api/client.py | 215 ++++++++ .../airflow/sdk/api/datamodels/__init__.py} | 15 - .../src/airflow/sdk/api/datamodels/_generated.py | 130 +++++ .../airflow/sdk/api/datamodels/activities.py} | 16 +- .../airflow/sdk/api/datamodels/ti.py} | 19 +- .../airflow/sdk/execution_time/__init__.py} | 16 +- task_sdk/src/airflow/sdk/execution_time/comms.py | 103 ++++ .../src/airflow/sdk/execution_time/supervisor.py | 547 +++++++++++++++++++++ .../src/airflow/sdk/execution_time/task_runner.py | 195 ++++++++ task_sdk/src/airflow/sdk/log.py | 377 ++++++++++++++ task_sdk/src/airflow/sdk/types.py | 2 +- task_sdk/tests/conftest.py | 58 +++ task_sdk/tests/defintions/test_baseoperator.py | 19 + .../{conftest.py => execution_time/__init__.py} | 15 - task_sdk/tests/execution_time/test_supervisor.py | 146 ++++++ task_sdk/tests/execution_time/test_task_runner.py | 51 ++ 20 files changed, 1915 insertions(+), 80 deletions(-) diff --git a/airflow/utils/net.py b/airflow/utils/net.py index 992aee67e8..9fc79b3842 100644 --- a/airflow/utils/net.py +++ b/airflow/utils/net.py @@ -20,8 +20,6 @@ from __future__ import annotations import socket from functools import lru_cache -from airflow.configuration import conf - # patched version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254 @lru_cache(maxsize=None) @@ -53,4 +51,6 @@ def get_host_ip_address(): def get_hostname(): """Fetch the hostname using the callable from config or use `airflow.utils.net.getfqdn` as a fallback.""" + from airflow.configuration import conf + return conf.getimport("core", "hostname_callable", fallback="airflow.utils.net.getfqdn")() diff --git a/pyproject.toml b/pyproject.toml index f008bfd810..23c67014f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -431,6 +431,10 @@ fixture-parentheses = false [tool.pytest.ini_options] addopts = [ "--tb=short", + # The stdin/out redirecting tests in TaskSDK don't work with the default + # capture=fd, so swap it to sys. There's no easy way I've found ot set this + # setting early enough for with our workspaces folder + "--capture=sys", "-rasl", "--verbosity=2", # Disable `flaky` plugin for pytest. This plugin conflicts with `rerunfailures` because provide the same marker. diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 37ea2d300a..439f91124d 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -24,7 +24,11 @@ requires-python = ">=3.9, <3.13" dependencies = [ "attrs>=24.2.0", "google-re2>=1.1.20240702", + "httpx>=0.27.0", "methodtools>=0.4.7", + "msgspec>=0.18.6", + "psutil>=6.1.0", + "structlog>=24.4.0", ] [build-system] @@ -44,8 +48,21 @@ namespace-packages = ["src/airflow"] # Ignore Doc rules et al for anything outside of tests "!src/*" = ["D", "TID253", "S101", "TRY002"] +# Only have pytest rules in tests - https://github.com/astral-sh/ruff/issues/14205 +"!tests/*" = ["PT"] + "src/airflow/sdk/__init__.py" = ["TCH004"] +# msgspec needs types for annotations to be defined, even with future +# annotations, so disable the "type check only import" for these files +"src/airflow/sdk/api/datamodels/*.py" = ["TCH001"] + +# Only the public API should _require_ docstrings on classes +"!src/airflow/sdk/definitions/*" = ["D101"] + +# Generated file, be less strict +"src/airflow/sdk/*/_generated.py" = ["D"] + [tool.uv] dev-dependencies = [ "kgb>=7.1.1", @@ -54,6 +71,7 @@ dev-dependencies = [ "pytest>=8.3.3", ] + [tool.coverage.run] branch = true relative_files = true @@ -71,3 +89,33 @@ exclude_also = [ "@(typing(_extensions)?\\.)?overload", "if (typing(_extensions)?\\.)?TYPE_CHECKING:", ] + +[dependency-groups] +codegen = [ + "datamodel-code-generator[http]>=0.26.3", +] + +[tool.black] +# This is needed for datamodel-codegen to treat this as the "project" file + +# To use: +# +# uv run --group codegen --project apache-airflow-task-sdk --directory task_sdk datamodel-codegen +[tool.datamodel-codegen] +capitalise-enum-members=true # `State.RUNNING` not `State.running` +disable-timestamp=true +enable-version-header=true +enum-field-as-literal='one' # When a single enum member, make it output a `Literal["..."]` +input-file-type='openapi' +output-model-type='msgspec.Struct' +output-datetime-class='datetime' +target-python-version='3.9' +use-default=true +use-double-quotes=true +use-schema-description=true # Desc becomes class doc comment +use-standard-collections=true # list[] not List[] +use-subclass-enum=true # enum, not union of Literals +use-union-operator=true # 3.9+annotations, not `Union[]` + +url = 'http://0.0.0.0:9091/execution/openapi.json' +output = 'src/airflow/sdk/api/datamodels/_generated.py' diff --git a/task_sdk/tests/conftest.py b/task_sdk/src/airflow/sdk/api/__init__.py similarity index 69% copy from task_sdk/tests/conftest.py copy to task_sdk/src/airflow/sdk/api/__init__.py index ddc7c61656..13a83393a9 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/src/airflow/sdk/api/__init__.py @@ -14,18 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -import os - -import pytest - -pytest_plugins = "tests_common.pytest_plugin" - -# Task SDK does not need access to the Airflow database -os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" - - [email protected](tryfirst=True) -def pytest_configure(config: pytest.Config) -> None: - config.inicfg["airflow_deprecations_ignore"] = [] diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py new file mode 100644 index 0000000000..cc26911d38 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING, Any + +import httpx +import methodtools +import msgspec +import structlog +from uuid6 import uuid7 + +from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, + State1 as TerminalState, + TaskInstanceState, + TIEnterRunningPayload, + TITerminalStatePayload, + ValidationError as RemoteValidationError, +) +from airflow.utils.net import get_hostname +from airflow.utils.platform import getuser + +if TYPE_CHECKING: + from datetime import datetime + + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "Client", + "ConnectionOperations", + "ErrorBody", + "ServerResponseError", + "TaskInstanceOperations", +] + + +def get_json_error(response: httpx.Response): + """Raise a ServerResponseError if we can extract error info from the error.""" + err = ServerResponseError.from_response(response) + if err: + log.warning("Server error", detail=err.detail) + raise err + + +def raise_on_4xx_5xx(response: httpx.Response): + return get_json_error(response) or response.raise_for_status() + + +# Py 3.11+ version +def raise_on_4xx_5xx_with_note(response: httpx.Response): + try: + return get_json_error(response) or response.raise_for_status() + except httpx.HTTPStatusError as e: + if TYPE_CHECKING: + assert hasattr(e, "add_note") + e.add_note( + f"Correlation-id={response.headers.get('correlation-id', None) or response.request.headers.get('correlation-id', 'no-correlction-id')}" + ) + raise + + +if hasattr(BaseException, "add_note"): + # Py 3.11+ + raise_on_4xx_5xx = raise_on_4xx_5xx_with_note + + +def add_correlation_id(request: httpx.Request): + request.headers["correlation-id"] = str(uuid7()) + + +class TaskInstanceOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def start(self, id: uuid.UUID, pid: int, when: datetime): + """Tell the API server that this TI has started running.""" + body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when) + + self.client.patch(f"task_instance/{id}/state", content=self.client.encoder.encode(body)) + + def finish(self, id: uuid.UUID, state: TaskInstanceState, when: datetime): + """Tell the API server that this TI has reached a terminal state.""" + body = TITerminalStatePayload(end_date=when, state=TerminalState(state)) + + self.client.patch(f"task_instance/{id}/state", content=self.client.encoder.encode(body)) + + def heartbeat(self, id: uuid.UUID): + self.client.put(f"task_instance/{id}/heartbeat") + + +class ConnectionOperations: + __slots__ = ("client", "decoder") + + def __init__(self, client: Client): + self.client = client + self.decoder: msgspec.json.Decoder[ConnectionResponse] = msgspec.json.Decoder(type=ConnectionResponse) + + def get(self, id: str) -> ConnectionResponse: + """Get a connection from the API server.""" + resp = self.client.get(f"connection/{id}") + return self.decoder.decode(resp.read()) + + +class BearerAuth(httpx.Auth): + def __init__(self, token: str): + self.token: str = token + + def auth_flow(self, request: httpx.Request): + if self.token: + request.headers["Authorization"] = "Bearer " + self.token + yield request + + +def noop_handler(request: httpx.Request) -> httpx.Response: + log.debug("Dry-run request", method=request.method, path=request.url.path) + return httpx.Response(200, json={"text": "Hello, world!"}) + + +class Client(httpx.Client): + encoder: msgspec.json.Encoder + + def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any): + if (not base_url) ^ dry_run: + raise ValueError(f"Can only specify one of {base_url=} or {dry_run=}") + auth = BearerAuth(token) + + self.encoder = msgspec.json.Encoder() + if dry_run: + # If dry run is requests, install a no op handler so that simple tasks can "heartbeat" using a + # real client, but just don't make any HTTP requests + kwargs["transport"] = httpx.MockTransport(noop_handler) + kwargs["base_url"] = "dry-run://server" + else: + kwargs["base_url"] = base_url + super().__init__( + auth=auth, + headers={"airflow-api-version": "2024-07-30"}, + event_hooks={"response": [raise_on_4xx_5xx], "request": [add_correlation_id]}, + **kwargs, + ) + + # We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all + # methods on one object prefixed with the object type (`.task_instances.update` rather than + # `task_instance_update` etc.) + + @methodtools.lru_cache() # type: ignore[misc] + @property + def task_instances(self) -> TaskInstanceOperations: + """Operations related to TaskInstances.""" + return TaskInstanceOperations(self) + + @methodtools.lru_cache() # type: ignore[misc] + @property + def connections(self) -> ConnectionOperations: + """Operations related to TaskInstances.""" + return ConnectionOperations(self) + + +class ErrorBody(msgspec.Struct): + detail: list[RemoteValidationError] | dict[str, Any] + + def __repr__(self): + return repr(self.detail) + + +class ServerResponseError(httpx.HTTPStatusError): + def __init__(self, message: str, *, request: httpx.Request, response: httpx.Response): + super().__init__(message, request=request, response=response) + + detail: ErrorBody + + @classmethod + def from_response(cls, response: httpx.Response) -> ServerResponseError | None: + if response.is_success: + return None + # 4xx or 5xx error? + if 400 < (response.status_code // 100) >= 600: + return None + + if response.headers.get("content-type") != "application/json": + return None + + try: + err = msgspec.json.decode(response.read(), type=ErrorBody) + if isinstance(err.detail, list): + msg = "Remote server returned validation error" + else: + msg = err.detail.get("message", "") or "Un-parseable error" + except Exception: + err = msgspec.json.decode(response.content) + msg = "Server returned error" + + self = cls(msg, request=response.request, response=response) + self.detail = err + return self diff --git a/task_sdk/tests/conftest.py b/task_sdk/src/airflow/sdk/api/datamodels/__init__.py similarity index 69% copy from task_sdk/tests/conftest.py copy to task_sdk/src/airflow/sdk/api/datamodels/__init__.py index ddc7c61656..13a83393a9 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/__init__.py @@ -14,18 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -import os - -import pytest - -pytest_plugins = "tests_common.pytest_plugin" - -# Task SDK does not need access to the Airflow database -os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" - - [email protected](tryfirst=True) -def pytest_configure(config: pytest.Config) -> None: - config.inicfg["airflow_deprecations_ignore"] = [] diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py new file mode 100644 index 0000000000..21ce287606 --- /dev/null +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# generated by datamodel-codegen: +# filename: http://0.0.0.0:9091/execution/openapi.json +# version: 0.26.3 + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Literal + +from msgspec import Struct, field + + +class ConnectionResponse(Struct): + """ + Connection schema for responses with fields that are needed for Runtime. + """ + + conn_id: str + conn_type: str + host: str | None = None + schema_: str | None = field(name="schema", default=None) + login: str | None = None + password: str | None = None + port: int | None = None + extra: str | None = None + + +class TIEnterRunningPayload(Struct): + """ + Schema for updating TaskInstance to 'RUNNING' state with minimal required fields. + """ + + hostname: str + unixname: str + pid: int + start_date: datetime + state: Literal["running"] | None = "running" + + +class TIHeartbeatInfo(Struct): + """ + Schema for TaskInstance heartbeat endpoint. + """ + + hostname: str + pid: int + + +class State(Enum): + REMOVED = "removed" + SCHEDULED = "scheduled" + QUEUED = "queued" + RUNNING = "running" + RESTARTING = "restarting" + UP_FOR_RETRY = "up_for_retry" + UP_FOR_RESCHEDULE = "up_for_reschedule" + UPSTREAM_FAILED = "upstream_failed" + DEFERRED = "deferred" + + +class TITargetStatePayload(Struct): + """ + Schema for updating TaskInstance to a target state, excluding terminal and running states. + """ + + state: State + + +class State1(Enum): + FAILED = "failed" + SUCCESS = "success" + SKIPPED = "skipped" + + +class TITerminalStatePayload(Struct): + """ + Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED). + """ + + state: State1 + end_date: datetime + + +class TaskInstanceState(str, Enum): + """ + All possible states that a Task Instance can be in. + + Note that None is also allowed, so always use this in a type hint with Optional. + """ + + REMOVED = "removed" + SCHEDULED = "scheduled" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + RESTARTING = "restarting" + FAILED = "failed" + UP_FOR_RETRY = "up_for_retry" + UP_FOR_RESCHEDULE = "up_for_reschedule" + UPSTREAM_FAILED = "upstream_failed" + SKIPPED = "skipped" + DEFERRED = "deferred" + + +class ValidationError(Struct): + loc: list[str | int] + msg: str + type: str + + +class HTTPValidationError(Struct): + detail: list[ValidationError] | None = None diff --git a/task_sdk/tests/conftest.py b/task_sdk/src/airflow/sdk/api/datamodels/activities.py similarity index 73% copy from task_sdk/tests/conftest.py copy to task_sdk/src/airflow/sdk/api/datamodels/activities.py index ddc7c61656..17b15d7c01 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/activities.py @@ -14,18 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import os -import pytest - -pytest_plugins = "tests_common.pytest_plugin" +import msgspec -# Task SDK does not need access to the Airflow database -os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" +from airflow.sdk.api.datamodels.ti import TaskInstance [email protected](tryfirst=True) -def pytest_configure(config: pytest.Config) -> None: - config.inicfg["airflow_deprecations_ignore"] = [] +class ExecuteTaskActivity(msgspec.Struct, tag="ExecuteTask", tag_field="kind"): + ti: TaskInstance + path: os.PathLike[str] + token: str + """The identity token for this workload""" diff --git a/task_sdk/tests/conftest.py b/task_sdk/src/airflow/sdk/api/datamodels/ti.py similarity index 72% copy from task_sdk/tests/conftest.py copy to task_sdk/src/airflow/sdk/api/datamodels/ti.py index ddc7c61656..72b00c7f34 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/ti.py @@ -14,18 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations -import os +from __future__ import annotations -import pytest +import uuid -pytest_plugins = "tests_common.pytest_plugin" +import msgspec -# Task SDK does not need access to the Airflow database -os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" +class TaskInstance(msgspec.Struct, omit_defaults=True): + id: uuid.UUID [email protected](tryfirst=True) -def pytest_configure(config: pytest.Config) -> None: - config.inicfg["airflow_deprecations_ignore"] = [] + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int | None = None diff --git a/task_sdk/tests/conftest.py b/task_sdk/src/airflow/sdk/execution_time/__init__.py similarity index 69% copy from task_sdk/tests/conftest.py copy to task_sdk/src/airflow/sdk/execution_time/__init__.py index ddc7c61656..217e5db960 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/src/airflow/sdk/execution_time/__init__.py @@ -1,3 +1,4 @@ +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,18 +15,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -import os - -import pytest - -pytest_plugins = "tests_common.pytest_plugin" - -# Task SDK does not need access to the Airflow database -os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" - - [email protected](tryfirst=True) -def pytest_configure(config: pytest.Config) -> None: - config.inicfg["airflow_deprecations_ignore"] = [] diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py new file mode 100644 index 0000000000..26dbd8e87b --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +r""" +Communication protocol between the Supervisor and the task process +================================================================== + +* All communication is done over stdout/stdin in the form of "JSON lines" (each + message is a single JSON document terminated by `\n` character) +* Messages from the subprocess are all log messages and are sent directly to the log +* No messages are sent to task process except in response to a request. (This is because the task process will + be running user's code, so we can't read from stdin until we enter our code, such as when requesting an XCom + value etc.) + +The reason this communication protocol exists, rather than the task process speaking directly to the Task +Execution API server is because: + +1. To reduce the number of concurrent HTTP connections on the API server. + + The supervisor already has to speak to that to heartbeat the running Task, so having the task speak to its + parent process and having all API traffic go through that means that the number of HTTP connections is + "halved". (Not every task will make API calls, so it's not always halved, but it is reduced.) + +2. This means that the user Task code doesn't ever directly see the task identity JWT token. + + This is a short lived token tied to one specific task instance try, so it being leaked/exfiltrated is not a + large risk, but it's easy to not give it to the user code, so lets do that. +""" # noqa: D400, D205 + +from __future__ import annotations + +from typing import Any, Union + +import msgspec + +from airflow.sdk.api.datamodels._generated import TaskInstanceState # noqa: TCH001 +from airflow.sdk.api.datamodels.ti import TaskInstance # noqa: TCH001 + + +class StartupDetails(msgspec.Struct, omit_defaults=True, tag=True): + ti: TaskInstance + file: str + requests_fd: int + """ + The channel for the task to send requests over. + + Responses will come back on stdin + """ + + +class XComResponse(msgspec.Struct, omit_defaults=True, tag=True): + """Response to ReadXCom request.""" + + key: str + value: Any + + +class ConnectionResponse(msgspec.Struct, omit_defaults=True, tag="connection"): + conn: Any + + +ToTask = Union[StartupDetails, XComResponse, ConnectionResponse] + + +class TaskState(msgspec.Struct, omit_defaults=True, tag=True): + """ + Update a task's state. + + If a process exits without sending one of these the state will be derived from the exit code: + - 0 = SUCCESS + - anything else = FAILED + """ + + state: TaskInstanceState + + +class ReadXCom(msgspec.Struct, tag=True): + key: str + + +class GetConnection(msgspec.Struct, tag=True): + id: str + + +class GetVariable(msgspec.Struct, tag=True): + id: str + + +ToSupervisor = Union[TaskState, ReadXCom, GetConnection, GetVariable] diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py new file mode 100644 index 0000000000..f3fad259db --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -0,0 +1,547 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervise and run Tasks in a subprocess.""" + +from __future__ import annotations + +import atexit +import io +import logging +import os +import selectors +import signal +import sys +import time +import weakref +from collections.abc import Generator +from contextlib import suppress +from datetime import datetime, timezone +from socket import socket, socketpair +from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload +from uuid import UUID + +import attrs +import httpx +import msgspec +import psutil +import structlog + +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import TaskInstanceState +from airflow.sdk.execution_time.comms import ConnectionResponse, GetConnection, StartupDetails, ToSupervisor + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger + + from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity + from airflow.sdk.api.datamodels.ti import TaskInstance + + +__all__ = ["WatchedSubprocess", "supervise"] + +log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor") + +# TODO: Pull this from config +SLOWEST_HEARTBEAT_INTERVAL: int = 30 +# Don't heartbeat more often than this +FASTEST_HEARTBEAT_INTERVAL: int = 5 + + +@overload +def mkpipe() -> tuple[socket, socket]: ... + + +@overload +def mkpipe(remote_read: Literal[True]) -> tuple[socket, BinaryIO]: ... + + +def mkpipe( + remote_read: bool = False, +) -> tuple[socket, socket | BinaryIO]: + """ + Create a pair of connected sockets. + + The inheritable flag will be set correctly so that the end destined for the subprocess is kept open but + the end for this process is closed automatically by the OS. + """ + rsock, wsock = socketpair() + local, remote = (wsock, rsock) if remote_read else (rsock, wsock) + + remote.set_inheritable(True) + local.setblocking(False) + + io: BinaryIO | socket + if remote_read: + # If _we_ are writing, we don't want to buffer + io = cast(BinaryIO, local.makefile("wb", buffering=0)) + else: + io = local + + return remote, io + + +def _subprocess_main(): + from airflow.sdk.execution_time.task_runner import main + + main() + + +def _fork_main( + child_stdin: socket, + child_stdout: socket, + child_stderr: socket, + log_fd: int, + target: Callable[[], None], +) -> NoReturn: + # TODO: Make this process a session leader + + # Uninstall the rich etc. exception handler + sys.excepthook = sys.__excepthook__ + signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGUSR2, signal.SIG_DFL) + + if log_fd > 0: + # A channel that the task can send JSON-formated logs over. + # + # JSON logs sent this way will be handled nicely + from airflow.sdk.log import configure_logging + + log_io = os.fdopen(log_fd, "wb", buffering=0) + configure_logging(enable_pretty_log=False, output=log_io) + + last_chance_stderr = sys.__stderr__ or sys.stderr + + # Ensure that sys.stdout et al (and the underlying filehandles for C libraries etc) are connected to the + # pipes form the supervisor + + for handle_name, sock, mode, close in ( + ("stdin", child_stdin, "r", True), + ("stdout", child_stdout, "w", True), + ("stderr", child_stderr, "w", False), + ): + handle = getattr(sys, handle_name) + try: + fd = handle.fileno() + os.dup2(sock.fileno(), fd) + if close: + handle.close() + except io.UnsupportedOperation: + if "PYTEST_CURRENT_TEST" in os.environ: + # When we're running under pytest, the stdin is not a real filehandle with an fd, so we need + # to handle that differently + fd = sock.fileno() + else: + raise + + setattr(sys, handle_name, os.fdopen(fd, mode)) + + def exit(n: int) -> NoReturn: + with suppress(ValueError, OSError): + sys.stdout.flush() + with suppress(ValueError, OSError): + sys.stderr.flush() + with suppress(ValueError, OSError): + last_chance_stderr.flush() + os._exit(n) + + if hasattr(atexit, "_clear"): + # Since we're in a fork we want to try and clear them + atexit._clear() + base_exit = exit + + def exit(n: int) -> NoReturn: + atexit._run_exitfuncs() + base_exit(n) + + try: + target() + exit(0) + except SystemExit as e: + code = 1 + if isinstance(e.code, int): + code = e.code + elif e.code: + print(e.code, file=sys.stderr) + exit(code) + except Exception: + # Last ditch log attempt + exc, v, tb = sys.exc_info() + + import traceback + + try: + last_chance_stderr.write("--- Last chance exception handler ---\n") + traceback.print_exception(exc, value=v, tb=tb, file=last_chance_stderr) + exit(99) + except Exception as e: + with suppress(Exception): + print( + f"--- Last chance exception handler failed --- {repr(str(e))}\n", file=last_chance_stderr + ) + exit(98) + + [email protected]() +class WatchedSubprocess: + ti_id: UUID + pid: int + + stdin: BinaryIO + stdout: socket + stderr: socket + + client: Client + + _process: psutil.Process + _exit_code: int | None = None + _terminal_state: str | None = None + + _last_heartbeat: float = 0 + + selector: selectors.BaseSelector = attrs.field(factory=selectors.DefaultSelector) + + procs: ClassVar[weakref.WeakValueDictionary[int, WatchedSubprocess]] = weakref.WeakValueDictionary() + + def __attrs_post_init__(self): + self.procs[self.pid] = self + + @classmethod + def start( + cls, + path: str | os.PathLike[str], + ti: TaskInstance, + client: Client, + target: Callable[[], None] = _subprocess_main, + ) -> WatchedSubprocess: + """Fork and start a new subprocess to execute the given task.""" + # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess + child_stdin, feed_stdin = mkpipe(remote_read=True) + child_stdout, read_stdout = mkpipe() + child_stderr, read_stderr = mkpipe() + + # Open these socketpair before forking off the child, so that it is open when we fork. + child_comms, read_msgs = mkpipe() + child_logs, read_logs = mkpipe() + + pid = os.fork() + if pid == 0: + # Parent ends of the sockets are closed by the OS as they are set as non-inheritable + + # Run the child entryoint + _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) + + proc = cls( + ti_id=ti.id, + pid=pid, + stdin=feed_stdin, + stdout=read_stdout, + stderr=read_stderr, + process=psutil.Process(pid), + client=client, + ) + + # We've forked, but the task won't start until we send it the StartupDetails message. But before we do + # that, we need to tell the server it's started (so it has the chance to tell us "no, stop!" for any + # reason) + try: + client.task_instances.start(ti.id, pid, datetime.now(tz=timezone.utc)) + proc._last_heartbeat = time.monotonic() + except Exception: + # On any error kill that subprocess! + proc.kill(signal.SIGKILL) + raise + + # TODO: Use logging providers to handle the chunked upload for us + task_logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() + + cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stdout"), level=logging.INFO)) + + proc.selector.register(read_stdout, selectors.EVENT_READ, cb) + cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stderr"), level=logging.ERROR)) + proc.selector.register(read_stderr, selectors.EVENT_READ, cb) + + proc.selector.register( + read_logs, + selectors.EVENT_READ, + make_buffered_socket_reader(process_log_messages_from_subprocess(task_logger)), + ) + proc.selector.register( + read_msgs, + selectors.EVENT_READ, + make_buffered_socket_reader(proc.handle_requests(log=log)), + ) + + # Tell the task process what it needs to do! + msg = StartupDetails( + ti=ti, + file=str(path), + requests_fd=child_comms.fileno(), + ) + + # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the + # other end of the pair open + child_stdout.close() + child_stdin.close() + child_comms.close() + child_logs.close() + + # Send the message to tell the process what it needs to execute + log.debug("Sending", msg=msg) + feed_stdin.write(msgspec.json.encode(msg)) + feed_stdin.write(b"\n") + + return proc + + def kill(self, signal: signal.Signals = signal.SIGINT): + if self._exit_code is not None: + return + + with suppress(ProcessLookupError): + os.kill(self.pid, signal) + + def wait(self) -> int: + if self._exit_code is not None: + return self._exit_code + + # Until we have a selector for the process, don't poll for more than 10s, just in case it exists but + # doesn't produce any output + max_poll_interval = 10 + + try: + while self._exit_code is None or len(self.selector.get_map()): + last_heartbeat_ago = time.monotonic() - self._last_heartbeat + # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible + # so we notice the subprocess finishing as quick as we can. + max_wait_time = max( + 0, # Make sure this value is never negative, + min( + # Ensure we heartbeat _at most_ 75% through the time the zombie threshold time + SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75, + max_poll_interval, + ), + ) + events = self.selector.select(timeout=max_wait_time) + for key, _ in events: + socket_handler = key.data + need_more = socket_handler(key.fileobj) + + if not need_more: + self.selector.unregister(key.fileobj) + key.fileobj.close() # type: ignore[union-attr] + + if self._exit_code is None: + try: + self._exit_code = self._process.wait(timeout=0) + log.debug("Task process exited", exit_code=self._exit_code) + except psutil.TimeoutExpired: + pass + + if last_heartbeat_ago < FASTEST_HEARTBEAT_INTERVAL: + # Avoid heartbeating too frequently + continue + + try: + self.client.task_instances.heartbeat(self.ti_id) + self._last_heartbeat = time.monotonic() + except Exception: + log.warning("Couldn't heartbeat", exc_info=True) + # TODO: If we couldn't heartbeat for X times the interval, kill ourselves + pass + finally: + self.selector.close() + + self.client.task_instances.finish( + id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) + ) + return self._exit_code + + @property + def final_state(self): + """ + The final state of the TaskInstance. + + By default this will be derived from the exit code of the task + (0=success, failed otherwise) but can be changed by the subprocess + sending a TaskState message, as long as the process exits with 0 + + Not valid before the process has finished. + """ + if self._exit_code == 0: + return self._terminal_state or TaskInstanceState.SUCCESS + return TaskInstanceState.FAILED + + def __rich_repr__(self): + yield "pid", self.pid + yield "exit_code", self._exit_code, None + + __rich_repr__.angular = True # type: ignore[attr-defined] + + def __repr__(self) -> str: + rep = f"<WatchedSubprocess pid={self.pid}" + if self._exit_code is not None: + rep += f" exit_code={self._exit_code}" + return rep + " >" + + def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: + decoder: msgspec.json.Decoder[ToSupervisor] = msgspec.json.Decoder(type=ToSupervisor) + encoder = msgspec.json.Encoder() + # Use a buffer to avoid small allocations + buffer = bytearray(64) + while True: + line = yield + + try: + msg = decoder.decode(line) + except Exception: + log.exception("Unable to decode message", line=line) + continue + + # if isinstnace(msg, TaskState): + # self._terminal_state = msg.state + # elif isinstance(msg, ReadXCom): + # resp = XComResponse(key="secret", value=True) + # encoder.encode_into(resp, buffer) + # self.stdin.write(buffer + b"\n") + if isinstance(msg, GetConnection): + conn = self.client.connections.get(msg.id) + resp = ConnectionResponse(conn=conn) + encoder.encode_into(resp, buffer) + else: + log.error("Unhandled request", msg=msg) + continue + + buffer.extend(b"\n") + self.stdin.write(buffer) + + # Ensure the buffer doesn't grow and stay large if a large payload is used. This won't grow it + # larger than it is, but it will shrink it + if len(buffer) > 1024: + buffer = buffer[:1024] + + +# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read +# and it doesn't contain a new line character, `.readline()` will just return the chunk as is. +# +# This returns a cb suitable for attaching to a `selector` that reads in to a buffer, and yields lines to a +# (sync) generator +def make_buffered_socket_reader( + gen: Generator[None, bytes, None], buffer_size: int = 4096 +) -> Callable[[socket], bool]: + buffer = bytearray() # This will hold our accumulated binary data + read_buffer = bytearray(buffer_size) # Temporary buffer for each read + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + nonlocal buffer, read_buffer + # Read up to `buffer_size` bytes of data from the socket + n_received = sock.recv_into(read_buffer) + + if not n_received: + # If no data is returned, the connection is closed. Return whatever is left in the buffer + if len(buffer): + gen.send(buffer) + # Tell loop to close this selector + return False + + buffer.extend(read_buffer[:n_received]) + + # We could have read multiple lines in one go, yield them all + while (newline_pos := buffer.find(b"\n")) != -1: + if TYPE_CHECKING: + # We send in a memoryvuew, but pretend it's a bytes, as Buffer is only in 3.12+ + line = buffer[: newline_pos + 1] + else: + line = memoryview(buffer)[: newline_pos + 1] # Include the newline character + gen.send(line) + buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data + + return True + + return cb + + +def process_log_messages_from_subprocess(log: FilteringBoundLogger) -> Generator[None, bytes, None]: + from structlog.stdlib import NAME_TO_LEVEL + + while True: + # Generator receive syntax, values are "sent" in by the `make_buffered_socket_reader` and returned to + # the yield. + line = yield + + try: + event = msgspec.json.decode(line) + except Exception: + log.exception("Malformed json log line", line=line) + continue + + if ts := event.get("timestamp"): + # We use msgspec to decode the timestamp as it does it orders of magnitude quicker than + # datetime.strptime cn + # + # We remove the timezone info here, as the json encoding has `+00:00`, and since the log came + # from a subprocess we know that the timezone of the log message is the same, so having some + # messages include tz (from subprocess) but others not (ones from supervisor process) is + # confusing. + event["timestamp"] = msgspec.json.decode(f'"{ts}"', type=datetime).replace(tzinfo=None) + + if exc := event.pop("exception", None): + # TODO: convert the dict back to a pretty stack trace + event["error_detail"] = exc + log.log(NAME_TO_LEVEL[event.pop("level")], event.pop("event", None), **event) + + +def forward_to_log(target_log: FilteringBoundLogger, level: int) -> Generator[None, bytes, None]: + while True: + buf = yield + line = bytes(buf) + # Strip off new line + line = line.rstrip() + try: + msg = line.decode("utf-8", errors="replace") + target_log.log(level, msg) + except UnicodeDecodeError: + msg = line.decode("ascii", errors="replace") + target_log.log(level, msg) + + +def supervise(activity: ExecuteTaskActivity, server: str | None = None, dry_run: bool = False) -> int: + """ + Run a single task execution to completion. + + Returns the exit code of the process + """ + # One or the other + if (server == "") ^ dry_run: + raise ValueError(f"Can only specify one of {server=} or {dry_run=}") + + if not activity.path: + raise ValueError("path filed of activity missing") + + limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) + client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=activity.token) + + start = time.monotonic() + + process = WatchedSubprocess.start(activity.path, activity.ti, client=client) + + exit_code = process.wait() + end = time.monotonic() + log.debug("Task finished", exit_code=exit_code, duration=end - start) + return exit_code diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py new file mode 100644 index 0000000000..4026788733 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The entrypoint for the actual task execution process.""" + +from __future__ import annotations + +import os +import sys +from io import FileIO +from typing import TYPE_CHECKING, TextIO + +import attrs +import msgspec +import structlog + +from airflow.sdk import BaseOperator +from airflow.sdk.execution_time.comms import StartupDetails, TaskInstance, ToSupervisor, ToTask + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + +class RuntimeTaskInstance(TaskInstance, kw_only=True): + task: BaseOperator + + +def parse(what: StartupDetails) -> RuntimeTaskInstance: + # TODO: Task-SDK: + # Using DagBag here is aoubt 98% wrong, but it'll do for now + + from airflow.models.dagbag import DagBag + + bag = DagBag( + dag_folder=what.file, + include_examples=False, + safe_mode=False, + load_op_links=False, + ) + if TYPE_CHECKING: + assert what.ti.dag_id + + dag = bag.dags[what.ti.dag_id] + + # install_loader() + + # TODO: Handle task not found + task = dag.task_dict[what.ti.task_id] + if not isinstance(task, BaseOperator): + raise TypeError(f"task is of the wrong type, got {type(task)}, wanted {BaseOperator}") + return RuntimeTaskInstance(**msgspec.structs.asdict(what.ti), task=task) + + [email protected]() +class CommsDecoder: + """Handle communication between the task in this process and the supervisor parent process.""" + + input: TextIO = sys.stdin + + decoder: msgspec.json.Decoder[ToTask] = attrs.field(factory=lambda: msgspec.json.Decoder(type=ToTask)) + encoder: msgspec.json.Encoder = attrs.field(factory=msgspec.json.Encoder) + + request_socket: FileIO = attrs.field(init=False, default=None) + + # Allocate a buffer for en/decoding. We use a small non-empty buffer to avoid reallocating for small messages. + buffer: bytearray = attrs.field(factory=lambda: bytearray(64)) + + def get_message(self) -> ToTask: + """ + Get a message from the parent. + + This will block until the message has been received. + """ + line = self.input.readline() + try: + msg = self.decoder.decode(line) + except Exception: + structlog.get_logger(logger_name="CommsDecoder").exception("Unable to decode message", line=line) + raise + + if isinstance(msg, StartupDetails): + # If we read a startup message, pull out the FDs we care about! + if msg.requests_fd > 0: + self.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0) + return msg + + def send_request(self, log: Logger, msg: ToSupervisor): + buffer = self.buffer + self.encoder.encode_into(msg, buffer) + buffer += b"\n" + + log.debug("Sending request", json=buffer) + self.request_socket.write(buffer) + + +# This global variable will be used by Connection/Variable classes etc to send requests to +SUPERVISOR_COMMS: CommsDecoder + +# State machine! +# 1. Start up (receive details from supervisor) +# 2. Execution (run task code, possibly send requests) +# 3. Shutdown and report status + + +def startup() -> tuple[RuntimeTaskInstance, Logger]: + msg = SUPERVISOR_COMMS.get_message() + + if isinstance(msg, StartupDetails): + log = structlog.get_logger(logger_name="task") + # TODO: set the "magic loop" context vars for parsing + ti = parse(msg) + log.debug("DAG file parsed", file=msg.file) + return ti, log + else: + raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") + + # TODO: Render fields here + + +def run(ti: RuntimeTaskInstance, log: Logger): + """Run the task in this process.""" + from airflow.exceptions import ( + AirflowException, + AirflowFailException, + AirflowRescheduleException, + AirflowSensorTimeout, + AirflowSkipException, + AirflowTaskTerminated, + AirflowTaskTimeout, + TaskDeferred, + ) + + if TYPE_CHECKING: + assert ti.task is not None + assert isinstance(ti.task, BaseOperator) + try: + # TODO: pre execute etc. + # TODO next_method to support resuming from deferred + # TODO: Get a real context object + ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined] + except TaskDeferred: + ... + except AirflowSkipException: + ... + except AirflowRescheduleException: + ... + except (AirflowFailException, AirflowSensorTimeout): + # If AirflowFailException is raised, task should not retry. + ... + except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated): + ... + except SystemExit: + ... + except BaseException: + ... + + +def finalize(log: Logger): ... + + +def main(): + # TODO: add an exception here, it causes an oof of a stack trace! + + global SUPERVISOR_COMMS + SUPERVISOR_COMMS = CommsDecoder() + try: + ti, log = startup() + run(ti, log) + finalize(log) + except KeyboardInterrupt: + log = structlog.get_logger(logger_name="task") + log.exception("Ctrl-c hit") + exit(2) + except Exception: + log = structlog.get_logger(logger_name="task") + log.exception("Top level error") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py new file mode 100644 index 0000000000..624e3f1676 --- /dev/null +++ b/task_sdk/src/airflow/sdk/log.py @@ -0,0 +1,377 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +import logging.config +import os +import sys +import warnings +from functools import cache +from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, TypeVar + +import msgspec +import structlog + +if TYPE_CHECKING: + from structlog.typing import EventDict, ExcInfo, Processor + + +__all__ = [ + "configure_logging", + "reset_logging", +] + + +def exception_group_tracebacks(format_exception: Callable[[ExcInfo], list[dict[str, Any]]]) -> Processor: + # Make mypy happy + if not hasattr(__builtins__, "BaseExceptionGroup"): + T = TypeVar("T") + + class BaseExceptionGroup(Generic[T]): + exceptions: list[T] + + def _exception_group_tracebacks(logger: Any, method_name: Any, event_dict: EventDict) -> EventDict: + if exc_info := event_dict.get("exc_info", None): + group: BaseExceptionGroup[Exception] | None = None + if exc_info is True: + # `log.exception('mesg")` case + exc_info = sys.exc_info() + if exc_info[0] is None: + exc_info = None + + if ( + isinstance(exc_info, tuple) + and len(exc_info) == 3 + and isinstance(exc_info[1], BaseExceptionGroup) + ): + group = exc_info[1] + elif isinstance(exc_info, BaseExceptionGroup): + group = exc_info + + if group: + # Only remove it from event_dict if we handle it + del event_dict["exc_info"] + event_dict["exception"] = list( + itertools.chain.from_iterable( + format_exception((type(exc), exc, exc.__traceback__)) # type: ignore[attr-defined,arg-type] + for exc in (*group.exceptions, group) + ) + ) + + return event_dict + + return _exception_group_tracebacks + + +def logger_name(logger: Any, method_name: Any, event_dict: EventDict) -> EventDict: + if logger_name := event_dict.pop("logger_name", None): + event_dict.setdefault("logger", logger_name) + return event_dict + + +def redact_jwt(logger: Any, method_name: str, event_dict: EventDict) -> EventDict: + for k, v in event_dict.items(): + if isinstance(v, str) and v.startswith("eyJ"): + event_dict[k] = "eyJ***" + return event_dict + + +def drop_positional_args(logger: Any, method_name: Any, event_dict: EventDict) -> EventDict: + event_dict.pop("positional_args", None) + return event_dict + + +def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: + """Encode event into JSON format.""" + return msgspec.json.encode(event_dict).decode("ascii") + + +class StdBinaryStreamHandler(logging.StreamHandler): + """A logging.StreamHandler that sends logs as binary JSON over the given stream.""" + + stream: BinaryIO + + def __init__(self, stream: BinaryIO): + super().__init__(stream) + + def emit(self, record: logging.LogRecord): + try: + msg = self.format(record) + buffer = bytearray(msg, "ascii", "backslashreplace") + + buffer += b"\n" + + stream = self.stream + stream.write(buffer) + self.flush() + except RecursionError: # See issue 36272 + raise + except Exception: + self.handleError(record) + + +@cache +def logging_processors( + enable_pretty_log: bool, +): + if enable_pretty_log: + timestamper = structlog.processors.MaybeTimeStamper(fmt="%Y-%m-%d %H:%M:%S.%f") + else: + timestamper = structlog.processors.MaybeTimeStamper(fmt="iso") + + processors: list[structlog.typing.Processor] = [ + timestamper, + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.stdlib.PositionalArgumentsFormatter(), + logger_name, + redact_jwt, + structlog.processors.StackInfoRenderer(), + ] + + if enable_pretty_log: + # Imports to suppress showing code from these modules + import asyncio + import contextlib + + import click + import httpcore + import httpx + import typer + + rich_exc_formatter = structlog.dev.RichTracebackFormatter( + extra_lines=0, + max_frames=30, + indent_guides=False, + suppress=[asyncio, httpcore, httpx, contextlib, click, typer], + ) + my_styles = structlog.dev.ConsoleRenderer.get_default_level_styles() + my_styles["debug"] = structlog.dev.CYAN + + console = structlog.dev.ConsoleRenderer( + exception_formatter=rich_exc_formatter, level_styles=my_styles + ) + processors.append(console) + return processors, { + "timestamper": timestamper, + "console": console, + } + else: + # Imports to suppress showing code from these modules + import asyncio + import contextlib + + import click + import httpcore + import httpx + import typer + + dict_exc_formatter = structlog.tracebacks.ExceptionDictTransformer( + use_rich=False, show_locals=False, suppress=(click, typer) + ) + + dict_tracebacks = structlog.processors.ExceptionRenderer( + structlog.tracebacks.ExceptionDictTransformer( + use_rich=False, show_locals=False, suppress=(click, typer) + ) + ) + if hasattr(__builtins__, "BaseExceptionGroup"): + exc_group_processor = exception_group_tracebacks(dict_exc_formatter) + processors.append(exc_group_processor) + else: + exc_group_processor = None + + encoder = msgspec.json.Encoder() + + def json_dumps(msg, default): + return encoder.encode(msg) + + def json_processor(logger: Any, method_name: Any, event_dict: EventDict) -> str: + # import web_pdb + + # web_pdb.set_trace() + return encoder.encode(event_dict).decode("ascii") + + json = structlog.processors.JSONRenderer(serializer=json_dumps) + + processors.extend( + ( + dict_tracebacks, + structlog.processors.UnicodeDecoder(), + json, + ), + ) + + return processors, { + "timestamper": timestamper, + "exc_group_processor": exc_group_processor, + "dict_tracebacks": dict_tracebacks, + "json": json_processor, + } + + +@cache +def configure_logging( + enable_pretty_log: bool = True, + log_level: str = "DEBUG", + output: BinaryIO | None = None, + cache_logger_on_first_use: bool = True, +): + """Set up struct logging and stdlib logging config.""" + if enable_pretty_log and output is not None: + raise ValueError("output can only be set if enable_pretty_log is not") + + lvl = structlog.stdlib.NAME_TO_LEVEL[log_level.lower()] + + if enable_pretty_log: + formatter = "colored" + else: + formatter = "plain" + processors, named = logging_processors(enable_pretty_log) + timestamper = named["timestamper"] + + pre_chain: list[structlog.typing.Processor] = [ + # Add the log level and a timestamp to the event_dict if the log entry + # is not from structlog. + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + timestamper, + ] + + # Don't cache the loggers during tests, it make it hard to capture them + if "PYTEST_CURRENT_TEST" in os.environ: + cache_logger_on_first_use = False + + color_formatter: list[structlog.typing.Processor] = [ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + drop_positional_args, + ] + std_lib_formatter: list[structlog.typing.Processor] = [ + structlog.stdlib.ProcessorFormatter.remove_processors_meta, + drop_positional_args, + ] + + wrapper_class = structlog.make_filtering_bound_logger(lvl) + if enable_pretty_log: + structlog.configure( + processors=processors, + cache_logger_on_first_use=cache_logger_on_first_use, + wrapper_class=wrapper_class, + ) + color_formatter.append(named["console"]) + else: + structlog.configure( + processors=processors, + cache_logger_on_first_use=cache_logger_on_first_use, + wrapper_class=wrapper_class, + logger_factory=structlog.BytesLoggerFactory(output), + ) + + if processor := named["exc_group_processor"]: + pre_chain.append(processor) + pre_chain.append(named["dict_tracebacks"]) + color_formatter.append(named["json"]) + std_lib_formatter.append(named["json"]) + + global _warnings_showwarning + _warnings_showwarning = warnings.showwarning + # Capture warnings and show them via structlog + warnings.showwarning = _showwarning + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "plain": { + "()": structlog.stdlib.ProcessorFormatter, + "processors": std_lib_formatter, + "foreign_pre_chain": pre_chain, + "pass_foreign_args": True, + }, + "colored": { + "()": structlog.stdlib.ProcessorFormatter, + "processors": color_formatter, + "foreign_pre_chain": pre_chain, + "pass_foreign_args": True, + }, + }, + "handlers": { + "default": { + "level": log_level.upper(), + "class": "logging.StreamHandler", + "formatter": formatter, + }, + "to_supervisor": { + "level": log_level.upper(), + "()": StdBinaryStreamHandler, + "formatter": formatter, + "stream": output, + }, + }, + "loggers": { + "": { + "handlers": ["to_supervisor" if output else "default"], + "level": log_level.upper(), + "propagate": True, + }, + # Some modules we _never_ want at debug level + "asyncio": {"level": "INFO"}, + "alembic": {"level": "INFO"}, + "httpcore": {"level": "INFO"}, + "httpx": {"level": "WARN"}, + "psycopg.pq": {"level": "INFO"}, + "sqlalchemy.engine": {"level": "WARN"}, + }, + } + ) + + +def reset_logging(): + global _warnings_showwarning + warnings.showwarning = _warnings_showwarning + configure_logging.cache_clear() + + +_warnings_showwarning = None + + +def _showwarning( + message: str | Warning, + category: type[Warning], + filename: str, + lineno: int, + file: TextIO | None = None, + line: str | None = None, +): + """ + Redirects warnings to structlog so they appear in task logs etc. + + Implementation of showwarnings which redirects to logging, which will first + check to see if the file parameter is None. If a file is specified, it will + delegate to the original warnings implementation of showwarning. Otherwise, + it will call warnings.formatwarning and will log the resulting string to a + warnings logger named "py.warnings" with level logging.WARNING. + """ + if file is not None: + if _warnings_showwarning is not None: + _warnings_showwarning(message, category, filename, lineno, file, line) + else: + log = structlog.get_logger(logger_name="py.warnings") + log.warning(str(message), category=category.__name__, filename=filename, lineno=lineno) diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 232d08e27f..ffde2170b1 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -57,7 +57,7 @@ if TYPE_CHECKING: Logger = logging.Logger else: - class Logger: ... # noqa: D101 + class Logger: ... def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, Any]) -> None: diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index ddc7c61656..d8a58219fd 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -17,6 +17,7 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING, NoReturn import pytest @@ -25,7 +26,64 @@ pytest_plugins = "tests_common.pytest_plugin" # Task SDK does not need access to the Airflow database os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" +if TYPE_CHECKING: + from structlog.typing import EventDict, WrappedLogger + + [email protected]() +def pytest_addhooks(pluginmanager: pytest.PytestPluginManager): + # Python 3.12 starts warning about mixing os.fork + Threads, and the pytest-rerunfailures plugin uses + # threads internally. Since this is new code, and it should be flake free, we disable the re-run failures + # plugin early (so that it doesn't run it's pytest_configure which is where the thread starts up if xdist + # is discovered). + pluginmanager.set_blocked("rerunfailures") + @pytest.hookimpl(tryfirst=True) def pytest_configure(config: pytest.Config) -> None: config.inicfg["airflow_deprecations_ignore"] = [] + + +class LogCapture: + # Like structlog.typing.LogCapture, but that doesn't add log_level in to the event dict + entries: list[EventDict] + + def __init__(self) -> None: + self.entries = [] + + def __call__(self, _: WrappedLogger, method_name: str, event_dict: EventDict) -> NoReturn: + from structlog.exceptions import DropEvent + + if "level" not in event_dict: + event_dict["_log_level"] = method_name + + self.entries.append(event_dict) + + raise DropEvent + + [email protected] +def captured_logs(): + import structlog + + from airflow.sdk.log import configure_logging, reset_logging + + # Use our real log config + reset_logging() + configure_logging(enable_pretty_log=False) + + # But we need to replace remove the last processor (the one that turns JSON into text, as we want the + # event dict for tests) + cur_processors = structlog.get_config()["processors"] + processors = cur_processors.copy() + proc = processors.pop() + assert isinstance( + proc, structlog.dev.ConsoleRenderer | structlog.processors.JSONRenderer + ), "Pre-condition" + try: + cap = LogCapture() + processors.append(cap) + structlog.configure(processors=processors) + yield cap.entries + finally: + structlog.configure(processors=cur_processors) diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index 427d1ee0e3..19035319cd 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -29,6 +29,25 @@ from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _U DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) [email protected](autouse=True, scope="module") +def _disable_ol_plugin(): + # The OpenLineage plugin imports setproctitle, and that now causes (C) level thread calls, which on Py + # 3.12+ issues a warning when os.fork happens. So for this plugin we disable it + + # And we load plugins when setting the priorty_weight field + import airflow.plugins_manager + + old = airflow.plugins_manager.plugins + + assert old is None, "Plugins already loaded, too late to stop them being loaded!" + + airflow.plugins_manager.plugins = [] + + yield + + airflow.plugins_manager.plugins = None + + # Essentially similar to airflow.models.baseoperator.BaseOperator class FakeOperator(metaclass=BaseOperatorMeta): def __init__(self, test_param, params=None, default_args=None): diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/execution_time/__init__.py similarity index 69% copy from task_sdk/tests/conftest.py copy to task_sdk/tests/execution_time/__init__.py index ddc7c61656..13a83393a9 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/execution_time/__init__.py @@ -14,18 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - -import os - -import pytest - -pytest_plugins = "tests_common.pytest_plugin" - -# Task SDK does not need access to the Airflow database -os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true" - - [email protected](tryfirst=True) -def pytest_configure(config: pytest.Config) -> None: - config.inicfg["airflow_deprecations_ignore"] = [] diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py new file mode 100644 index 0000000000..1e372b0561 --- /dev/null +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import os +import signal +import sys +from unittest.mock import MagicMock + +import pytest +import structlog +import structlog.testing + +import airflow.sdk.api.client +from airflow.sdk.api.datamodels.ti import TaskInstance +from airflow.sdk.execution_time.supervisor import WatchedSubprocess +from airflow.utils import timezone as tz + + [email protected] +def disable_capturing(capsys): + with capsys.disabled(): + yield + + +class TestWatchedSubprocess: + @pytest.mark.usefixtures("disable_capturing") + def test_reading_from_pipes(self, captured_logs, time_machine): + # Ignore anything lower than INFO for this test. Captured_logs resets things for us afterwards + structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO)) + + def subprocess_main(): + # This is run in the subprocess! + + # Flush calls are to ensure ordering of output for predictable tests + import logging + import warnings + + print("I'm a short message") + sys.stdout.write("Message ") + sys.stdout.write("split across two writes\n") + sys.stdout.flush() + + print("stderr message", file=sys.stderr) + sys.stderr.flush() + + logging.getLogger("airflow.foobar").error("An error message") + + warnings.warn("Warning should be captured too", stacklevel=1) + + instant = tz.datetime(2024, 11, 7, 12, 34, 56, 78901) + time_machine.move_to(instant, tick=False) + + proc = WatchedSubprocess.start( + path=os.devnull, + ti=TaskInstance( + id="a", + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + ), + client=MagicMock(spec=airflow.sdk.api.client.Client), + target=subprocess_main, + ) + + rc = proc.wait() + + assert rc == 0 + assert captured_logs == [ + { + "chan": "stdout", + "event": "I'm a short message", + "level": "info", + "logger": "task", + "timestamp": "2024-11-07T12:34:56.078901Z", + }, + { + "chan": "stdout", + "event": "Message split across two writes", + "level": "info", + "logger": "task", + "timestamp": "2024-11-07T12:34:56.078901Z", + }, + { + "chan": "stderr", + "event": "stderr message", + "level": "error", + "logger": "task", + "timestamp": "2024-11-07T12:34:56.078901Z", + }, + { + "event": "An error message", + "level": "error", + "logger": "airflow.foobar", + "timestamp": instant.replace(tzinfo=None), + }, + { + "category": "UserWarning", + "event": "Warning should be captured too", + "filename": __file__, + "level": "warning", + "lineno": 65, + "logger": "py.warnings", + "timestamp": instant.replace(tzinfo=None), + }, + ] + + @pytest.mark.usefixtures("disable_capturing") + def test_subprocess_sigkilled(self): + def subprocess_main(): + # This is run in the subprocess! + os.kill(0, signal.SIGKILL) + + proc = WatchedSubprocess.start( + path=os.devnull, + ti=TaskInstance( + id="a", + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + ), + client=MagicMock(spec=airflow.sdk.api.client.Client), + target=subprocess_main, + ) + + rc = proc.wait() + + assert rc == -9 diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py new file mode 100644 index 0000000000..3ef4d55656 --- /dev/null +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from socket import socketpair + +from airflow.sdk.execution_time.comms import StartupDetails +from airflow.sdk.execution_time.task_runner import CommsDecoder + + +class TestCommsDecoder: + """Test the communication between the subprocess and the "supervisor".""" + + def test_recv_StartupDetails(self): + r, w = socketpair() + + w.makefile("wb").write( + b'{"type":"StartupDetails", "ti": {' + b'"id": "4d828a62-a417-4936-a7a6-2b3fabacecab", "task_id": "a", "try_number": 1, "run_id": "b", "dag_id": "c" }, ' + b'"file": "/dev/null", "requests_fd": 4' + b"}\n" + ) + + decoder = CommsDecoder(input=r.makefile("r")) + + msg = decoder.get_message() + assert isinstance(msg, StartupDetails) + assert msg.ti.id == uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + assert msg.ti.task_id == "a" + assert msg.ti.dag_id == "c" + assert msg.file == "/dev/null" + + # Since this was a StartupDetails message, the decoder should open the other socket + assert decoder.request_socket.writable() + assert decoder.request_socket.fileno() == 4
