This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch localexecutor-uses-task-sdk in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 5a7aca4d69f6baa9e821bb30c2d795bbf04a9ca1 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Wed Nov 20 17:49:45 2024 +0000 Convert the LocalExecutor to run tasks using new Task SDK supervisor code This also lays the groundwork for a more general purpose "workload" execution system, make a single interface for executors to run tasks and callbacks. Also in this PR we set up the supervise function to send Task logs to a file, and handle the task log template rendering in the scheduler before queueing the workload. Additionally we don't pass the activity directly to `supervise()` but instead the properties/fields of it to reduce the coupling between SDK and Executor. (More separation will appear in PRs over the next few weeks.) The big change of note here is that rather than sending an airflow command line to execute (`["airflow", "tasks", "run", ...]`) and going back in via the CLI parser we go directly to a special purpose function. Much simpler. It doesn't remove any of the old behaviour (CeleryExecutor still uses LocalTaskJob via the CLI parser etc.), nor does anything currently send callback requests via this new workload mechanism. The `airflow.executors.workloads` module currently needs to be shared between the Scheduler (or more specifically the Executor) and the "worker" side of things. In the future these will be separate python dists and this module will need to live somewhere else. Right now we check the if `executor.queue_workload` is different from the BaseExecutor version (which just raises an error right now) to see which executors support this new version. That check will be removed as soon as all the in-tree executors have been migrated. --- airflow/cli/commands/task_command.py | 32 +++-- airflow/config_templates/config.yml | 2 +- airflow/executors/base_executor.py | 7 ++ airflow/executors/local_executor.py | 140 ++++++--------------- airflow/executors/workloads.py | 98 +++++++++++++++ airflow/jobs/scheduler_job_runner.py | 11 +- airflow/utils/helpers.py | 25 +++- task_sdk/pyproject.toml | 1 + .../src/airflow/sdk/execution_time/supervisor.py | 71 ++++++++--- .../src/airflow/sdk/execution_time/task_runner.py | 4 + task_sdk/src/airflow/sdk/log.py | 81 ++++++++++-- task_sdk/tests/execution_time/test_supervisor.py | 19 ++- tests/executors/test_local_executor.py | 77 ++++++------ tests/jobs/test_scheduler_job.py | 6 + 14 files changed, 383 insertions(+), 191 deletions(-) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 55c1662d35d..3962592e9b2 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -300,6 +300,8 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None: This can result in the task being started by another host if the executor implementation does. """ + from airflow.executors.base_executor import BaseExecutor + if ti.executor: executor = ExecutorLoader.load_executor(ti.executor) else: @@ -307,17 +309,25 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None: executor.job_id = None executor.start() print("Sending to executor.") - executor.queue_task_instance( - ti, - mark_success=args.mark_success, - ignore_all_deps=args.ignore_all_dependencies, - ignore_depends_on_past=should_ignore_depends_on_past(args), - wait_for_past_depends_before_skipping=(args.depends_on_past == "wait"), - ignore_task_deps=args.ignore_dependencies, - ignore_ti_state=args.force, - pool=args.pool, - ) - executor.heartbeat() + + # TODO: Task-SDK: this is temporary while we migrate the other executors over + if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined] + from airflow.executors import workloads + + workload = workloads.ExecuteTask.make(ti, dag_path=dag.relative_fileloc) + executor.queue_workload(workload) + else: + executor.queue_task_instance( + ti, + mark_success=args.mark_success, + ignore_all_deps=args.ignore_all_dependencies, + ignore_depends_on_past=should_ignore_depends_on_past(args), + wait_for_past_depends_before_skipping=(args.depends_on_past == "wait"), + ignore_task_deps=args.ignore_dependencies, + ignore_ti_state=args.force, + pool=args.pool, + ) + executor.heartbeat() executor.end() diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 7c50f3b7cd3..5ca17e63ee2 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -908,7 +908,7 @@ logging: is_template: true default: "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/\ {%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% endif %%}\ - attempt={{ try_number }}.log" + attempt={{ try_number|default(ti.try_number) }}.log" log_processor_filename_template: description: | Formatting for how airflow generates file names for log diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 3e406e3d74d..b57f7458bbb 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -49,6 +49,7 @@ if TYPE_CHECKING: from airflow.callbacks.base_callback_sink import BaseCallbackSink from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import GroupCommand + from airflow.executors import workloads from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -170,6 +171,9 @@ class BaseExecutor(LoggingMixin): else: self.log.error("could not queue task %s", task_instance.key) + def queue_workload(self, workload: workloads.All) -> None: + raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}") + def queue_task_instance( self, task_instance: TaskInstance, @@ -409,6 +413,9 @@ class BaseExecutor(LoggingMixin): self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) self.running.add(key) + # TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did + # it die". It is possible for the task itself to finish with success, but the state of the task to be set + # to FAILED. By using TaskInstanceState enum here it confuses matters! def change_state( self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True ) -> None: diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 02e201f7b5b..397901a8ff4 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -30,32 +30,24 @@ import logging import multiprocessing import multiprocessing.sharedctypes import os -import subprocess from multiprocessing import Queue, SimpleQueue -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional from setproctitle import setproctitle from airflow import settings from airflow.executors.base_executor import PARALLELISM, BaseExecutor -from airflow.traces.tracer import add_span -from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from airflow.executors.base_executor import CommandType - from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.executors import workloads - # This is a work to be executed by a worker. - # It can Key and Command - but it can also be None, None which is actually a - # "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly. - ExecutorWorkType = Optional[tuple[TaskInstanceKey, CommandType]] - TaskInstanceStateType = tuple[TaskInstanceKey, TaskInstanceState, Optional[Exception]] + TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Optional[Exception]] def _run_worker( logger_name: str, - input: SimpleQueue[ExecutorWorkType], + input: SimpleQueue[workloads.All | None], output: Queue[TaskInstanceStateType], unread_messages: multiprocessing.sharedctypes.Synchronized[int], ): @@ -65,16 +57,16 @@ def _run_worker( signal.signal(signal.SIGINT, signal.SIG_IGN) log = logging.getLogger(logger_name) + log.info("Worker starting up pid=%d", os.getpid()) # We know we've just started a new process, so lets disconnect from the metadata db now settings.engine.pool.dispose() settings.engine.dispose() - setproctitle("airflow worker -- LocalExecutor: <idle>") - while True: + setproctitle("airflow worker -- LocalExecutor: <idle>") try: - item = input.get() + workload = input.get() except EOFError: log.info( "Failed to read tasks from the task queue because the other " @@ -83,88 +75,49 @@ def _run_worker( ) break - if item is None: + if workload is None: # Received poison pill, no more tasks to run return # Decrement this as soon as we pick up a message off the queue with unread_messages: unread_messages.value -= 1 + key = None + if ti := getattr(workload, "ti", None): + key = ti.key + else: + raise TypeError(f"Don't know how to get ti key from {type(workload).__name__}") - (key, command) = item try: - state = _execute_work(log, key, command) + _execute_work(log, workload) - output.put((key, state, None)) + output.put((key, TaskInstanceState.SUCCESS, None)) except Exception as e: + log.exception("uhoh") output.put((key, TaskInstanceState.FAILED, e)) -def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: CommandType) -> TaskInstanceState: +def _execute_work(log: logging.Logger, workload: workloads.ExecuteTask) -> None: """ Execute command received and stores result state in queue. :param key: the key to identify the task instance :param command: the command to execute """ - setproctitle(f"airflow worker -- LocalExecutor: {command}") - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) - try: - with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - return _execute_work_in_subprocess(log, command) - else: - return _execute_work_in_fork(log, command) - finally: - # Remove the command since the worker is done executing the task - setproctitle("airflow worker -- LocalExecutor: <idle>") - - -def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> TaskInstanceState: - try: - subprocess.check_call(command, close_fds=True) - return TaskInstanceState.SUCCESS - except subprocess.CalledProcessError as e: - log.error("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - - -def _execute_work_in_fork(log: logging.Logger, command: CommandType) -> TaskInstanceState: - pid = os.fork() - if pid: - # In parent, wait for the child - pid, ret = os.waitpid(pid, 0) - return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED - - from airflow.sentry import Sentry - - ret = 1 - try: - import signal - - from airflow.cli.cli_parser import get_parser - - signal.signal(signal.SIGINT, signal.SIG_IGN) - signal.signal(signal.SIGTERM, signal.SIG_DFL) - signal.signal(signal.SIGUSR2, signal.SIG_DFL) - - parser = get_parser() - # [1:] - remove "airflow" from the start of the command - args = parser.parse_args(command[1:]) - args.shut_down_logging = False - - setproctitle(f"airflow task supervisor: {command}") - - args.func(args) - ret = 0 - return TaskInstanceState.SUCCESS - except Exception as e: - log.exception("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - finally: - Sentry.flush() - logging.shutdown() - os._exit(ret) + from airflow.configuration import conf + from airflow.sdk.execution_time.supervisor import supervise + + setproctitle(f"airflow worker -- LocalExecutor: {workload.ti.id}") + # This will return the exit code of the task process, but we don't care about that, just if the + # _supervisor_ had an error reporting the state back (which will result in an exception.) + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + dag_path=workload.dag_path, + token=workload.token, + server=conf.get("workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"), + log_path=workload.log_path, + ) class LocalExecutor(BaseExecutor): @@ -180,7 +133,7 @@ class LocalExecutor(BaseExecutor): serve_logs: bool = True - activity_queue: SimpleQueue[ExecutorWorkType] + activity_queue: SimpleQueue[workloads.All | None] result_queue: SimpleQueue[TaskInstanceStateType] workers: dict[int, multiprocessing.Process] _unread_messages: multiprocessing.sharedctypes.Synchronized[int] @@ -203,22 +156,7 @@ class LocalExecutor(BaseExecutor): # (it looks like an int to python) self._unread_messages = multiprocessing.Value(ctypes.c_uint) # type: ignore[assignment] - @add_span - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: - """Execute asynchronously.""" - self.validate_airflow_tasks_run_command(command) - self.activity_queue.put((key, command)) - with self._unread_messages: - self._unread_messages.value += 1 - self._check_workers(can_start=True) - - def _check_workers(self, can_start: bool = True): + def _check_workers(self): # Reap any dead workers to_remove = set() for pid, proc in self.workers.items(): @@ -270,12 +208,6 @@ class LocalExecutor(BaseExecutor): while not self.result_queue.empty(): key, state, exc = self.result_queue.get() - if exc: - # TODO: This needs a better stacktrace, it appears from here - if hasattr(exc, "add_note"): - exc.add_note("(This stacktrace is incorrect -- the exception came from a subprocess)") - raise exc - self.change_state(key, state) def end(self) -> None: @@ -306,3 +238,9 @@ class LocalExecutor(BaseExecutor): def terminate(self): """Terminate the executor is not doing anything.""" + + def queue_workload(self, workload: workloads.All): + self.activity_queue.put(workload) + with self._unread_messages: + self._unread_messages.value += 1 + self._check_workers() diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py new file mode 100644 index 00000000000..0adb54cd6da --- /dev/null +++ b/airflow/executors/workloads.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Union + +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from airflow.models.taskinstance import TaskInstance as TIModel + from airflow.models.taskinstancekey import TaskInstanceKey + + +__all__ = [ + "All", + "ExecuteTask", +] + + +class BaseActivity(BaseModel): + token: str + """The identity token for this workload""" + + +class TaskInstance(BaseModel): + """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" + + id: uuid.UUID + + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int | None = None + + # TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across the codebase? + @property + def key(self) -> TaskInstanceKey: + from airflow.models.taskinstancekey import TaskInstanceKey + + return TaskInstanceKey( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + try_number=self.try_number, + map_index=-1 if self.map_index is None else self.map_index, + ) + + +class ExecuteTask(BaseActivity): + """Execute the given Task.""" + + ti: TaskInstance + """The TaskInstance to execute""" + dag_path: os.PathLike[str] + """The filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`)""" + + log_path: str | None + """The rendered relative log filename template the task logs should be written to""" + + kind: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") + + @classmethod + def make(cls, ti: TIModel, dag_path: Path | None = None) -> ExecuteTask: + from pathlib import Path + + from airflow.utils.helpers import log_filename_template_renderer + + ser_ti = TaskInstance.model_validate(ti, from_attributes=True) + + dag_path = dag_path or Path(ti.dag_run.dag_model.relative_fileloc) + + if dag_path and not dag_path.is_absolute(): + # TODO: What about multiple dag sub folders + dag_path = "DAGS_FOLDER" / dag_path + + fname = log_filename_template_renderer()(ti=ti) + return cls(ti=ser_ti, dag_path=dag_path, token="", log_path=fname) + + +All = Union[ExecuteTask] diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index da50deb31c9..56a65009e2b 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -44,6 +44,8 @@ from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallback from airflow.callbacks.pipe_callback_sink import PipeCallbackSink from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning +from airflow.executors import workloads +from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_loader import ExecutorLoader from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import Job, perform_heartbeat @@ -91,7 +93,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Query, Session from airflow.dag_processing.manager import DagFileProcessorAgent - from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.models.taskinstance import TaskInstanceKey from airflow.utils.sqlalchemy import ( @@ -654,6 +655,14 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): if ti.dag_run.state in State.finished_dr_states: ti.set_state(None, session=session) continue + + # TODO: Task-SDK: This check is transitionary. Remove once all executors are ported over. + # Has a real queue_activity implemented + if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined] + workload = workloads.ExecuteTask.make(ti) + executor.queue_workload(workload) + continue + command = ti.command_as_list( local=True, ) diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 4cfc62acdd2..30f8dde41af 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -23,7 +23,7 @@ import re import signal from collections.abc import Generator, Iterable, Mapping, MutableMapping from datetime import datetime -from functools import reduce +from functools import cache, reduce from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast from lazy_object_proxy import Proxy @@ -171,7 +171,7 @@ def as_flattened_list(iterable: Iterable[Iterable[T]]) -> list[T]: return [e for i in iterable for e in i] -def parse_template_string(template_string: str) -> tuple[str | None, jinja2.Template | None]: +def parse_template_string(template_string: str) -> tuple[str, None] | tuple[None, jinja2.Template]: """Parse Jinja template string.""" import jinja2 @@ -181,6 +181,27 @@ def parse_template_string(template_string: str) -> tuple[str | None, jinja2.Temp return template_string, None +@cache +def log_filename_template_renderer() -> Callable[..., str]: + template = conf.get("logging", "log_filename_template") + + if "{{" in template: + import jinja2 + + return jinja2.Template(template).render + else: + + def f_str_format(ti: TaskInstance, try_number: int | None = None): + return template.format( + dag_id=ti.dag_id, + task_id=ti.task_id, + logical_date=ti.logical_date.isoformat(), + try_number=try_number or ti.try_number, + ) + + return f_str_format + + def render_log_filename(ti: TaskInstance, try_number, filename_template) -> str: """ Given task instance, try_number, filename_template, return the rendered log filename. diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 170aff5ec24..77061faa37c 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "attrs>=24.2.0", "google-re2>=1.1.20240702", "httpx>=0.27.0", + "jinja2>=3.1.4", "methodtools>=0.4.7", "msgspec>=0.18.6", "psutil>=6.1.0", diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 6dfbe415058..4058e46513a 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -54,9 +54,8 @@ from airflow.sdk.execution_time.comms import ( ) if TYPE_CHECKING: - from structlog.typing import FilteringBoundLogger + from structlog.typing import FilteringBoundLogger, WrappedLogger - from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity __all__ = ["WatchedSubprocess", "supervise"] @@ -282,6 +281,7 @@ class WatchedSubprocess: ti: TaskInstance, client: Client, target: Callable[[], None] = _subprocess_main, + logger: FilteringBoundLogger | None = None, ) -> WatchedSubprocess: """Fork and start a new subprocess to execute the given task.""" # Create socketpairs/"pipes" to connect to the stdin and out from the subprocess @@ -297,6 +297,13 @@ class WatchedSubprocess: if pid == 0: # Parent ends of the sockets are closed by the OS as they are set as non-inheritable + # Python GC should delete these for us, but lets make double sure that we don't keep anything + # around in the forked processes, especially things that might involve open files or sockets! + del path + del client + del ti + del logger + # Run the child entrypoint _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) @@ -319,8 +326,13 @@ class WatchedSubprocess: proc.kill(signal.SIGKILL) raise + logger = logger or cast("FilteringBoundLogger", structlog.get_logger(logger_name="task").bind()) proc._register_pipe_readers( - stdout=read_stdout, stderr=read_stderr, requests=read_msgs, logs=read_logs + logger=logger, + stdout=read_stdout, + stderr=read_stderr, + requests=read_msgs, + logs=read_logs, ) # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the @@ -331,16 +343,15 @@ class WatchedSubprocess: proc._send_startup_message(ti, path, child_comms) return proc - def _register_pipe_readers(self, stdout: socket, stderr: socket, requests: socket, logs: socket): + def _register_pipe_readers( + self, logger: FilteringBoundLogger, stdout: socket, stderr: socket, requests: socket, logs: socket + ): """Register handlers for subprocess communication channels.""" # self.selector is a way of registering a handler/callback to be called when the given IO channel has # activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better # alternatives are used automatically) -- this is a way of having "event-based" code, but without # needing full async, to read and process output from each socket as it is received. - # TODO: Use logging providers to handle the chunked upload for us - logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() - self.selector.register(stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout")) self.selector.register( stderr, @@ -369,7 +380,7 @@ class WatchedSubprocess: def _send_startup_message(self, ti: TaskInstance, path: str | os.PathLike[str], child_comms: socket): """Send startup message to the subprocess.""" - msg = StartupDetails( + msg = StartupDetails.model_construct( ti=ti, file=str(path), requests_fd=child_comms.fileno(), @@ -630,25 +641,57 @@ def forward_to_log(target_log: FilteringBoundLogger, level: int) -> Generator[No target_log.log(level, msg) -def supervise(activity: ExecuteTaskActivity, server: str | None = None, dry_run: bool = False) -> int: +def supervise( + *, + ti: TaskInstance, + dag_path: str | os.PathLike[str], + token: str, + server: str | None = None, + dry_run: bool = False, + log_path: str | None = None, +) -> int: """ Run a single task execution to completion. Returns the exit code of the process """ # One or the other - if (server == "") ^ dry_run: + if (not server) ^ dry_run: raise ValueError(f"Can only specify one of {server=} or {dry_run=}") - if not activity.path: - raise ValueError("path filed of activity missing") + if not dag_path: + raise ValueError("dag_path is required") + + if (str_path := os.fspath(dag_path)).startswith("DAGS_FOLDER/"): + from airflow.settings import DAGS_FOLDER + + dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1) limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) - client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=activity.token) + client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) start = time.monotonic() - process = WatchedSubprocess.start(activity.path, activity.ti, client=client) + # TODO: Use logging providers to handle the chunked upload for us etc. + logger: FilteringBoundLogger | None = None + if log_path: + # If we are told to write logs to a file, redirect the task logger to it. + from airflow.sdk.log import init_log_file, logging_processors + + try: + log_file = init_log_file(log_path) + except OSError as e: + log.warning("OSError while changing ownership of the log file. ", e) + + pretty_logs = False + if pretty_logs: + underlying_logger: WrappedLogger = structlog.WriteLogger(log_file.open("w", buffering=1)) + else: + underlying_logger = structlog.BytesLogger(log_file.open("wb")) + processors = logging_processors(enable_pretty_log=pretty_logs)[0] + logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() + + process = WatchedSubprocess.start(dag_path, ti, client=client, logger=logger) exit_code = process.wait() end = time.monotonic() diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 9abf3a4796b..c693a77eac2 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -127,6 +127,10 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: msg = SUPERVISOR_COMMS.get_message() if isinstance(msg, StartupDetails): + from setproctitle import setproctitle + + setproctitle(f"airflow worker -- {msg.ti.id}") + log = structlog.get_logger(logger_name="task") # TODO: set the "magic loop" context vars for parsing ti = parse(msg) diff --git a/task_sdk/src/airflow/sdk/log.py b/task_sdk/src/airflow/sdk/log.py index f8e06eda4a6..0873233f536 100644 --- a/task_sdk/src/airflow/sdk/log.py +++ b/task_sdk/src/airflow/sdk/log.py @@ -23,6 +23,7 @@ import os import sys import warnings from functools import cache +from pathlib import Path from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Generic, TextIO, TypeVar import msgspec @@ -285,9 +286,11 @@ def configure_logging( std_lib_formatter.append(named["json"]) global _warnings_showwarning - _warnings_showwarning = warnings.showwarning - # Capture warnings and show them via structlog - warnings.showwarning = _showwarning + + if _warnings_showwarning is None: + _warnings_showwarning = warnings.showwarning + # Capture warnings and show them via structlog + warnings.showwarning = _showwarning logging.config.dictConfig( { @@ -327,11 +330,13 @@ def configure_logging( "propagate": True, }, # Some modules we _never_ want at debug level - "asyncio": {"level": "INFO"}, "alembic": {"level": "INFO"}, + "asyncio": {"level": "INFO"}, + "cron_descriptor.GetText": {"level": "INFO"}, "httpcore": {"level": "INFO"}, "httpx": {"level": "WARN"}, "psycopg.pq": {"level": "INFO"}, + # These ones are too chatty even at info "sqlalchemy.engine": {"level": "WARN"}, }, } @@ -344,17 +349,17 @@ def reset_logging(): configure_logging.cache_clear() -_warnings_showwarning = None +_warnings_showwarning: Any = None def _showwarning( - message: str | Warning, + message: Warning | str, category: type[Warning], filename: str, lineno: int, file: TextIO | None = None, line: str | None = None, -): +) -> Any: """ Redirects warnings to structlog so they appear in task logs etc. @@ -370,3 +375,65 @@ def _showwarning( else: log = structlog.get_logger(logger_name="py.warnings") log.warning(str(message), category=category.__name__, filename=filename, lineno=lineno) + + +def _prepare_log_folder(directory: Path, mode: int): + """ + Prepare the log folder and ensure its mode is as configured. + + To handle log writing when tasks are impersonated, the log files need to + be writable by the user that runs the Airflow command and the user + that is impersonated. This is mainly to handle corner cases with the + SubDagOperator. When the SubDagOperator is run, all of the operators + run under the impersonated user and create appropriate log files + as the impersonated user. However, if the user manually runs tasks + of the SubDagOperator through the UI, then the log files are created + by the user that runs the Airflow command. For example, the Airflow + run command may be run by the `airflow_sudoable` user, but the Airflow + tasks may be run by the `airflow` user. If the log files are not + writable by both users, then it's possible that re-running a task + via the UI (or vice versa) results in a permission error as the task + tries to write to a log file created by the other user. + + We leave it up to the user to manage their permissions by exposing configuration for both + new folders and new log files. Default is to make new log folders and files group-writeable + to handle most common impersonation use cases. The requirement in this case will be to make + sure that the same group is set as default group for both - impersonated user and main airflow + user. + """ + for parent in reversed(directory.parents): + parent.mkdir(mode=mode, exist_ok=True) + directory.mkdir(mode=mode, exist_ok=True) + + +def init_log_file(local_relative_path: str) -> Path: + """ + Ensure log file and parent directories are created. + + Any directories that are missing are created with the right permission bits. + + See above ``_prepare_log_folder`` method for more detailed explanation. + """ + # NOTE: This is duplicated from airflow.utils.log.file_task_handler:FileTaskHandler._init_file, but we + # want to remove that + from airflow.configuration import conf + + new_file_permissions = int( + conf.get("logging", "file_task_handler_new_file_permissions", fallback="0o664"), 8 + ) + new_folder_permissions = int( + conf.get("logging", "file_task_handler_new_folder_permissions", fallback="0o775"), 8 + ) + + base_log_folder = conf.get("logging", "base_log_folder") + full_path = Path(base_log_folder, local_relative_path) + + _prepare_log_folder(full_path.parent, new_folder_permissions) + + try: + full_path.touch(new_file_permissions) + except OSError as e: + log = structlog.get_logger() + log.warning("OSError while changing ownership of the log file. %s", e) + + return full_path diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 2cb8c24af1b..3a50a73b3cf 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -36,7 +36,6 @@ from uuid6 import uuid7 from airflow.sdk.api import client as sdk_client from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import TaskInstance -from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, @@ -249,19 +248,15 @@ class TestWatchedSubprocess: time_machine.move_to(instant, tick=False) dagfile_path = test_dags_dir / "super_basic_run.py" - task_activity = ExecuteTaskActivity( - ti=TaskInstance( - id=uuid7(), - task_id="hello", - dag_id="super_basic_run", - run_id="c", - try_number=1, - ), - path=dagfile_path, - token="", + ti = TaskInstance( + id=uuid7(), + task_id="hello", + dag_id="super_basic_run", + run_id="c", + try_number=1, ) # Assert Exit Code is 0 - assert supervise(activity=task_activity, server="", dry_run=True) == 0 + assert supervise(ti=ti, dag_path=dagfile_path, token="", server="", dry_run=True) == 0 # We should have a log from the task! assert { diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index 2545ceb7e70..a64f87bf1ae 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -17,17 +17,15 @@ # under the License. from __future__ import annotations -import datetime import multiprocessing import os -import subprocess from unittest import mock import pytest from kgb import spy_on +from uuid6 import uuid7 -from airflow import settings -from airflow.exceptions import AirflowException +from airflow.executors import workloads from airflow.executors.local_executor import LocalExecutor from airflow.utils.state import State @@ -52,43 +50,43 @@ class TestLocalExecutor: def test_serve_logs_default_value(self): assert LocalExecutor.serve_logs - @mock.patch("airflow.executors.local_executor.subprocess.check_call") - @mock.patch("airflow.cli.commands.task_command.task_run") - def _test_execute(self, mock_run, mock_check_call, parallelism=1): - success_command = ["airflow", "tasks", "run", "success", "some_parameter", "2020-10-07"] - fail_command = ["airflow", "tasks", "run", "failure", "task_id", "2020-10-07"] + @mock.patch("airflow.sdk.execution_time.supervisor.supervise") + def _test_execute(self, mock_supervise, parallelism=1): + success_tis = [ + workloads.TaskInstance( + id=uuid7(), + task_id=f"success_{i}", + dag_id="mydag", + run_id="run1", + try_number=1, + state="queued", + ) + for i in range(self.TEST_SUCCESS_COMMANDS) + ] + fail_ti = success_tis[0].model_copy(update={"id": uuid7(), "task_id": "failure"}) # We just mock both styles here, only one will be hit though - def fake_execute_command(command, close_fds=True): - if command != success_command: - raise subprocess.CalledProcessError(returncode=1, cmd=command) - else: - return 0 - - def fake_task_run(args): - if args.dag_id != "success": - raise AirflowException("Simulate failed task") + def fake_supervise(ti, **kwargs): + if ti.id == fail_ti.id: + raise RuntimeError("fake failure") + return 0 - mock_check_call.side_effect = fake_execute_command - mock_run.side_effect = fake_task_run + mock_supervise.side_effect = fake_supervise executor = LocalExecutor(parallelism=parallelism) executor.start() - success_key = "success {}" assert executor.result_queue.empty() with spy_on(executor._spawn_worker) as spawn_worker: - run_id = "manual_" + datetime.datetime.now().isoformat() - for i in range(self.TEST_SUCCESS_COMMANDS): - key_id, command = success_key.format(i), success_command - key = key_id, "fake_ti", run_id, 0 - executor.running.add(key) - executor.execute_async(key=key, command=command) + for ti in success_tis: + executor.queue_workload( + workloads.ExecuteTask(token="", ti=ti, dag_path="some/path", log_path=None) + ) - fail_key = "fail", "fake_ti", run_id, 0 - executor.running.add(fail_key) - executor.execute_async(key=fail_key, command=fail_command) + executor.queue_workload( + workloads.ExecuteTask(token="", ti=fail_ti, dag_path="some/path", log_path=None) + ) executor.end() @@ -100,24 +98,19 @@ class TestLocalExecutor: assert len(executor.running) == 0 assert executor._unread_messages.value == 0 - for i in range(self.TEST_SUCCESS_COMMANDS): - key_id = success_key.format(i) - key = key_id, "fake_ti", run_id, 0 - assert executor.event_buffer[key][0] == State.SUCCESS - assert executor.event_buffer[fail_key][0] == State.FAILED + for ti in success_tis: + assert executor.event_buffer[ti.key][0] == State.SUCCESS + assert executor.event_buffer[fail_ti.key][0] == State.FAILED @skip_spawn_mp_start @pytest.mark.parametrize( - ("parallelism", "fork_or_subproc"), + ("parallelism",), [ - pytest.param(0, True, id="unlimited_subprocess"), - pytest.param(2, True, id="limited_subprocess"), - pytest.param(0, False, id="unlimited_fork"), - pytest.param(2, False, id="limited_fork"), + pytest.param(0, id="unlimited"), + pytest.param(2, id="limited"), ], ) - def test_execution(self, parallelism: int, fork_or_subproc: bool, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", fork_or_subproc) + def test_execution(self, parallelism: int): self._test_execute(parallelism=parallelism) @mock.patch("airflow.executors.local_executor.LocalExecutor.sync") diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 3b2d1a34dd6..b4039c7bc51 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -181,6 +181,12 @@ class TestSchedulerJob: default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") second_executor = mock.MagicMock(name="SeconadaryExecutor", slots_available=8, slots_occupied=0) second_executor.name = ExecutorName(alias="secondary_exec", module_path="secondary.exec.module.path") + + # TODO: Task-SDK Make it look like a bound method. Needed until we remove the old queue_command + # interface from executors + default_executor.queue_workload.__func__ = BaseExecutor.queue_workload + second_executor.queue_workload.__func__ = BaseExecutor.queue_workload + with mock.patch("airflow.jobs.job.Job.executors", new_callable=PropertyMock) as executors_mock: executors_mock.return_value = [default_executor, second_executor] yield [default_executor, second_executor]
