This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch localexecutor-uses-task-sdk in repository https://gitbox.apache.org/repos/asf/airflow.git
commit c22e1ce1e3433bf9b8a458636361b66440f0e38b Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Wed Nov 20 17:49:45 2024 +0000 WIP: LocalExecutor runs task using new Task SDK supervisor --- airflow/config_templates/config.yml | 2 +- airflow/executors/base_executor.py | 61 ++++++++++- airflow/executors/local_executor.py | 82 +++++++++------ airflow/jobs/scheduler_job_runner.py | 8 +- airflow/utils/helpers.py | 2 +- task_sdk/pyproject.toml | 1 + .../src/airflow/sdk/execution_time/supervisor.py | 112 +++++++++++++++++---- .../src/airflow/sdk/execution_time/task_runner.py | 12 ++- task_sdk/src/airflow/sdk/log.py | 18 ++-- 9 files changed, 234 insertions(+), 64 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 4a7fe5f1897..c7a4ce159cc 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -908,7 +908,7 @@ logging: is_template: true default: "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/\ {%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% endif %%}\ - attempt={{ try_number }}.log" + attempt={{ ti.try_number }}.log" log_processor_filename_template: description: | Formatting for how airflow generates file names for log diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index d24fd4f1516..bab5642fa0b 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -19,13 +19,17 @@ from __future__ import annotations import logging +import os import sys from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple +from enum import Enum, auto +from functools import cache +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Sequence, Tuple, overload import pendulum from deprecated import deprecated +from pydantic import BaseModel from airflow.cli.cli_config import DefaultHelpParser from airflow.configuration import conf @@ -70,6 +74,9 @@ if TYPE_CHECKING: # Task tuple to send to be executed TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]] +# This feels wrong importing from the fastapi app here!!!! +from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance as SdkTI + log = logging.getLogger(__name__) @@ -106,6 +113,50 @@ class RunningRetryAttemptType: return True +class ExecuteTaskActivity(BaseModel): # noqa: D101 + ti: SdkTI + dag_path: os.PathLike[str] + token: str + """The identity token for this workload""" + + log_filename_suffix: str + + @classmethod + def from_ti(cls, ti: TaskInstance) -> ExecuteTaskActivity: + from pathlib import Path + + serde_ti = SdkTI.model_validate(ti, from_attributes=True) + path = Path(ti.dag_run.dag_model.relative_fileloc) + + if path and not path.is_absolute(): + path = "DAGS_FOLDER" / path + + # Quasi-breaking change for format strings. Lets just deprecate those + fname = cls.log_filename_template()(ti=ti, try_number=ti.try_number) + return ExecuteTaskActivity(ti=serde_ti, dag_path=path, token="", log_filename_suffix=fname) + + # Parse the string once, and cache the rendered template + @classmethod + @cache + def log_filename_template(cls) -> Callable[..., str]: + # TODO: We should probably use the version from ti.dag_run.log_template_id + + # For now we always return the active log template + template = conf.get("logging", "log_filename_template") + + if "{{" in template: + import jinja2 + + return jinja2.Template(template).render + else: + return template.format + + +class ExecuteActivityKind(Enum): # noqa: D101 + TASK_INSTANCE = auto() + DAG_CALLBACK = auto() + + class BaseExecutor(LoggingMixin): """ Base class to inherit for concrete executors such as Celery, Kubernetes, Local, Sequential, etc. @@ -169,6 +220,14 @@ class BaseExecutor(LoggingMixin): else: self.log.error("could not queue task %s", task_instance.key) + @overload + def queue_activity(self, *, kind: Literal[ExecuteActivityKind.TASK_INSTANCE], context: TaskInstance): ... + @overload + def queue_activity(self, *, kind: Literal[ExecuteActivityKind.DAG_CALLBACK], context: dict[str, Any]): ... + + def queue_activity(self, *, kind: ExecuteActivityKind, context: Any): + raise ValueError(f"Un-handled activity kind {kind!r} in {type(self).__name__}") + def queue_task_instance( self, task_instance: TaskInstance, diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 3b8b52176db..c4cd260d336 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -32,14 +32,18 @@ import multiprocessing.sharedctypes import os import subprocess from multiprocessing import Queue, SimpleQueue -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, overload from setproctitle import setproctitle from airflow import settings -from airflow.executors.base_executor import PARALLELISM, BaseExecutor +from airflow.executors.base_executor import ( + PARALLELISM, + BaseExecutor, + ExecuteActivityKind, + ExecuteTaskActivity, +) from airflow.traces.tracer import add_span -from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -55,26 +59,29 @@ if TYPE_CHECKING: def _run_worker( logger_name: str, - input: SimpleQueue[ExecutorWorkType], + input: SimpleQueue[ExecuteTaskActivity | None], output: Queue[TaskInstanceStateType], unread_messages: multiprocessing.sharedctypes.Synchronized[int], ): + import os import signal + os.environ["PUDB_TTY"] = "/dev/ttys006" + # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion signal.signal(signal.SIGINT, signal.SIG_IGN) log = logging.getLogger(logger_name) + log.info("Worker starting up pid=%d", os.getpid()) # We know we've just started a new process, so lets disconnect from the metadata db now settings.engine.pool.dispose() settings.engine.dispose() - setproctitle("airflow worker -- LocalExecutor: <idle>") - while True: + setproctitle("airflow worker -- LocalExecutor: <idle>") try: - item = input.get() + activity = input.get() except EOFError: log.info( "Failed to read tasks from the task queue because the other " @@ -83,7 +90,7 @@ def _run_worker( ) break - if item is None: + if activity is None: # Received poison pill, no more tasks to run return @@ -91,33 +98,33 @@ def _run_worker( with unread_messages: unread_messages.value -= 1 - (key, command) = item try: - state = _execute_work(log, key, command) + state = _execute_work(log, activity) - output.put((key, state, None)) + # output.put((key, state, None)) + output.put((None, TaskInstanceState.SUCCESS, None)) except Exception as e: - output.put((key, TaskInstanceState.FAILED, e)) + log.exception("uhoh") + output.put((None, TaskInstanceState.FAILED, e)) -def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: CommandType) -> TaskInstanceState: +def _execute_work(log: logging.Logger, activity: ExecuteTaskActivity) -> None: """ Execute command received and stores result state in queue. :param key: the key to identify the task instance :param command: the command to execute """ - setproctitle(f"airflow worker -- LocalExecutor: {command}") - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) - try: - with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - return _execute_work_in_subprocess(log, command) - else: - return _execute_work_in_fork(log, command) - finally: - # Remove the command since the worker is done executing the task - setproctitle("airflow worker -- LocalExecutor: <idle>") + from airflow.sdk.execution_time.supervisor import supervise + + setproctitle(f"airflow worker -- LocalExecutor: {activity.ti.id}") + supervise( + ti=activity.ti, + dag_path=activity.dag_path, + token=activity.token, + server="http://localhost:9091/execution/", + log_filename_suffix=activity.log_filename_suffix, + ) def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> TaskInstanceState: @@ -180,7 +187,7 @@ class LocalExecutor(BaseExecutor): serve_logs: bool = True - activity_queue: SimpleQueue[ExecutorWorkType] + activity_queue: SimpleQueue[ExecuteTaskActivity | None] result_queue: SimpleQueue[TaskInstanceStateType] workers: dict[int, multiprocessing.Process] _unread_messages: multiprocessing.sharedctypes.Synchronized[int] @@ -216,9 +223,9 @@ class LocalExecutor(BaseExecutor): self.activity_queue.put((key, command)) with self._unread_messages: self._unread_messages.value += 1 - self._check_workers(can_start=True) + self._check_workers() - def _check_workers(self, can_start: bool = True): + def _check_workers(self): # Reap any dead workers to_remove = set() for pid, proc in self.workers.items(): @@ -276,7 +283,7 @@ class LocalExecutor(BaseExecutor): exc.add_note("(This stacktrace is incorrect -- the exception came from a subprocess)") raise exc - self.change_state(key, state) + # self.change_state(key, state) def end(self) -> None: """End the executor.""" @@ -306,3 +313,22 @@ class LocalExecutor(BaseExecutor): def terminate(self): """Terminate the executor is not doing anything.""" + + +class TooLocalTooAPIExecutor(LocalExecutor): + """TooLocalTooAPIExecutor implementation to run airflow commands.""" + + @overload + def queue_activity(self, *, kind: Literal[ExecuteActivityKind.TASK_INSTANCE], context: TaskInstance): ... + @overload + def queue_activity(self, *, kind: Literal[ExecuteActivityKind.DAG_CALLBACK], context: dict[str, Any]): ... + + def queue_activity(self, *, kind: ExecuteActivityKind, context: TaskInstance | dict[str, Any]): + if kind != ExecuteActivityKind.TASK_INSTANCE: + return super().queue_activity(kind, context) + + activity = ExecuteTaskActivity.from_ti(context) + self.activity_queue.put(activity) + with self._unread_messages: + self._unread_messages.value += 1 + self._check_workers() diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 96a09eb99bb..63d02060105 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -43,6 +43,7 @@ from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallback from airflow.callbacks.pipe_callback_sink import PipeCallbackSink from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning, UnknownExecutorException +from airflow.executors.base_executor import BaseExecutor, ExecuteActivityKind from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import Job, perform_heartbeat @@ -90,7 +91,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Query, Session from airflow.dag_processing.manager import DagFileProcessorAgent - from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.sqlalchemy import ( @@ -645,6 +645,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): if ti.dag_run.state in State.finished_dr_states: ti.set_state(None, session=session) continue + + # Has a real queue_activity implemented + if executor.queue_activity is not BaseExecutor.queue_activity: + executor.queue_activity(kind=ExecuteActivityKind.TASK_INSTANCE, context=ti) + continue + command = ti.command_as_list( local=True, ) diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 1a6b7396e32..73598123d4d 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -170,7 +170,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]: return [e for i in iterable for e in i] -def parse_template_string(template_string: str) -> tuple[str | None, jinja2.Template | None]: +def parse_template_string(template_string: str) -> tuple[str, None] | tuple[None, jinja2.Template]: """Parse Jinja template string.""" import jinja2 diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 5da673a79bf..e338ecb9e0c 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "attrs>=24.2.0", "google-re2>=1.1.20240702", "httpx>=0.27.0", + "jinja2>=3.1.4", "methodtools>=0.4.7", "msgspec>=0.18.6", "psutil>=6.1.0", diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 1d72f7be633..85eb86f43ff 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -31,8 +31,9 @@ import weakref from collections.abc import Generator from contextlib import suppress from datetime import datetime, timezone +from pathlib import Path from socket import socket, socketpair -from typing import TYPE_CHECKING, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload +from typing import TYPE_CHECKING, Any, BinaryIO, Callable, ClassVar, Literal, NoReturn, cast, overload from uuid import UUID import attrs @@ -52,9 +53,8 @@ from airflow.sdk.execution_time.comms import ( ) if TYPE_CHECKING: - from structlog.typing import FilteringBoundLogger + from structlog.typing import FilteringBoundLogger, WrappedLogger - from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity __all__ = ["WatchedSubprocess", "supervise"] @@ -140,8 +140,8 @@ def _reopen_std_io_handles(child_stdin, child_stdout, child_stderr): try: fd = handle.fileno() os.dup2(sock.fileno(), fd) - if close: - handle.close() + # We now have two open copies of the socket, we need to close one! + sock.close() except io.UnsupportedOperation: if "PYTEST_CURRENT_TEST" in os.environ: # When we're running under pytest, the stdin is not a real filehandle with an fd, so we need @@ -269,6 +269,7 @@ class WatchedSubprocess: ti: TaskInstance, client: Client, target: Callable[[], None] = _subprocess_main, + logger: FilteringBoundLogger | None = None, ) -> WatchedSubprocess: """Fork and start a new subprocess to execute the given task.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess @@ -284,6 +285,10 @@ class WatchedSubprocess: if pid == 0: # Parent ends of the sockets are closed by the OS as they are set as non-inheritable + del path + del client + del ti + # Run the child entrypoint _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) @@ -308,17 +313,20 @@ class WatchedSubprocess: proc.kill(signal.SIGKILL) raise - proc._register_pipes(read_msgs, read_logs) + if logger is None: + logger = structlog.get_logger(logger_name="task").bind() + + proc._register_pipes(logger, read_msgs, read_logs) # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the # other end of the pair open - proc._close_unused_sockets(child_stdout, child_stdin, child_comms, child_logs) + proc._close_unused_sockets(child_stdout, child_stdin, child_stderr, child_comms, child_logs) # Tell the task process what it needs to do! proc._send_startup_message(ti, path, child_comms) return proc - def _register_pipes(self, read_msgs, read_logs): + def _register_pipes(self, logger: FilteringBoundLogger, read_msgs, read_logs): """Register handlers for subprocess communication channels.""" # self.selector is a way of registering a handler/callback to be called when the given IO channel has # activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better @@ -326,23 +334,24 @@ class WatchedSubprocess: # needing full async, to read and process output from each socket as it is received. # TODO: Use logging providers to handle the chunked upload for us - logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() self.selector.register( - self.stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout") + self.stdout, selectors.EVENT_READ, (self._create_socket_handler(logger, "stdout"), "stdout") ) self.selector.register( self.stderr, selectors.EVENT_READ, - self._create_socket_handler(logger, "stderr", log_level=logging.ERROR), + (self._create_socket_handler(logger, "stderr", log_level=logging.ERROR), "stderr"), ) self.selector.register( read_logs, selectors.EVENT_READ, - make_buffered_socket_reader(process_log_messages_from_subprocess(logger)), + (make_buffered_socket_reader(process_log_messages_from_subprocess(logger)), "read_logs"), ) self.selector.register( - read_msgs, selectors.EVENT_READ, make_buffered_socket_reader(self.handle_requests(log)) + read_msgs, + selectors.EVENT_READ, + (make_buffered_socket_reader(self.handle_requests(log)), "read_msgs"), ) @staticmethod @@ -358,7 +367,7 @@ class WatchedSubprocess: def _send_startup_message(self, ti: TaskInstance, path: str | os.PathLike[str], child_comms: socket): """Send startup message to the subprocess.""" - msg = StartupDetails( + msg = StartupDetails.model_construct( ti=ti, file=str(path), requests_fd=child_comms.fileno(), @@ -422,7 +431,7 @@ class WatchedSubprocess: ) events = self.selector.select(timeout=max_wait_time) for key, _ in events: - socket_handler = key.data + (socket_handler, _) = key.data need_more = socket_handler(key.fileobj) if not need_more: @@ -606,25 +615,86 @@ def forward_to_log(target_log: FilteringBoundLogger, level: int) -> Generator[No target_log.log(level, msg) -def supervise(activity: ExecuteTaskActivity, server: str | None = None, dry_run: bool = False) -> int: +def _init_log_file(local_relative_path: str) -> Path: + """ + Create log directory and give it permissions that are configured. + + See above _prepare_log_folder method for more detailed explanation. + + :param ti: task instance object + :return: relative log path of the given task instance + """ + from airflow.configuration import conf + + new_file_permissions = int( + conf.get("logging", "file_task_handler_new_file_permissions", fallback="0o664"), 8 + ) + new_folder_permissions = int( + conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8 + ) + + base_log_folder = conf.get("logging", "base_log_folder") + full_path = Path(base_log_folder, local_relative_path) + + directory = full_path.parent + for parent in reversed(directory.parents): + parent.mkdir(mode=new_folder_permissions, exist_ok=True) + directory.mkdir(mode=new_folder_permissions, exist_ok=True) + + try: + full_path.touch(new_file_permissions) + except OSError as e: + log.warning("OSError while changing ownership of the log file. ", e) + + return full_path + + +def supervise( + *, + ti: Any, + dag_path: str | os.PathLike[str], + token: str, + server: str | None = None, + dry_run: bool = False, + log_filename_suffix: str | None = None, +) -> int: """ Run a single task execution to completion. Returns the exit code of the process """ # One or the other - if (server == "") ^ dry_run: + if (not server) ^ dry_run: raise ValueError(f"Can only specify one of {server=} or {dry_run=}") - if not activity.path: - raise ValueError("path filed of activity missing") + if not dag_path: + raise ValueError("dag_path is required") + + if (str_path := os.fspath(dag_path)).startswith("DAGS_FOLDER/"): + from airflow.settings import DAGS_FOLDER + + dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1) limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) - client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=activity.token) + client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) start = time.monotonic() - process = WatchedSubprocess.start(activity.path, activity.ti, client=client) + logger: FilteringBoundLogger | None = None + if log_filename_suffix: + from airflow.sdk.log import logging_processors + + log_file = _init_log_file(log_filename_suffix) + + pretty_logs = False + if pretty_logs: + underlying_logger: WrappedLogger = structlog.WriteLogger(log_file.open("w", buffering=1)) + else: + underlying_logger = structlog.BytesLogger(log_file.open("wb")) + processors = logging_processors(enable_pretty_log=pretty_logs)[0] + logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() + + process = WatchedSubprocess.start(dag_path, ti, client=client, logger=logger) exit_code = process.wait() end = time.monotonic() diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index e00efe4597b..a0c12b79984 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -22,11 +22,11 @@ from __future__ import annotations import os import sys from io import FileIO -from typing import TYPE_CHECKING, TextIO +from typing import TYPE_CHECKING, Annotated, TextIO import attrs import structlog -from pydantic import ConfigDict, TypeAdapter +from pydantic import ConfigDict, SkipValidation, TypeAdapter from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.definitions.baseoperator import BaseOperator @@ -37,9 +37,9 @@ if TYPE_CHECKING: class RuntimeTaskInstance(TaskInstance): - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True, defer_build=True) - task: BaseOperator + task: Annotated[BaseOperator, SkipValidation] def parse(what: StartupDetails) -> RuntimeTaskInstance: @@ -118,6 +118,10 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, StartupDetails): + from setproctitle import setproctitle + + setproctitle(f"airflow worker -- {msg.ti.id}") + log = structlog.get_logger(logger_name="task") # TODO: set the "magic loop" context vars for parsing ti = parse(msg) diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py index f8e06eda4a6..b120f32a594 100644 --- a/task_sdk/src/airflow/sdk/log.py +++ b/task_sdk/src/airflow/sdk/log.py @@ -285,9 +285,11 @@ def configure_logging( std_lib_formatter.append(named["json"]) global _warnings_showwarning - _warnings_showwarning = warnings.showwarning - # Capture warnings and show them via structlog - warnings.showwarning = _showwarning + + if _warnings_showwarning is None: + _warnings_showwarning = warnings.showwarning + # Capture warnings and show them via structlog + warnings.showwarning = _showwarning logging.config.dictConfig( { @@ -327,11 +329,13 @@ def configure_logging( "propagate": True, }, # Some modules we _never_ want at debug level - "asyncio": {"level": "INFO"}, "alembic": {"level": "INFO"}, + "asyncio": {"level": "INFO"}, + "cron_descriptor.GetText": {"level": "INFO"}, "httpcore": {"level": "INFO"}, "httpx": {"level": "WARN"}, "psycopg.pq": {"level": "INFO"}, + # These ones are too chatty even at info "sqlalchemy.engine": {"level": "WARN"}, }, } @@ -344,17 +348,17 @@ def reset_logging(): configure_logging.cache_clear() -_warnings_showwarning = None +_warnings_showwarning: Any = None def _showwarning( - message: str | Warning, + message: Warning | str, category: type[Warning], filename: str, lineno: int, file: TextIO | None = None, line: str | None = None, -): +) -> Any: """ Redirects warnings to structlog so they appear in task logs etc.
