This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9fa13c82ff7 Executor Synchronous callback workload (#61153)
9fa13c82ff7 is described below
commit 9fa13c82ff7cf6416a6237376eb7b0ff126b83b6
Author: D. Ferruzzi <[email protected]>
AuthorDate: Sat Feb 28 09:59:36 2026 -0800
Executor Synchronous callback workload (#61153)
* Synchronous callback support for BaseExecutor, LocalExecutor, and
CeleryExecutor
Add support for the Callback workload to be run in the executors. Other
executors will need to be updated before the can support the workload, but I
tried to make it as non-invasive as I could.
Co-authored-by: Sean Ghaeli <[email protected]>
Co-authored-by: Ramit Kataria <[email protected]>
---
.../src/airflow/executors/base_executor.py | 107 ++++++---
.../src/airflow/executors/local_executor.py | 90 +++++---
airflow-core/src/airflow/executors/workloads.py | 210 ------------------
.../src/airflow/executors/workloads/__init__.py | 35 +++
.../src/airflow/executors/workloads/base.py | 85 +++++++
.../src/airflow/executors/workloads/callback.py | 158 +++++++++++++
.../src/airflow/executors/workloads/task.py | 104 +++++++++
.../src/airflow/executors/workloads/trigger.py | 42 ++++
.../src/airflow/executors/workloads/types.py | 40 ++++
.../src/airflow/jobs/scheduler_job_runner.py | 247 +++++++++++++++------
.../src/airflow/jobs/triggerer_job_runner.py | 5 +-
airflow-core/src/airflow/models/callback.py | 42 ++--
airflow-core/src/airflow/models/deadline.py | 38 +++-
airflow-core/src/airflow/models/taskinstance.py | 11 +-
airflow-core/src/airflow/utils/state.py | 14 ++
.../tests/unit/executors/test_base_executor.py | 149 +++++++++++++
.../tests/unit/executors/test_local_executor.py | 43 +++-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 76 +++++--
airflow-core/tests/unit/models/test_callback.py | 8 +-
docs/spelling_wordlist.txt | 1 +
.../providers/celery/executors/celery_executor.py | 55 ++---
.../celery/executors/celery_executor_utils.py | 55 +++--
.../celery/executors/celery_kubernetes_executor.py | 12 +-
.../integration/celery/test_celery_executor.py | 7 +-
.../executors/local_kubernetes_executor.py | 8 +-
.../edge3/worker_api/v2-edge-generated.yaml | 6 +-
task-sdk/src/airflow/sdk/definitions/deadline.py | 4 +-
.../tests/task_sdk/definitions/test_deadline.py | 23 +-
28 files changed, 1221 insertions(+), 454 deletions(-)
diff --git a/airflow-core/src/airflow/executors/base_executor.py
b/airflow-core/src/airflow/executors/base_executor.py
index 3bb8a70fa27..2997d55d8bb 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -32,7 +32,9 @@ from airflow.cli.cli_config import DefaultHelpParser
from airflow.configuration import conf
from airflow.executors import workloads
from airflow.executors.executor_loader import ExecutorLoader
+from airflow.executors.workloads.task import TaskInstanceDTO
from airflow.models import Log
+from airflow.models.callback import CallbackKey
from airflow.observability.metrics import stats_utils
from airflow.observability.trace import Trace
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -52,6 +54,7 @@ if TYPE_CHECKING:
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
from airflow.executors.executor_utils import ExecutorName
+ from airflow.executors.workloads.types import WorkloadKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -143,6 +146,7 @@ class BaseExecutor(LoggingMixin):
active_spans = ThreadSafeDict()
supports_ad_hoc_ti_run: bool = False
+ supports_callbacks: bool = False
supports_multi_team: bool = False
sentry_integration: str = ""
@@ -186,8 +190,9 @@ class BaseExecutor(LoggingMixin):
self.parallelism: int = parallelism
self.team_name: str | None = team_name
self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
- self.running: set[TaskInstanceKey] = set()
- self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
+ self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}
+ self.running: set[WorkloadKey] = set()
+ self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {}
self._task_event_logs: deque[Log] = deque()
self.conf = ExecutorConf(team_name)
@@ -203,7 +208,7 @@ class BaseExecutor(LoggingMixin):
:meta private:
"""
- self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] =
defaultdict(RunningRetryAttemptType)
+ self.attempts: dict[WorkloadKey, RunningRetryAttemptType] =
defaultdict(RunningRetryAttemptType)
def __repr__(self):
_repr = f"{self.__class__.__name__}(parallelism={self.parallelism}"
@@ -224,10 +229,47 @@ class BaseExecutor(LoggingMixin):
self._task_event_logs.append(Log(event=event, task_instance=ti_key,
extra=extra))
def queue_workload(self, workload: workloads.All, session: Session) ->
None:
- if not isinstance(workload, workloads.ExecuteTask):
- raise ValueError(f"Un-handled workload kind
{type(workload).__name__!r} in {type(self).__name__}")
- ti = workload.ti
- self.queued_tasks[ti.key] = workload
+ if isinstance(workload, workloads.ExecuteTask):
+ ti = workload.ti
+ self.queued_tasks[ti.key] = workload
+ elif isinstance(workload, workloads.ExecuteCallback):
+ if not self.supports_callbacks:
+ raise NotImplementedError(
+ f"{type(self).__name__} does not support ExecuteCallback
workloads. "
+ f"Set supports_callbacks = True and implement callback
handling in _process_workloads(). "
+ f"See LocalExecutor or CeleryExecutor for reference
implementation."
+ )
+ self.queued_callbacks[workload.callback.id] = workload
+ else:
+ raise ValueError(
+ f"Un-handled workload type {type(workload).__name__!r} in
{type(self).__name__}. "
+ f"Workload must be one of: ExecuteTask, ExecuteCallback."
+ )
+
+ def _get_workloads_to_schedule(self, open_slots: int) ->
list[tuple[WorkloadKey, workloads.All]]:
+ """
+ Select and return the next batch of workloads to schedule, respecting
priority policy.
+
+ Priority Policy: Callbacks are scheduled before tasks (callbacks
complete existing work).
+ Callbacks are processed in FIFO order. Tasks are sorted by
priority_weight (higher priority first).
+
+ :param open_slots: Number of available execution slots
+ """
+ workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = []
+
+ if self.queued_callbacks:
+ for key, workload in self.queued_callbacks.items():
+ if len(workloads_to_schedule) >= open_slots:
+ break
+ workloads_to_schedule.append((key, workload))
+
+ if open_slots > len(workloads_to_schedule) and self.queued_tasks:
+ for task_key, task_workload in
self.order_queued_tasks_by_priority():
+ if len(workloads_to_schedule) >= open_slots:
+ break
+ workloads_to_schedule.append((task_key, task_workload))
+
+ return workloads_to_schedule
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
"""
@@ -266,10 +308,10 @@ class BaseExecutor(LoggingMixin):
"""Heartbeat sent to trigger new jobs."""
open_slots = self.parallelism - len(self.running)
- num_running_tasks = len(self.running)
- num_queued_tasks = len(self.queued_tasks)
+ num_running_workloads = len(self.running)
+ num_queued_workloads = len(self.queued_tasks) +
len(self.queued_callbacks)
- self._emit_metrics(open_slots, num_running_tasks, num_queued_tasks)
+ self._emit_metrics(open_slots, num_running_workloads,
num_queued_workloads)
self.trigger_tasks(open_slots)
# Calling child class sync method
@@ -350,16 +392,16 @@ class BaseExecutor(LoggingMixin):
def trigger_tasks(self, open_slots: int) -> None:
"""
- Initiate async execution of the queued tasks, up to the number of
available slots.
+ Initiate async execution of queued workloads (tasks and callbacks), up
to the number of available slots.
+
+ Callbacks are prioritized over tasks to complete existing work before
starting new work.
:param open_slots: Number of open slots
"""
- sorted_queue = self.order_queued_tasks_by_priority()
+ workloads_to_schedule = self._get_workloads_to_schedule(open_slots)
workload_list = []
- for _ in range(min((open_slots, len(self.queued_tasks)))):
- key, item = sorted_queue.pop()
-
+ for key, workload in workloads_to_schedule:
# If a task makes it here but is still understood by the executor
# to be running, it generally means that the task has been killed
# externally and not yet been marked as failed.
@@ -373,12 +415,12 @@ class BaseExecutor(LoggingMixin):
if key in self.attempts:
del self.attempts[key]
- if isinstance(item, workloads.ExecuteTask) and hasattr(item, "ti"):
- ti = item.ti
+ if isinstance(workload, workloads.ExecuteTask) and
hasattr(workload, "ti"):
+ ti = workload.ti
# If it's None, then the span for the current id hasn't been
started.
if self.active_spans is not None and
self.active_spans.get("ti:" + str(ti.id)) is None:
- if isinstance(ti, workloads.TaskInstance):
+ if isinstance(ti, TaskInstanceDTO):
parent_context =
Trace.extract(ti.parent_context_carrier)
else:
parent_context =
Trace.extract(ti.dag_run.context_carrier)
@@ -397,7 +439,8 @@ class BaseExecutor(LoggingMixin):
carrier = Trace.inject()
ti.context_carrier = carrier
- workload_list.append(item)
+ workload_list.append(workload)
+
if workload_list:
self._process_workloads(workload_list)
@@ -459,24 +502,25 @@ class BaseExecutor(LoggingMixin):
"""
self.change_state(key, TaskInstanceState.RUNNING, info,
remove_running=False)
- def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey,
EventBufferValueType]:
+ def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey,
EventBufferValueType]:
"""
Return and flush the event buffer.
In case dag_ids is specified it will only return and flush events
for the given dag_ids. Otherwise, it returns and flushes all events.
+ Note: Callback events (with string keys) are always included
regardless of dag_ids filter.
:param dag_ids: the dag_ids to return events for; returns all if given
``None``.
:return: a dict of events
"""
- cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {}
+ cleared_events: dict[WorkloadKey, EventBufferValueType] = {}
if dag_ids is None:
cleared_events = self.event_buffer
self.event_buffer = {}
else:
- for ti_key in list(self.event_buffer.keys()):
- if ti_key.dag_id in dag_ids:
- cleared_events[ti_key] = self.event_buffer.pop(ti_key)
+ for key in list(self.event_buffer.keys()):
+ if isinstance(key, CallbackKey) or key.dag_id in dag_ids:
+ cleared_events[key] = self.event_buffer.pop(key)
return cleared_events
@@ -529,21 +573,26 @@ class BaseExecutor(LoggingMixin):
@property
def slots_available(self):
- """Number of new tasks this executor instance can accept."""
- return self.parallelism - len(self.running) - len(self.queued_tasks)
+ """Number of new workloads (tasks and callbacks) this executor
instance can accept."""
+ return self.parallelism - len(self.running) - len(self.queued_tasks) -
len(self.queued_callbacks)
@property
def slots_occupied(self):
- """Number of tasks this executor instance is currently managing."""
- return len(self.running) + len(self.queued_tasks)
+ """Number of workloads (tasks and callbacks) this executor instance is
currently managing."""
+ return len(self.running) + len(self.queued_tasks) +
len(self.queued_callbacks)
def debug_dump(self):
"""Get called in response to SIGUSR2 by the scheduler."""
self.log.info(
- "executor.queued (%d)\n\t%s",
+ "executor.queued_tasks (%d)\n\t%s",
len(self.queued_tasks),
"\n\t".join(map(repr, self.queued_tasks.items())),
)
+ self.log.info(
+ "executor.queued_callbacks (%d)\n\t%s",
+ len(self.queued_callbacks),
+ "\n\t".join(map(repr, self.queued_callbacks.items())),
+ )
self.log.info("executor.running (%d)\n\t%s", len(self.running),
"\n\t".join(map(repr, self.running)))
self.log.info(
"executor.event_buffer (%d)\n\t%s",
diff --git a/airflow-core/src/airflow/executors/local_executor.py
b/airflow-core/src/airflow/executors/local_executor.py
index 604de7c7f00..9b5939a0bd2 100644
--- a/airflow-core/src/airflow/executors/local_executor.py
+++ b/airflow-core/src/airflow/executors/local_executor.py
@@ -37,7 +37,8 @@ import structlog
from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
-from airflow.utils.state import TaskInstanceState
+from airflow.executors.workloads.callback import execute_callback_workload
+from airflow.utils.state import CallbackState, TaskInstanceState
# add logger to parameter of setproctitle to support logging
if sys.platform == "darwin":
@@ -50,13 +51,23 @@ else:
if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
- TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState,
Exception | None]
+ from airflow.executors.workloads.types import WorkloadResultType
+
+
+def _get_executor_process_title_prefix(team_name: str | None) -> str:
+ """
+ Build the process title prefix for LocalExecutor workers.
+
+ :param team_name: Team name from executor configuration
+ """
+ team_suffix = f" [{team_name}]" if team_name else ""
+ return f"airflow worker -- LocalExecutor{team_suffix}:"
def _run_worker(
logger_name: str,
input: SimpleQueue[workloads.All | None],
- output: Queue[TaskInstanceStateType],
+ output: Queue[WorkloadResultType],
unread_messages: multiprocessing.sharedctypes.Synchronized[int],
team_conf,
):
@@ -68,11 +79,8 @@ def _run_worker(
log = structlog.get_logger(logger_name)
log.info("Worker starting up pid=%d", os.getpid())
- # Create team suffix for process title
- team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""
-
while True:
- setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: <idle>",
log)
+
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)}
<idle>", log)
try:
workload = input.get()
except EOFError:
@@ -87,25 +95,30 @@ def _run_worker(
# Received poison pill, no more tasks to run
return
- if not isinstance(workload, workloads.ExecuteTask):
- raise ValueError(f"LocalExecutor does not know how to handle
{type(workload)}")
-
# Decrement this as soon as we pick up a message off the queue
with unread_messages:
unread_messages.value -= 1
- key = None
- if ti := getattr(workload, "ti", None):
- key = ti.key
- else:
- raise TypeError(f"Don't know how to get ti key from
{type(workload).__name__}")
- try:
- _execute_work(log, workload, team_conf)
+ # Handle different workload types
+ if isinstance(workload, workloads.ExecuteTask):
+ try:
+ _execute_work(log, workload, team_conf)
+ output.put((workload.ti.key, TaskInstanceState.SUCCESS, None))
+ except Exception as e:
+ log.exception("Task execution failed.")
+ output.put((workload.ti.key, TaskInstanceState.FAILED, e))
+
+ elif isinstance(workload, workloads.ExecuteCallback):
+ output.put((workload.callback.id, CallbackState.RUNNING, None))
+ try:
+ _execute_callback(log, workload, team_conf)
+ output.put((workload.callback.id, CallbackState.SUCCESS, None))
+ except Exception as e:
+ log.exception("Callback execution failed")
+ output.put((workload.callback.id, CallbackState.FAILED, e))
- output.put((key, TaskInstanceState.SUCCESS, None))
- except Exception as e:
- log.exception("uhoh")
- output.put((key, TaskInstanceState.FAILED, e))
+ else:
+ raise ValueError(f"LocalExecutor does not know how to handle
{type(workload)}")
def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) ->
None:
@@ -118,9 +131,7 @@ def _execute_work(log: Logger, workload:
workloads.ExecuteTask, team_conf) -> No
"""
from airflow.sdk.execution_time.supervisor import supervise
- # Create team suffix for process title
- team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""
- setproctitle(f"airflow worker -- LocalExecutor{team_suffix}:
{workload.ti.id}", log)
+ setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)}
{workload.ti.id}", log)
base_url = team_conf.get("api", "base_url", fallback="/")
# If it's a relative URL, use localhost:8080 as the default
@@ -141,6 +152,22 @@ def _execute_work(log: Logger, workload:
workloads.ExecuteTask, team_conf) -> No
)
+def _execute_callback(log: Logger, workload: workloads.ExecuteCallback,
team_conf) -> None:
+ """
+ Execute a callback workload.
+
+ :param log: Logger instance
+ :param workload: The ExecuteCallback workload to execute
+ :param team_conf: Team-specific executor configuration
+ """
+ setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)}
{workload.callback.id}", log)
+
+ success, error_msg = execute_callback_workload(workload.callback, log)
+
+ if not success:
+ raise RuntimeError(error_msg or "Callback execution failed")
+
+
class LocalExecutor(BaseExecutor):
"""
LocalExecutor executes tasks locally in parallel.
@@ -155,9 +182,10 @@ class LocalExecutor(BaseExecutor):
supports_multi_team: bool = True
serve_logs: bool = True
+ supports_callbacks: bool = True
activity_queue: SimpleQueue[workloads.All | None]
- result_queue: SimpleQueue[TaskInstanceStateType]
+ result_queue: SimpleQueue[WorkloadResultType]
workers: dict[int, multiprocessing.Process]
_unread_messages: multiprocessing.sharedctypes.Synchronized[int]
@@ -300,10 +328,14 @@ class LocalExecutor(BaseExecutor):
def terminate(self):
"""Terminate the executor is not doing anything."""
- def _process_workloads(self, workloads):
- for workload in workloads:
+ def _process_workloads(self, workload_list):
+ for workload in workload_list:
self.activity_queue.put(workload)
- del self.queued_tasks[workload.ti.key]
+ # Remove from appropriate queue based on workload type
+ if isinstance(workload, workloads.ExecuteTask):
+ del self.queued_tasks[workload.ti.key]
+ elif isinstance(workload, workloads.ExecuteCallback):
+ del self.queued_callbacks[workload.callback.id]
with self._unread_messages:
- self._unread_messages.value += len(workloads)
+ self._unread_messages.value += len(workload_list)
self._check_workers()
diff --git a/airflow-core/src/airflow/executors/workloads.py
b/airflow-core/src/airflow/executors/workloads.py
deleted file mode 100644
index 7cf1aae60ff..00000000000
--- a/airflow-core/src/airflow/executors/workloads.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import os
-import uuid
-from abc import ABC
-from datetime import datetime
-from pathlib import Path
-from typing import TYPE_CHECKING, Annotated, Literal
-
-import structlog
-from pydantic import BaseModel, Field
-
-if TYPE_CHECKING:
- from airflow.api_fastapi.auth.tokens import JWTGenerator
- from airflow.models import DagRun
- from airflow.models.callback import Callback as CallbackModel,
CallbackFetchMethod
- from airflow.models.taskinstance import TaskInstance as TIModel
- from airflow.models.taskinstancekey import TaskInstanceKey
-
-
-__all__ = ["All", "ExecuteTask", "ExecuteCallback"]
-
-log = structlog.get_logger(__name__)
-
-
-class BaseWorkload(BaseModel):
- token: str
- """The identity token for this workload"""
-
- @staticmethod
- def generate_token(sub_id: str, generator: JWTGenerator | None = None) ->
str:
- return generator.generate({"sub": sub_id}) if generator else ""
-
-
-class BundleInfo(BaseModel):
- """Schema for telling task which bundle to run with."""
-
- name: str
- version: str | None = None
-
-
-class TaskInstance(BaseModel):
- """Schema for TaskInstance with minimal required fields needed for
Executors and Task SDK."""
-
- id: uuid.UUID
- dag_version_id: uuid.UUID
- task_id: str
- dag_id: str
- run_id: str
- try_number: int
- map_index: int = -1
-
- pool_slots: int
- queue: str
- priority_weight: int
- executor_config: dict | None = Field(default=None, exclude=True)
-
- parent_context_carrier: dict | None = None
- context_carrier: dict | None = None
-
- # TODO: Task-SDK: Can we replace TastInstanceKey with just the uuid across
the codebase?
- @property
- def key(self) -> TaskInstanceKey:
- from airflow.models.taskinstancekey import TaskInstanceKey
-
- return TaskInstanceKey(
- dag_id=self.dag_id,
- task_id=self.task_id,
- run_id=self.run_id,
- try_number=self.try_number,
- map_index=self.map_index,
- )
-
-
-class Callback(BaseModel):
- """Schema for Callback with minimal required fields needed for Executors
and Task SDK."""
-
- id: uuid.UUID
- fetch_type: CallbackFetchMethod
- data: dict
-
-
-class BaseDagBundleWorkload(BaseWorkload, ABC):
- """Base class for Workloads that are associated with a DAG bundle."""
-
- dag_rel_path: os.PathLike[str]
- """The filepath where the DAG can be found (likely prefixed with
`DAG_FOLDER/`)"""
-
- bundle_info: BundleInfo
-
- log_path: str | None
- """The rendered relative log filename template the task logs should be
written to"""
-
-
-class ExecuteTask(BaseDagBundleWorkload):
- """Execute the given Task."""
-
- ti: TaskInstance
- sentry_integration: str = ""
-
- type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask")
-
- @classmethod
- def make(
- cls,
- ti: TIModel,
- dag_rel_path: Path | None = None,
- generator: JWTGenerator | None = None,
- bundle_info: BundleInfo | None = None,
- sentry_integration: str = "",
- ) -> ExecuteTask:
- from airflow.utils.helpers import log_filename_template_renderer
-
- ser_ti = TaskInstance.model_validate(ti, from_attributes=True)
- ser_ti.parent_context_carrier = ti.dag_run.context_carrier
- if not bundle_info:
- bundle_info = BundleInfo(
- name=ti.dag_model.bundle_name,
- version=ti.dag_run.bundle_version,
- )
- fname = log_filename_template_renderer()(ti=ti)
-
- return cls(
- ti=ser_ti,
- dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or
""),
- token=cls.generate_token(str(ti.id), generator),
- log_path=fname,
- bundle_info=bundle_info,
- sentry_integration=sentry_integration,
- )
-
-
-class ExecuteCallback(BaseDagBundleWorkload):
- """Execute the given Callback."""
-
- callback: Callback
-
- type: Literal["ExecuteCallback"] = Field(init=False,
default="ExecuteCallback")
-
- @classmethod
- def make(
- cls,
- callback: CallbackModel,
- dag_run: DagRun,
- dag_rel_path: Path | None = None,
- generator: JWTGenerator | None = None,
- bundle_info: BundleInfo | None = None,
- ) -> ExecuteCallback:
- if not bundle_info:
- bundle_info = BundleInfo(
- name=dag_run.dag_model.bundle_name,
- version=dag_run.bundle_version,
- )
- fname = f"executor_callbacks/{callback.id}" # TODO: better log file
template
-
- return cls(
- callback=Callback.model_validate(callback, from_attributes=True),
- dag_rel_path=dag_rel_path or
Path(dag_run.dag_model.relative_fileloc or ""),
- token=cls.generate_token(str(callback.id), generator),
- log_path=fname,
- bundle_info=bundle_info,
- )
-
-
-class RunTrigger(BaseModel):
- """Execute an async "trigger" process that yields events."""
-
- id: int
-
- ti: TaskInstance | None
- """
- The task instance associated with this trigger.
-
- Could be none for asset-based triggers.
- """
-
- classpath: str
- """
- Dot-separated name of the module+fn to import and run this workload.
-
- Consumers of this Workload must perform their own validation of this input.
- """
-
- encrypted_kwargs: str
-
- timeout_after: datetime | None = None
-
- type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger")
-
-
-All = Annotated[
- ExecuteTask | RunTrigger,
- Field(discriminator="type"),
-]
diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py
b/airflow-core/src/airflow/executors/workloads/__init__.py
new file mode 100644
index 00000000000..dca4c991f63
--- /dev/null
+++ b/airflow-core/src/airflow/executors/workloads/__init__.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Workload schemas for executor communication."""
+
+from __future__ import annotations
+
+from typing import Annotated
+
+from pydantic import Field
+
+from airflow.executors.workloads.base import BaseWorkload, BundleInfo
+from airflow.executors.workloads.callback import CallbackFetchMethod,
ExecuteCallback
+from airflow.executors.workloads.task import ExecuteTask
+from airflow.executors.workloads.trigger import RunTrigger
+
+All = Annotated[
+ ExecuteTask | ExecuteCallback | RunTrigger,
+ Field(discriminator="type"),
+]
+
+__all__ = ["All", "BaseWorkload", "BundleInfo", "CallbackFetchMethod",
"ExecuteCallback", "ExecuteTask"]
diff --git a/airflow-core/src/airflow/executors/workloads/base.py
b/airflow-core/src/airflow/executors/workloads/base.py
new file mode 100644
index 00000000000..cf622209d67
--- /dev/null
+++ b/airflow-core/src/airflow/executors/workloads/base.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""ORM models and Pydantic schemas for BaseWorkload."""
+
+from __future__ import annotations
+
+import os
+from abc import ABC
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel, ConfigDict
+
+if TYPE_CHECKING:
+ from airflow.api_fastapi.auth.tokens import JWTGenerator
+
+
+class BaseWorkload:
+ """
+ Mixin for ORM models that can be scheduled as workloads.
+
+ This mixin defines the interface that scheduler workloads (TaskInstance,
+ ExecutorCallback, etc.) must implement to provide routing information to
the scheduler.
+
+ Subclasses must override:
+ - get_dag_id() -> str | None
+ - get_executor_name() -> str | None
+ """
+
+ def get_dag_id(self) -> str | None:
+ """
+ Return the DAG ID for scheduler routing.
+
+ Must be implemented by subclasses.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} must implement
get_dag_id()")
+
+ def get_executor_name(self) -> str | None:
+ """
+ Return the executor name for scheduler routing.
+
+ Must be implemented by subclasses.
+ """
+ raise NotImplementedError(f"{self.__class__.__name__} must implement
get_executor_name()")
+
+
+class BundleInfo(BaseModel):
+ """Schema for telling task which bundle to run with."""
+
+ name: str
+ version: str | None = None
+
+
+class BaseWorkloadSchema(BaseModel):
+ """Base Pydantic schema for executor workload DTOs."""
+
+ model_config = ConfigDict(populate_by_name=True)
+
+ token: str
+ """The identity token for this workload"""
+
+ @staticmethod
+ def generate_token(sub_id: str, generator: JWTGenerator | None = None) ->
str:
+ return generator.generate({"sub": sub_id}) if generator else ""
+
+
+class BaseDagBundleWorkload(BaseWorkloadSchema, ABC):
+ """Base class for Workloads that are associated with a DAG bundle."""
+
+ dag_rel_path: os.PathLike[str] # Filepath where the DAG can be found
(likely prefixed with `DAG_FOLDER/`)
+ bundle_info: BundleInfo
+ log_path: str | None # Rendered relative log filename template the task
logs should be written to.
diff --git a/airflow-core/src/airflow/executors/workloads/callback.py
b/airflow-core/src/airflow/executors/workloads/callback.py
new file mode 100644
index 00000000000..c15bb33fba7
--- /dev/null
+++ b/airflow-core/src/airflow/executors/workloads/callback.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Callback workload schemas for executor communication."""
+
+from __future__ import annotations
+
+from enum import Enum
+from importlib import import_module
+from pathlib import Path
+from typing import TYPE_CHECKING, Literal
+from uuid import UUID
+
+import structlog
+from pydantic import BaseModel, Field, field_validator
+
+from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo
+
+if TYPE_CHECKING:
+ from airflow.api_fastapi.auth.tokens import JWTGenerator
+ from airflow.models import DagRun
+ from airflow.models.callback import Callback as CallbackModel, CallbackKey
+
+log = structlog.get_logger(__name__)
+
+
+class CallbackFetchMethod(str, Enum):
+ """Methods used to fetch callback at runtime."""
+
+ # For future use once Dag Processor callbacks
(on_success_callback/on_failure_callback) get moved to executors
+ DAG_ATTRIBUTE = "dag_attribute"
+
+ # For deadline callbacks since they import callbacks through the import
path
+ IMPORT_PATH = "import_path"
+
+
+class CallbackDTO(BaseModel):
+ """Schema for Callback with minimal required fields needed for Executors
and Task SDK."""
+
+ id: str # A uuid.UUID stored as a string
+ fetch_method: CallbackFetchMethod
+ data: dict
+
+ @field_validator("id", mode="before")
+ @classmethod
+ def validate_id(cls, v):
+ """Convert UUID to str if needed."""
+ if isinstance(v, UUID):
+ return str(v)
+ return v
+
+ @property
+ def key(self) -> CallbackKey:
+ """Return callback ID as key (CallbackKey = str)."""
+ return self.id
+
+
+class ExecuteCallback(BaseDagBundleWorkload):
+ """Execute the given Callback."""
+
+ callback: CallbackDTO
+
+ type: Literal["ExecuteCallback"] = Field(init=False,
default="ExecuteCallback")
+
+ @classmethod
+ def make(
+ cls,
+ callback: CallbackModel,
+ dag_run: DagRun,
+ dag_rel_path: Path | None = None,
+ generator: JWTGenerator | None = None,
+ bundle_info: BundleInfo | None = None,
+ ) -> ExecuteCallback:
+ """Create an ExecuteCallback workload from a Callback ORM model."""
+ if not bundle_info:
+ bundle_info = BundleInfo(
+ name=dag_run.dag_model.bundle_name,
+ version=dag_run.bundle_version,
+ )
+ fname = f"executor_callbacks/{callback.id}" # TODO: better log file
template
+
+ return cls(
+ callback=CallbackDTO.model_validate(callback,
from_attributes=True),
+ dag_rel_path=dag_rel_path or
Path(dag_run.dag_model.relative_fileloc or ""),
+ token=cls.generate_token(str(callback.id), generator),
+ log_path=fname,
+ bundle_info=bundle_info,
+ )
+
+
+def execute_callback_workload(
+ callback: CallbackDTO,
+ log,
+) -> tuple[bool, str | None]:
+ """
+ Execute a callback function by importing and calling it, returning the
success state.
+
+ Supports two patterns:
+ 1. Functions - called directly with kwargs
+ 2. Classes that return callable instances (like BaseNotifier) -
instantiated then called with context
+
+ Example:
+ # Function callback
+ callback.data = {"path": "my_module.alert_func", "kwargs": {"msg":
"Alert!", "context": {...}}}
+ execute_callback_workload(callback, log) # Calls
alert_func(msg="Alert!", context={...})
+
+ # Notifier callback
+ callback.data = {"path":
"airflow.providers.slack...SlackWebhookNotifier", "kwargs": {"text": "Alert!",
"context": {...}}}
+ execute_callback_workload(callback, log) #
SlackWebhookNotifier(text=..., context=...) then calls instance(context)
+
+ :param callback: The Callback schema containing path and kwargs
+ :param log: Logger instance for recording execution
+ :return: Tuple of (success: bool, error_message: str | None)
+ """
+ callback_path = callback.data.get("path")
+ callback_kwargs = callback.data.get("kwargs", {})
+
+ if not callback_path:
+ return False, "Callback path not found in data."
+
+ try:
+ # Import the callback callable
+ # Expected format: "module.path.to.function_or_class"
+ module_path, function_name = callback_path.rsplit(".", 1)
+ module = import_module(module_path)
+ callback_callable = getattr(module, function_name)
+
+ log.debug("Executing callback %s(%s)...", callback_path,
callback_kwargs)
+
+ # If the callback is a callable, call it. If it is a class,
instantiate it.
+ result = callback_callable(**callback_kwargs)
+
+ # If the callback is a class then it is now instantiated and callable,
call it.
+ if callable(result):
+ context = callback_kwargs.get("context", {})
+ log.debug("Calling result with context for %s", callback_path)
+ result = result(context)
+
+ log.info("Callback %s executed successfully.", callback_path)
+ return True, None
+
+ except Exception as e:
+ error_msg = f"Callback execution failed: {type(e).__name__}: {str(e)}"
+ log.exception("Callback %s(%s) execution failed: %s", callback_path,
callback_kwargs, error_msg)
+ return False, error_msg
diff --git a/airflow-core/src/airflow/executors/workloads/task.py
b/airflow-core/src/airflow/executors/workloads/task.py
new file mode 100644
index 00000000000..d691dcb6f09
--- /dev/null
+++ b/airflow-core/src/airflow/executors/workloads/task.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task workload schemas for executor communication."""
+
+from __future__ import annotations
+
+import uuid
+from pathlib import Path
+from typing import TYPE_CHECKING, Literal
+
+from pydantic import BaseModel, Field
+
+from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo
+
+if TYPE_CHECKING:
+ from airflow.api_fastapi.auth.tokens import JWTGenerator
+ from airflow.models.taskinstance import TaskInstance as TIModel
+ from airflow.models.taskinstancekey import TaskInstanceKey
+
+
+class TaskInstanceDTO(BaseModel):
+ """Schema for TaskInstance with minimal required fields needed for
Executors and Task SDK."""
+
+ id: uuid.UUID
+ dag_version_id: uuid.UUID
+ task_id: str
+ dag_id: str
+ run_id: str
+ try_number: int
+ map_index: int = -1
+
+ pool_slots: int
+ queue: str
+ priority_weight: int
+ executor_config: dict | None = Field(default=None, exclude=True)
+
+ parent_context_carrier: dict | None = None
+ context_carrier: dict | None = None
+
+ # TODO: Task-SDK: Can we replace TaskInstanceKey with just the uuid across
the codebase?
+ @property
+ def key(self) -> TaskInstanceKey:
+ from airflow.models.taskinstancekey import TaskInstanceKey
+
+ return TaskInstanceKey(
+ dag_id=self.dag_id,
+ task_id=self.task_id,
+ run_id=self.run_id,
+ try_number=self.try_number,
+ map_index=self.map_index,
+ )
+
+
+class ExecuteTask(BaseDagBundleWorkload):
+ """Execute the given Task."""
+
+ ti: TaskInstanceDTO
+ sentry_integration: str = ""
+
+ type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask")
+
+ @classmethod
+ def make(
+ cls,
+ ti: TIModel,
+ dag_rel_path: Path | None = None,
+ generator: JWTGenerator | None = None,
+ bundle_info: BundleInfo | None = None,
+ sentry_integration: str = "",
+ ) -> ExecuteTask:
+ """Create an ExecuteTask workload from a TaskInstance ORM model."""
+ from airflow.utils.helpers import log_filename_template_renderer
+
+ ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True)
+ ser_ti.parent_context_carrier = ti.dag_run.context_carrier
+ if not bundle_info:
+ bundle_info = BundleInfo(
+ name=ti.dag_model.bundle_name,
+ version=ti.dag_run.bundle_version,
+ )
+ fname = log_filename_template_renderer()(ti=ti)
+
+ return cls(
+ ti=ser_ti,
+ dag_rel_path=dag_rel_path or Path(ti.dag_model.relative_fileloc or
""),
+ token=cls.generate_token(str(ti.id), generator),
+ log_path=fname,
+ bundle_info=bundle_info,
+ sentry_integration=sentry_integration,
+ )
diff --git a/airflow-core/src/airflow/executors/workloads/trigger.py
b/airflow-core/src/airflow/executors/workloads/trigger.py
new file mode 100644
index 00000000000..25bca9ce44b
--- /dev/null
+++ b/airflow-core/src/airflow/executors/workloads/trigger.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Trigger workload schemas for executor communication."""
+
+from __future__ import annotations
+
+from datetime import datetime
+from typing import Literal
+
+from pydantic import BaseModel, Field
+
+# Using noqa because Ruff wants this in a TYPE_CHECKING block but Pydantic
fails if it is.
+from airflow.executors.workloads.task import TaskInstanceDTO # noqa: TCH001
+
+
+class RunTrigger(BaseModel):
+ """
+ Execute an async "trigger" process that yields events.
+
+ Consumers of this Workload must perform their own validation of the
classpath input.
+ """
+
+ id: int
+ ti: TaskInstanceDTO | None # Could be none for asset-based triggers.
+ classpath: str # Dot-separated name of the module+fn to import and run
this workload.
+ encrypted_kwargs: str
+ timeout_after: datetime | None = None
+ type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger")
diff --git a/airflow-core/src/airflow/executors/workloads/types.py
b/airflow-core/src/airflow/executors/workloads/types.py
new file mode 100644
index 00000000000..31cda702846
--- /dev/null
+++ b/airflow-core/src/airflow/executors/workloads/types.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Type aliases for Workloads."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, TypeAlias
+
+from airflow.models.callback import ExecutorCallback
+from airflow.models.taskinstance import TaskInstance
+
+if TYPE_CHECKING:
+ from airflow.models.callback import CallbackKey
+ from airflow.models.taskinstancekey import TaskInstanceKey
+ from airflow.utils.state import CallbackState, TaskInstanceState
+
+ # Type aliases for workload keys and states (used by executor layer)
+ WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey
+ WorkloadState: TypeAlias = TaskInstanceState | CallbackState
+
+ # Type alias for executor workload results (used by executor
implementations)
+ WorkloadResultType: TypeAlias = tuple[WorkloadKey, WorkloadState,
Exception | None]
+
+# Type alias for scheduler workloads (ORM models that can be routed to
executors)
+# Must be outside TYPE_CHECKING for use in function signatures
+SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 4117153d4e7..d1927980042 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -85,7 +85,7 @@ from airflow.models.asset import (
TaskOutletAssetReference,
)
from airflow.models.backfill import Backfill
-from airflow.models.callback import Callback
+from airflow.models.callback import Callback, CallbackType, ExecutorCallback
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DBDagBag
@@ -95,6 +95,7 @@ from airflow.models.dagwarning import DagWarning,
DagWarningType
from airflow.models.pool import normalize_pool_name_for_stats
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.team import Team
from airflow.models.trigger import TRIGGER_FAIL_REPR, Trigger,
TriggerFailureReason
from airflow.observability.metrics import stats_utils
@@ -115,7 +116,7 @@ from airflow.utils.sqlalchemy import (
prohibit_commit,
with_row_locks,
)
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import CallbackState, DagRunState, State,
TaskInstanceState
from airflow.utils.thread_safe_dict import ThreadSafeDict
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -130,7 +131,7 @@ if TYPE_CHECKING:
from airflow._shared.logging.types import Logger
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_utils import ExecutorName
- from airflow.models.taskinstance import TaskInstanceKey
+ from airflow.executors.workloads.types import SchedulerWorkload
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.utils.sqlalchemy import CommitProhibitorGuard
@@ -359,32 +360,35 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# Return dict with all None values to ensure graceful degradation
return {}
- def _get_task_team_name(self, task_instance: TaskInstance, session:
Session) -> str | None:
+ def _get_workload_team_name(self, workload: SchedulerWorkload, session:
Session) -> str | None:
"""
- Resolve team name for a task instance using the DAG > Bundle > Team
relationship chain.
+ Resolve team name for a workload using the DAG > Bundle > Team
relationship chain.
- TaskInstance > DagModel (via dag_id) > DagBundleModel (via
bundle_name) > Team
+ Workload > DagModel (via dag_id) > DagBundleModel (via bundle_name) >
Team
- :param task_instance: The TaskInstance to resolve team name for
+ :param workload: The Workload to resolve team name for
:param session: Database session for queries
:return: Team name if found or None
"""
# Use the batch query function with a single DAG ID
- dag_id_to_team_name =
self._get_team_names_for_dag_ids([task_instance.dag_id], session)
- team_name = dag_id_to_team_name.get(task_instance.dag_id)
+ if dag_id := workload.get_dag_id():
+ dag_id_to_team_name = self._get_team_names_for_dag_ids([dag_id],
session)
+ team_name = dag_id_to_team_name.get(dag_id)
+ else:
+ team_name = None # mypy didn't like the implicit defaulting to
None
if team_name:
self.log.debug(
- "Resolved team name '%s' for task %s (dag_id=%s)",
+ "Resolved team name '%s' for task or callback %s (dag_id=%s)",
team_name,
- task_instance.task_id,
- task_instance.dag_id,
+ workload,
+ dag_id,
)
else:
self.log.debug(
- "No team found for task %s (dag_id=%s) - DAG may not have
bundle or team association",
- task_instance.task_id,
- task_instance.dag_id,
+ "No team found for task or callback %s (dag_id=%s) - DAG may
not have bundle or team association",
+ workload,
+ dag_id,
)
return team_name
@@ -981,7 +985,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
queued_tis = self._executable_task_instances_to_queued(max_tis,
session=session)
# Sort queued TIs to their respective executor
- executor_to_queued_tis = self._executor_to_tis(queued_tis, session)
+ executor_to_queued_tis = self._executor_to_workloads(queued_tis,
session)
for executor, queued_tis_per_executor in
executor_to_queued_tis.items():
self.log.info(
"Trying to enqueue tasks: %s for executor: %s",
@@ -993,6 +997,75 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
return len(queued_tis)
+ def _enqueue_executor_callbacks(self, session: Session) -> None:
+ """
+ Enqueue ExecutorCallback workloads to executors.
+
+ Similar to _enqueue_task_instances, but for callbacks that need to run
on executors.
+ Queries for QUEUED ExecutorCallback instances and routes them to the
appropriate executor.
+
+ :param session: The database session
+ """
+ num_occupied_slots = sum(executor.slots_occupied for executor in
self.executors)
+ max_callbacks = conf.getint("core", "parallelism") - num_occupied_slots
+
+ if max_callbacks <= 0:
+ self.log.debug("No available slots for callbacks; all executors at
capacity")
+ return
+
+ pending_callbacks = session.scalars(
+ select(ExecutorCallback)
+ .where(ExecutorCallback.type == CallbackType.EXECUTOR)
+ .where(ExecutorCallback.state == CallbackState.PENDING)
+ .order_by(ExecutorCallback.priority_weight.desc())
+ .limit(max_callbacks)
+ ).all()
+
+ if not pending_callbacks:
+ return
+
+ # Route callbacks to executors using the generalized routing method
+ executor_to_callbacks = self._executor_to_workloads(pending_callbacks,
session)
+
+ # Enqueue callbacks for each executor
+ for executor, callbacks in executor_to_callbacks.items():
+ for callback in callbacks:
+ if not isinstance(callback, ExecutorCallback):
+ # Can't happen since we queried ExecutorCallback, but
satisfies mypy.
+ continue
+
+ # TODO: Add dagrun_id as a proper ORM foreign key on the
callback table instead of storing in data dict.
+ # This would eliminate this reconstruction step. For
now, all ExecutorCallbacks
+ # are expected to have dag_run_id set in their data dict
(e.g., by Deadline.handle_miss).
+ if not isinstance(callback.data, dict) or "dag_run_id" not in
callback.data:
+ self.log.error(
+ "ExecutorCallback %s is missing required 'dag_run_id'
in data dict. "
+ "This indicates a bug in callback creation. Skipping
callback.",
+ callback.id,
+ )
+ continue
+
+ dag_run_id = callback.data["dag_run_id"]
+ dag_run = session.get(DagRun, dag_run_id)
+
+ if dag_run is None:
+ self.log.warning(
+ "Could not find DagRun with id=%s for callback %s.
DagRun may have been deleted.",
+ dag_run_id,
+ callback.id,
+ )
+ continue
+
+ workload = workloads.ExecuteCallback.make(
+ callback=callback,
+ dag_run=dag_run,
+ generator=executor.jwt_generator,
+ )
+
+ executor.queue_workload(workload, session=session)
+ callback.state = CallbackState.QUEUED
+ session.add(callback)
+
@staticmethod
def _process_task_event_logs(log_records: deque[Log], session: Session):
objects = (log_records.popleft() for _ in range(len(log_records)))
@@ -1055,21 +1128,50 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int]
= {}
event_buffer = executor.get_event_buffer()
tis_with_right_state: list[TaskInstanceKey] = []
+ callback_keys_with_events: list[str] = []
+
+ # Report execution - handle both task and callback events
+ for key, (state, _) in event_buffer.items():
+ if isinstance(key, TaskInstanceKey):
+ ti_primary_key_to_try_number_map[key.primary] = key.try_number
+ cls.logger().info("Received executor event with state %s for
task instance %s", state, key)
+ if state in (
+ TaskInstanceState.FAILED,
+ TaskInstanceState.SUCCESS,
+ TaskInstanceState.QUEUED,
+ TaskInstanceState.RUNNING,
+ TaskInstanceState.RESTARTING,
+ ):
+ tis_with_right_state.append(key)
+ else:
+ # Callback event (key is string UUID)
+ cls.logger().info("Received executor event with state %s for
callback %s", state, key)
+ if state in (CallbackState.RUNNING, CallbackState.FAILED,
CallbackState.SUCCESS):
+ callback_keys_with_events.append(key)
+
+ # Handle callback state events
+ for callback_id in callback_keys_with_events:
+ state, info = event_buffer.pop(callback_id)
+ callback = session.get(Callback, callback_id)
+ if not callback:
+ # This should not normally happen - we just received an event
for this callback.
+ # Only possible if callback was deleted mid-execution (e.g.,
cascade delete from DagRun deletion).
+ cls.logger().warning(
+ "Callback %s not found in database (may have been cascade
deleted)", callback_id
+ )
+ continue
- # Report execution
- for ti_key, (state, _) in event_buffer.items():
- # We create map (dag_id, task_id, logical_date) -> in-memory
try_number
- ti_primary_key_to_try_number_map[ti_key.primary] =
ti_key.try_number
-
- cls.logger().info("Received executor event with state %s for task
instance %s", state, ti_key)
- if state in (
- TaskInstanceState.FAILED,
- TaskInstanceState.SUCCESS,
- TaskInstanceState.QUEUED,
- TaskInstanceState.RUNNING,
- TaskInstanceState.RESTARTING,
- ):
- tis_with_right_state.append(ti_key)
+ if state == CallbackState.RUNNING:
+ callback.state = CallbackState.RUNNING
+ cls.logger().info("Callback %s is currently running",
callback_id)
+ elif state == CallbackState.SUCCESS:
+ callback.state = CallbackState.SUCCESS
+ cls.logger().info("Callback %s completed successfully",
callback_id)
+ elif state == CallbackState.FAILED:
+ callback.state = CallbackState.FAILED
+ callback.output = str(info) if info else "Execution failed"
+ cls.logger().error("Callback %s failed: %s", callback_id,
callback.output)
+ session.add(callback)
# Return if no finished tasks
if not tis_with_right_state:
@@ -1657,6 +1759,9 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
):
deadline.handle_miss(session)
+ # Route ExecutorCallback workloads to executors (similar
to task routing)
+ self._enqueue_executor_callbacks(session)
+
# Heartbeat the scheduler periodically
perform_heartbeat(
job=self.job, heartbeat_callback=self.heartbeat_callback,
only_if_necessary=True
@@ -2434,7 +2539,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
scheduled) up to 2 times before failing the task.
"""
tasks_stuck_in_queued = self._get_tis_stuck_in_queued(session)
- for executor, stuck_tis in
self._executor_to_tis(tasks_stuck_in_queued, session).items():
+ for executor, stuck_tis in
self._executor_to_workloads(tasks_stuck_in_queued, session).items():
try:
for ti in stuck_tis:
executor.revoke_task(ti=ti)
@@ -2725,7 +2830,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
to_reset: list[TaskInstance] = []
- exec_to_tis = self._executor_to_tis(tis_to_adopt_or_reset,
session)
+ exec_to_tis =
self._executor_to_workloads(tis_to_adopt_or_reset, session)
for executor, tis in exec_to_tis.items():
to_reset.extend(executor.try_adopt_task_instances(tis))
@@ -3074,50 +3179,57 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)
- def _executor_to_tis(
+ def _executor_to_workloads(
self,
- tis: Iterable[TaskInstance],
+ workloads: Iterable[SchedulerWorkload],
session,
dag_id_to_team_name: dict[str, str | None] | None = None,
- ) -> dict[BaseExecutor, list[TaskInstance]]:
- """Organize TIs into lists per their respective executor."""
- tis_iter: Iterable[TaskInstance]
+ ) -> dict[BaseExecutor, list[SchedulerWorkload]]:
+ """Organize workloads into lists per their respective executor."""
+ workloads_iter: Iterable[SchedulerWorkload]
if conf.getboolean("core", "multi_team"):
if dag_id_to_team_name is None:
- if isinstance(tis, list):
- tis_list = tis
+ if isinstance(workloads, list):
+ workloads_list = workloads
else:
- tis_list = list(tis)
- if tis_list:
+ workloads_list = list(workloads)
+ if workloads_list:
dag_id_to_team_name = self._get_team_names_for_dag_ids(
- {ti.dag_id for ti in tis_list}, session
+ {
+ dag_id
+ for workload in workloads_list
+ if (dag_id := workload.get_dag_id()) is not None
+ },
+ session,
)
else:
dag_id_to_team_name = {}
- tis_iter = tis_list
+ workloads_iter = workloads_list
else:
- tis_iter = tis
+ workloads_iter = workloads
else:
dag_id_to_team_name = {}
- tis_iter = tis
+ workloads_iter = workloads
- _executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] =
defaultdict(list)
- for ti in tis_iter:
- if executor_obj := self._try_to_load_executor(
- ti, session, team_name=dag_id_to_team_name.get(ti.dag_id,
NOTSET)
- ):
- _executor_to_tis[executor_obj].append(ti)
+ _executor_to_workloads: defaultdict[BaseExecutor,
list[SchedulerWorkload]] = defaultdict(list)
+ for workload in workloads_iter:
+ _dag_id = workload.get_dag_id()
+ _team = dag_id_to_team_name.get(_dag_id, NOTSET) if _dag_id else
NOTSET
+ if executor_obj := self._try_to_load_executor(workload, session,
team_name=_team):
+ _executor_to_workloads[executor_obj].append(workload)
- return _executor_to_tis
+ return _executor_to_workloads
- def _try_to_load_executor(self, ti: TaskInstance, session,
team_name=NOTSET) -> BaseExecutor | None:
+ def _try_to_load_executor(
+ self, workload: SchedulerWorkload, session, team_name=NOTSET
+ ) -> BaseExecutor | None:
"""
Try to load the given executor.
In this context, we don't want to fail if the executor does not exist.
Catch the exception and
log to the user.
- :param ti: TaskInstance to load executor for
+ :param workload: SchedulerWorkload (TaskInstance or ExecutorCallback)
to load executor for
:param session: Database session for queries
:param team_name: Optional pre-resolved team name. If NOTSET and
multi-team is enabled,
will query the database to resolve team name. None
indicates global team.
@@ -3126,17 +3238,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
if conf.getboolean("core", "multi_team"):
# Use provided team_name if available, otherwise query the database
if team_name is NOTSET:
- team_name = self._get_task_team_name(ti, session)
+ team_name = self._get_workload_team_name(workload, session)
else:
team_name = None
- # Firstly, check if there is no executor set on the TaskInstance, if
not, we need to fetch the default
- # (either globally or for the team)
- if ti.executor is None:
+ # If there is no executor set on the workload fetch the default
(either globally or for the team)
+ if workload.get_executor_name() is None:
if not team_name:
- # No team is specified, so just use the global default executor
+ # No team is specified, use the global default executor
executor = self.executor
else:
- # We do have a team, so we need to find the default executor
for that team
+ # We do have a team, use the default executor for that team
for _executor in self.executors:
# First executor that resolves should be the default for
that team
if _executor.team_name == team_name:
@@ -3146,22 +3257,30 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# No executor found for that team, fall back to global
default
executor = self.executor
else:
- # An executor is specified on the TaskInstance (as a str), so we
need to find it in the list of executors
+ # An executor is specified on the workload (as a str), so we need
to find it in the list of executors
for _executor in self.executors:
- if _executor.name and ti.executor in (_executor.name.alias,
_executor.name.module_path):
+ if _executor.name and workload.get_executor_name() in (
+ _executor.name.alias,
+ _executor.name.module_path,
+ ):
# The executor must either match the team or be global
(i.e. team_name is None)
if team_name and _executor.team_name == team_name or
_executor.team_name is None:
executor = _executor
if executor is not None:
- self.log.debug("Found executor %s for task %s (team: %s)",
executor.name, ti, team_name)
+ self.log.debug(
+ "Found executor %s for task or callback %s (team: %s)",
executor.name, workload, team_name
+ )
else:
# This case should not happen unless some (as of now unknown) edge
case occurs or direct DB
# modification, since the DAG parser will validate the tasks in
the DAG and ensure the executor
# they request is available and if not, disallow the DAG to be
scheduled.
# Keeping this exception handling because this is a critical issue
if we do somehow find
# ourselves here and the user should get some feedback about that.
- self.log.warning("Executor, %s, was not found but a Task was
configured to use it", ti.executor)
+ self.log.warning(
+ "Executor, %s, was not found but a Task or Callback was
configured to use it",
+ workload.get_executor_name(),
+ )
return executor
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index be8213c423f..5567e4763ea 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -45,6 +45,7 @@ from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.timezones import timezone
from airflow.configuration import conf
from airflow.executors import workloads
+from airflow.executors.workloads.task import TaskInstanceDTO
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import perform_heartbeat
from airflow.models.trigger import Trigger
@@ -687,9 +688,7 @@ class TriggerRunnerSupervisor(WatchedSubprocess):
ti_id=new_trigger_orm.task_instance.id,
)
continue
- ser_ti = workloads.TaskInstance.model_validate(
- new_trigger_orm.task_instance, from_attributes=True
- )
+ ser_ti =
TaskInstanceDTO.model_validate(new_trigger_orm.task_instance,
from_attributes=True)
# When producing logs from TIs, include the job id producing
the logs to disambiguate it.
self.logger_cache[new_id] = TriggerLoggingFactory(
log_path=f"{log_path}.trigger.{self.job.id}.log",
diff --git a/airflow-core/src/airflow/models/callback.py
b/airflow-core/src/airflow/models/callback.py
index ea45a10f1f1..ea482ab7ba8 100644
--- a/airflow-core/src/airflow/models/callback.py
+++ b/airflow-core/src/airflow/models/callback.py
@@ -29,8 +29,13 @@ from sqlalchemy.orm import Mapped, mapped_column,
relationship
from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.timezones import timezone
+from airflow.executors.workloads import BaseWorkload
+from airflow.executors.workloads.callback import CallbackFetchMethod
from airflow.models import Base
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
+from airflow.utils.state import CallbackState
+
+CallbackKey = str # Callback keys are str(UUID)
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -41,20 +46,7 @@ if TYPE_CHECKING:
log = structlog.get_logger(__name__)
-class CallbackState(str, Enum):
- """All possible states of callbacks."""
-
- PENDING = "pending"
- QUEUED = "queued"
- RUNNING = "running"
- SUCCESS = "success"
- FAILED = "failed"
-
- def __str__(self) -> str:
- return self.value
-
-
-ACTIVE_STATES = frozenset((CallbackState.QUEUED, CallbackState.RUNNING))
+ACTIVE_STATES = frozenset((CallbackState.PENDING, CallbackState.QUEUED,
CallbackState.RUNNING))
TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED))
@@ -70,16 +62,6 @@ class CallbackType(str, Enum):
DAG_PROCESSOR = "dag_processor"
-class CallbackFetchMethod(str, Enum):
- """Methods used to fetch callback at runtime."""
-
- # For future use once Dag Processor callbacks
(on_success_callback/on_failure_callback) get moved to executors
- DAG_ATTRIBUTE = "dag_attribute"
-
- # For deadline callbacks since they import callbacks through the import
path
- IMPORT_PATH = "import_path"
-
-
class CallbackDefinitionProtocol(Protocol):
"""Protocol for TaskSDK Callback definition."""
@@ -103,7 +85,7 @@ class
ImportPathExecutorCallbackDefProtocol(ImportPathCallbackDefProtocol, Proto
executor: str | None
-class Callback(Base):
+class Callback(Base, BaseWorkload):
"""Base class for callbacks."""
__tablename__ = "callback"
@@ -147,7 +129,7 @@ class Callback(Base):
:param prefix: Optional prefix for metric names
:param kwargs: Additional data emitted in metric tags
"""
- self.state = CallbackState.PENDING
+ self.state = CallbackState.SCHEDULED
self.priority_weight = priority_weight
self.data = kwargs # kwargs can be used to include additional info in
metric tags
if prefix:
@@ -169,6 +151,14 @@ class Callback(Base):
return {"stat": name, "tags": tags}
+ def get_dag_id(self) -> str | None:
+ """Return the DAG ID for scheduler routing."""
+ return self.data.get("dag_id")
+
+ def get_executor_name(self) -> str | None:
+ """Return the executor name for scheduler routing."""
+ return self.data.get("executor")
+
@staticmethod
def create_from_sdk_def(callback_def: CallbackDefinitionProtocol,
**kwargs) -> Callback:
# Cannot check actual type using isinstance() because that would
require SDK import
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index bbf24cc2842..debfe949b31 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -32,10 +32,15 @@ from sqlalchemy.orm import Mapped, mapped_column,
relationship
from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.timezones import timezone
from airflow.models.base import Base
-from airflow.models.callback import Callback, CallbackDefinitionProtocol
+from airflow.models.callback import (
+ Callback,
+ ExecutorCallback,
+ TriggererCallback,
+)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name
+from airflow.utils.state import CallbackState
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -224,9 +229,36 @@ class Deadline(Base):
"deadline": {"id": self.id, "deadline_time":
self.deadline_time},
}
- self.callback.data["kwargs"] = self.callback.data["kwargs"] |
{"context": get_simple_context()}
+ if isinstance(self.callback, TriggererCallback):
+ # Update the callback with context before queuing
+ if "kwargs" not in self.callback.data:
+ self.callback.data["kwargs"] = {}
+ self.callback.data["kwargs"] = (self.callback.data.get("kwargs")
or {}) | {
+ "context": get_simple_context()
+ }
+
+ self.callback.queue()
+ session.add(self.callback)
+ session.flush()
+
+ elif isinstance(self.callback, ExecutorCallback):
+ if "kwargs" not in self.callback.data:
+ self.callback.data["kwargs"] = {}
+ self.callback.data["kwargs"] = (self.callback.data.get("kwargs")
or {}) | {
+ "context": get_simple_context()
+ }
+ self.callback.data["deadline_id"] = str(self.id)
+ self.callback.data["dag_run_id"] = str(self.dagrun.id)
+ self.callback.data["dag_id"] = self.dagrun.dag_id
+
+ self.callback.state = CallbackState.PENDING
+ session.add(self.callback)
+ session.flush()
+
+ else:
+ raise TypeError(f"Unknown Callback type:
{type(self.callback).__name__}")
+
self.missed = True
- self.callback.queue()
session.add(self)
Stats.incr(
"deadline_alerts.deadline_missed",
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 475cbd7ae68..8b12db483af 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -70,6 +70,7 @@ from airflow._shared.observability.metrics.stats import Stats
from airflow._shared.timezones import timezone
from airflow.assets.manager import asset_manager
from airflow.configuration import conf
+from airflow.executors.workloads import BaseWorkload
from airflow.listeners.listener import get_listener_manager
from airflow.models.asset import AssetModel
from airflow.models.base import Base, StringID, TaskInstanceDependencies
@@ -406,7 +407,7 @@ def uuid7() -> UUID:
return uuid6.uuid7()
-class TaskInstance(Base, LoggingMixin):
+class TaskInstance(Base, LoggingMixin, BaseWorkload):
"""
Task instances store the state of a task instance.
@@ -802,6 +803,14 @@ class TaskInstance(Base, LoggingMixin):
"""Returns a tuple that identifies the task instance uniquely."""
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id,
self.try_number, self.map_index)
+ def get_dag_id(self) -> str:
+ """Return the DAG ID for scheduler routing."""
+ return self.dag_id
+
+ def get_executor_name(self) -> str | None:
+ """Return the executor name for scheduler routing."""
+ return self.executor
+
@provide_session
def set_state(self, state: str | None, session: Session = NEW_SESSION) ->
bool:
"""
diff --git a/airflow-core/src/airflow/utils/state.py
b/airflow-core/src/airflow/utils/state.py
index b392a023525..332efb10553 100644
--- a/airflow-core/src/airflow/utils/state.py
+++ b/airflow-core/src/airflow/utils/state.py
@@ -20,6 +20,20 @@ from __future__ import annotations
from enum import Enum
+class CallbackState(str, Enum):
+ """All possible states of callbacks."""
+
+ SCHEDULED = "scheduled"
+ PENDING = "pending"
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+ def __str__(self) -> str:
+ return self.value
+
+
class TerminalTIState(str, Enum):
"""States that a Task Instance can be in that indicate it has reached a
terminal state."""
diff --git a/airflow-core/tests/unit/executors/test_base_executor.py
b/airflow-core/tests/unit/executors/test_base_executor.py
index 5c2a3d6d549..fa0f311d018 100644
--- a/airflow-core/tests/unit/executors/test_base_executor.py
+++ b/airflow-core/tests/unit/executors/test_base_executor.py
@@ -25,6 +25,7 @@ from uuid import UUID
import pendulum
import pytest
+import structlog
import time_machine
from airflow._shared.timezones import timezone
@@ -34,6 +35,9 @@ from airflow.cli.cli_parser import AirflowHelpFormatter
from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor,
RunningRetryAttemptType
from airflow.executors.local_executor import LocalExecutor
+from airflow.executors.workloads.base import BundleInfo
+from airflow.executors.workloads.callback import CallbackDTO,
execute_callback_workload
+from airflow.models.callback import CallbackFetchMethod
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.sdk import BaseOperator
from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
@@ -573,3 +577,148 @@ class TestExecutorConf:
team_executor_conf = ExecutorConf(team_name="test_team")
assert team_executor_conf.get_mandatory_value("celery", "broker_url")
== "redis://team-redis"
+
+
+class TestCallbackSupport:
+ def test_supports_callbacks_flag_default_false(self):
+ executor = BaseExecutor()
+ assert executor.supports_callbacks is False
+
+ def test_local_executor_supports_callbacks_true(self):
+ """Test that LocalExecutor sets supports_callbacks to True."""
+ executor = LocalExecutor()
+ assert executor.supports_callbacks is True
+
+ @pytest.mark.db_test
+ def test_queue_callback_without_support_raises_error(self, dag_maker,
session):
+ executor = BaseExecutor() # supports_callbacks = False by default
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={"path": "test.func", "kwargs": {}},
+ )
+ callback_workload = workloads.ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path="test.py",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+ with pytest.raises(NotImplementedError, match="does not support
ExecuteCallback"):
+ executor.queue_workload(callback_workload, session)
+
+ @pytest.mark.db_test
+ def test_queue_workload_with_execute_callback(self, dag_maker, session):
+ executor = BaseExecutor()
+ executor.supports_callbacks = True # Enable for this test
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={"path": "test.func", "kwargs": {}},
+ )
+ callback_workload = workloads.ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path="test.py",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+ executor.queue_workload(callback_workload, session)
+
+ assert len(executor.queued_callbacks) == 1
+ assert callback_data.id in executor.queued_callbacks
+
+ @pytest.mark.db_test
+ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session):
+ executor = BaseExecutor()
+ executor.supports_callbacks = True # Enable for this test
+ dagrun = setup_dagrun(dag_maker)
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={"path": "test.func", "kwargs": {}},
+ )
+ callback_workload = workloads.ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path="test.py",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+ executor.queue_workload(callback_workload, session)
+
+ for ti in dagrun.task_instances:
+ task_workload = workloads.ExecuteTask.make(ti)
+ executor.queue_workload(task_workload, session)
+
+ workloads_to_schedule =
executor._get_workloads_to_schedule(open_slots=10)
+
+ assert len(workloads_to_schedule) == 4 # 1 callback + 3 tasks
+ _, first_workload = workloads_to_schedule[0]
+ assert isinstance(first_workload, workloads.ExecuteCallback) # Assert
callback comes first
+
+
+class TestExecuteCallbackWorkload:
+ def test_execute_function_callback_success(self):
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={
+ "path": "builtins.dict",
+ "kwargs": {"a": 1, "b": 2, "c": 3},
+ },
+ )
+ log = structlog.get_logger()
+
+ success, error = execute_callback_workload(callback_data, log)
+
+ assert success is True
+ assert error is None
+
+ def test_execute_callback_missing_path(self):
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={"kwargs": {}}, # Missing 'path'
+ )
+ log = structlog.get_logger()
+
+ success, error = execute_callback_workload(callback_data, log)
+
+ assert success is False
+ assert "Callback path not found" in error
+
+ def test_execute_callback_import_error(self):
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={
+ "path": "nonexistent.module.function",
+ "kwargs": {},
+ },
+ )
+ log = structlog.get_logger()
+
+ success, error = execute_callback_workload(callback_data, log)
+
+ assert success is False
+ assert "ModuleNotFoundError" in error
+
+ def test_execute_callback_execution_error(self):
+ # Use a function that will raise an error; len() requires an argument
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={
+ "path": "builtins.len",
+ "kwargs": {},
+ },
+ )
+ log = structlog.get_logger()
+
+ success, error = execute_callback_workload(callback_data, log)
+
+ assert success is False
+ assert "TypeError" in error
diff --git a/airflow-core/tests/unit/executors/test_local_executor.py
b/airflow-core/tests/unit/executors/test_local_executor.py
index 5f216cca2e7..34e8f818aa9 100644
--- a/airflow-core/tests/unit/executors/test_local_executor.py
+++ b/airflow-core/tests/unit/executors/test_local_executor.py
@@ -29,6 +29,10 @@ from uuid6 import uuid7
from airflow._shared.timezones import timezone
from airflow.executors import workloads
from airflow.executors.local_executor import LocalExecutor, _execute_work
+from airflow.executors.workloads.base import BundleInfo
+from airflow.executors.workloads.callback import CallbackDTO
+from airflow.executors.workloads.task import TaskInstanceDTO
+from airflow.models.callback import CallbackFetchMethod
from airflow.settings import Session
from airflow.utils.state import State
@@ -81,7 +85,7 @@ class TestLocalExecutor:
@mock.patch("airflow.sdk.execution_time.supervisor.supervise")
def test_execution(self, mock_supervise):
success_tis = [
- workloads.TaskInstance(
+ TaskInstanceDTO(
id=uuid7(),
dag_version_id=uuid7(),
task_id=f"success_{i}",
@@ -327,3 +331,40 @@ class TestLocalExecutor:
assert len(executor.workers) == 2
executor.end()
+
+
+class TestLocalExecutorCallbackSupport:
+ def test_supports_callbacks_flag_is_true(self):
+ executor = LocalExecutor()
+ assert executor.supports_callbacks is True
+
+ @skip_spawn_mp_start
+
@mock.patch("airflow.executors.workloads.callback.execute_callback_workload")
+ def test_process_callback_workload(self, mock_execute_callback):
+ mock_execute_callback.return_value = (True, None)
+
+ executor = LocalExecutor(parallelism=1)
+ callback_data = CallbackDTO(
+ id="12345678-1234-5678-1234-567812345678",
+ fetch_method=CallbackFetchMethod.IMPORT_PATH,
+ data={"path": "test.func", "kwargs": {}},
+ )
+ callback_workload = workloads.ExecuteCallback(
+ callback=callback_data,
+ dag_rel_path="test.py",
+ bundle_info=BundleInfo(name="test_bundle", version="1.0"),
+ token="test_token",
+ log_path="test.log",
+ )
+
+ executor.start()
+
+ try:
+ executor.queued_callbacks[callback_data.id] = callback_workload
+ executor._process_workloads([callback_workload])
+ assert len(executor.queued_callbacks) == 0
+ # We can't easily verify worker execution without running the
worker,
+ # but we can verify the helper is called via mock
+
+ finally:
+ executor.end()
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index ea9c51220ca..abbd5f20067 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -65,6 +65,7 @@ from airflow.models.asset import (
PartitionedAssetKeyLog,
)
from airflow.models.backfill import Backfill, _create_backfill
+from airflow.models.callback import ExecutorCallback
from airflow.models.dag import DagModel, get_last_dagrun,
infer_automated_data_interval
from airflow.models.dag_version import DagVersion
from airflow.models.dagbundle import DagBundleModel
@@ -92,7 +93,7 @@ from airflow.serialization.serialized_objects import
LazyDeserializedDAG
from airflow.timetables.base import DagRunInfo, DataInterval
from airflow.utils.session import create_session, provide_session
from airflow.utils.span_status import SpanStatus
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import CallbackState, DagRunState, State,
TaskInstanceState
from airflow.utils.thread_safe_dict import ThreadSafeDict
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -556,6 +557,48 @@ class TestSchedulerJob:
any_order=True,
)
+ def test_enqueue_executor_callbacks_only_selects_pending_state(self,
dag_maker, session):
+ def test_callback():
+ pass
+
+ def create_callback_in_state(state: CallbackState):
+ callback = Deadline(
+ deadline_time=timezone.utcnow(),
+ callback=SyncCallback(test_callback),
+ dagrun_id=dag_run.id,
+ deadline_alert_id=None,
+ ).callback
+ callback.state = state
+ callback.data["dag_run_id"] = dag_run.id
+ callback.data["dag_id"] = dag_run.dag_id
+ return callback
+
+ with dag_maker(dag_id="test_callback_states"):
+ pass
+ dag_run = dag_maker.create_dagrun()
+
+ scheduled_callback = create_callback_in_state(CallbackState.SCHEDULED)
+ pending_callback = create_callback_in_state(CallbackState.PENDING)
+ queued_callback = create_callback_in_state(CallbackState.QUEUED)
+ running_callback = create_callback_in_state(CallbackState.RUNNING)
+ session.add_all([scheduled_callback, pending_callback,
queued_callback, running_callback])
+ session.flush()
+
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=scheduler_job)
+
+ # Verify initial state before calling _enqueue_executor_callbacks
+ assert session.get(ExecutorCallback, pending_callback.id).state ==
CallbackState.PENDING
+
+ self.job_runner._enqueue_executor_callbacks(session)
+ # PENDING should progress to QUEUED after _enqueue_executor_callbacks
+ assert session.get(ExecutorCallback, pending_callback.id).state ==
CallbackState.QUEUED
+
+ # Other callbacks should remain in their original states
+ assert session.get(ExecutorCallback, scheduled_callback.id).state ==
CallbackState.SCHEDULED
+ assert session.get(ExecutorCallback, queued_callback.id).state ==
CallbackState.QUEUED
+ assert session.get(ExecutorCallback, running_callback.id).state ==
CallbackState.RUNNING
+
@mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest")
@mock.patch("airflow.jobs.scheduler_job_runner.Stats.incr")
def test_process_executor_events_with_callback(
@@ -1311,7 +1354,7 @@ class TestSchedulerJob:
assert len(res) == 5
# Verify that each task is routed to the correct executor
- executor_to_tis = self.job_runner._executor_to_tis(res, session)
+ executor_to_tis = self.job_runner._executor_to_workloads(res, session)
# Team pi tasks should go to mock_executors[0] (configured for team_pi)
a_tis_in_executor = [ti for ti in
executor_to_tis.get(mock_executors[0], []) if ti.dag_id == "dag_a"]
@@ -7909,7 +7952,7 @@ class TestSchedulerJob:
assert result == {}
mock_log.exception.assert_called_once()
- def test_multi_team_get_task_team_name_success(self, dag_maker, session):
+ def test_multi_team_get_workload_team_name_success(self, dag_maker,
session):
"""Test successful team name resolution for a single task."""
clear_db_teams()
clear_db_dag_bundles()
@@ -7932,10 +7975,10 @@ class TestSchedulerJob:
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._get_task_team_name(ti, session)
+ result = self.job_runner._get_workload_team_name(ti, session)
assert result == "team_a"
- def test_multi_team_get_task_team_name_no_team(self, dag_maker, session):
+ def test_multi_team_get_workload_team_name_no_team(self, dag_maker,
session):
"""Test team resolution when no team is associated with the DAG."""
with dag_maker(dag_id="dag_no_team", session=session):
task = EmptyOperator(task_id="task_no_team")
@@ -7946,10 +7989,10 @@ class TestSchedulerJob:
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._get_task_team_name(ti, session)
+ result = self.job_runner._get_workload_team_name(ti, session)
assert result is None
- def test_multi_team_get_task_team_name_database_error(self, dag_maker,
session):
+ def test_multi_team_get_workload_team_name_database_error(self, dag_maker,
session):
"""Test graceful error handling when individual task team resolution
fails. This code should _not_ fail the scheduler."""
with dag_maker(dag_id="dag_test", session=session):
task = EmptyOperator(task_id="task_test")
@@ -7962,7 +8005,7 @@ class TestSchedulerJob:
# Mock _get_team_names_for_dag_ids to return empty dict (simulates
database error handling in that function)
with mock.patch.object(self.job_runner, "_get_team_names_for_dag_ids",
return_value={}) as mock_batch:
- result = self.job_runner._get_task_team_name(ti, session)
+ result = self.job_runner._get_workload_team_name(ti, session)
mock_batch.assert_called_once_with([ti.dag_id], session)
# Should return None when batch function returns empty dict
@@ -7980,7 +8023,7 @@ class TestSchedulerJob:
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- with mock.patch.object(self.job_runner, "_get_task_team_name") as
mock_team_resolve:
+ with mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_team_resolve:
result = self.job_runner._try_to_load_executor(ti, session)
# Should not call team resolution when multi_team is disabled
mock_team_resolve.assert_not_called()
@@ -8177,7 +8220,8 @@ class TestSchedulerJob:
# Should log a warning when no executor is found
mock_log.warning.assert_called_once_with(
- "Executor, %s, was not found but a Task was configured to use
it", "secondary_exec"
+ "Executor, %s, was not found but a Task or Callback was
configured to use it",
+ "secondary_exec",
)
# Should return None since we failed to resolve an executor due to the
mismatch. In practice, this
@@ -8229,7 +8273,7 @@ class TestSchedulerJob:
self.job_runner = SchedulerJobRunner(job=scheduler_job)
# Call with pre-resolved team name (as done in the scheduling loop)
- with mock.patch.object(self.job_runner, "_get_task_team_name") as
mock_team_resolve:
+ with mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_team_resolve:
result = self.job_runner._try_to_load_executor(ti, session,
team_name="team_a")
mock_team_resolve.assert_not_called() # We don't query for the
team if it is pre-resolved
@@ -8320,13 +8364,13 @@ class TestSchedulerJob:
with (
assert_queries_count(1, session=session),
- mock.patch.object(self.job_runner, "_get_task_team_name") as
mock_single,
+ mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_single,
):
- executor_to_tis = self.job_runner._executor_to_tis([ti1, ti2],
session)
+ executor_to_workloads =
self.job_runner._executor_to_workloads([ti1, ti2], session)
mock_single.assert_not_called()
- assert executor_to_tis[mock_executors[0]] == [ti1]
- assert executor_to_tis[mock_executors[1]] == [ti2]
+ assert executor_to_workloads[mock_executors[0]] == [ti1]
+ assert executor_to_workloads[mock_executors[1]] == [ti2]
@conf_vars({("core", "multi_team"): "false"})
def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker,
mock_executors, session):
@@ -8342,7 +8386,7 @@ class TestSchedulerJob:
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- with mock.patch.object(self.job_runner, "_get_task_team_name") as
mock_team_resolve:
+ with mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_team_resolve:
result1 = self.job_runner._try_to_load_executor(ti1, session)
result2 = self.job_runner._try_to_load_executor(ti2, session)
diff --git a/airflow-core/tests/unit/models/test_callback.py
b/airflow-core/tests/unit/models/test_callback.py
index dfc19fc61a3..6ab6ad2d02d 100644
--- a/airflow-core/tests/unit/models/test_callback.py
+++ b/airflow-core/tests/unit/models/test_callback.py
@@ -123,7 +123,7 @@ class TestTriggererCallback:
assert isinstance(retrieved, TriggererCallback)
assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH
assert retrieved.data == TEST_ASYNC_CALLBACK.serialize()
- assert retrieved.state == CallbackState.PENDING.value
+ assert retrieved.state == CallbackState.SCHEDULED.value
assert retrieved.output is None
assert retrieved.priority_weight == 1
assert retrieved.created_at is not None
@@ -131,7 +131,7 @@ class TestTriggererCallback:
def test_queue(self, session):
callback = TriggererCallback(TEST_ASYNC_CALLBACK)
- assert callback.state == CallbackState.PENDING
+ assert callback.state == CallbackState.SCHEDULED
assert callback.trigger is None
callback.queue()
@@ -193,7 +193,7 @@ class TestExecutorCallback:
assert isinstance(retrieved, ExecutorCallback)
assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH
assert retrieved.data == TEST_SYNC_CALLBACK.serialize()
- assert retrieved.state == CallbackState.PENDING.value
+ assert retrieved.state == CallbackState.SCHEDULED.value
assert retrieved.output is None
assert retrieved.priority_weight == 1
assert retrieved.created_at is not None
@@ -201,7 +201,7 @@ class TestExecutorCallback:
def test_queue(self):
callback = ExecutorCallback(TEST_SYNC_CALLBACK,
fetch_method=CallbackFetchMethod.DAG_ATTRIBUTE)
- assert callback.state == CallbackState.PENDING
+ assert callback.state == CallbackState.SCHEDULED
callback.queue()
assert callback.state == CallbackState.QUEUED
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index e7bef2859af..12f7e3ac427 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -2092,6 +2092,7 @@ winrm
WIT
workgroup
workgroups
+WorkloadKey
workspaces
writeable
wsman
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
index 7144425b2c3..cd142acc7a1 100644
--- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
+++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py
@@ -39,7 +39,7 @@ from deprecated import deprecated
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
-from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_2_PLUS
from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats
from airflow.utils.state import TaskInstanceState
@@ -52,13 +52,11 @@ CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task"
if TYPE_CHECKING:
from collections.abc import Sequence
- from sqlalchemy.orm import Session
-
from airflow.cli.cli_config import GroupCommand
from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
- from airflow.providers.celery.executors.celery_executor_utils import
TaskInstanceInCelery, TaskTuple
+ from airflow.providers.celery.executors.celery_executor_utils import
TaskTuple, WorkloadInCelery
# PEP562
@@ -91,16 +89,17 @@ class CeleryExecutor(BaseExecutor):
"""
supports_ad_hoc_ti_run: bool = True
+ supports_callbacks: bool = True
sentry_integration: str =
"sentry_sdk.integrations.celery.CeleryIntegration"
# TODO: Remove this flag once providers depend on Airflow 3.2.
supports_sentry: bool = True
supports_multi_team: bool = True
- if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
- # In the v3 path, we store workloads, not commands as strings.
- # TODO: TaskSDK: move this type change into BaseExecutor
- queued_tasks: dict[TaskInstanceKey, workloads.All] # type:
ignore[assignment]
+ if TYPE_CHECKING:
+ if AIRFLOW_V_3_0_PLUS:
+ # TODO: TaskSDK: move this type change into BaseExecutor
+ queued_tasks: dict[TaskInstanceKey, workloads.All] # type:
ignore[assignment]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -160,18 +159,25 @@ class CeleryExecutor(BaseExecutor):
# Airflow V3 version -- have to delay imports until we know we are on
v3
from airflow.executors.workloads import ExecuteTask
- tasks = [
- (workload.ti.key, workload, workload.ti.queue, self.team_name)
- for workload in workloads
- if isinstance(workload, ExecuteTask)
- ]
- if len(tasks) != len(workloads):
- invalid = list(workload for workload in workloads if not
isinstance(workload, ExecuteTask))
- raise ValueError(f"{type(self)}._process_workloads cannot handle
{invalid}")
+ if AIRFLOW_V_3_2_PLUS:
+ from airflow.executors.workloads import ExecuteCallback
+
+ tasks: list[WorkloadInCelery] = []
+ for workload in workloads:
+ if isinstance(workload, ExecuteTask):
+ tasks.append((workload.ti.key, workload, workload.ti.queue,
self.team_name))
+ elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):
+ # Use default queue for callbacks, or extract from callback
data if available
+ queue = "default"
+ if isinstance(workload.callback.data, dict) and "queue" in
workload.callback.data:
+ queue = workload.callback.data["queue"]
+ tasks.append((workload.callback.key, workload, queue,
self.team_name))
+ else:
+ raise ValueError(f"{type(self)}._process_workloads cannot
handle {type(workload)}")
self._send_tasks(tasks)
- def _send_tasks(self, task_tuples_to_send: Sequence[TaskInstanceInCelery]):
+ def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
# Celery state queries will be stuck if we do not use one same backend
# for all tasks.
cached_celery_backend = self.celery_app.backend
@@ -195,7 +201,10 @@ class CeleryExecutor(BaseExecutor):
)
self.task_publish_retries[key] = retries + 1
continue
- self.queued_tasks.pop(key)
+ if key in self.queued_tasks:
+ self.queued_tasks.pop(key)
+ else:
+ self.queued_callbacks.pop(key, None)
self.task_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER,
result.exception, result.traceback)
@@ -210,7 +219,7 @@ class CeleryExecutor(BaseExecutor):
# which point we don't need the ID anymore anyway
self.event_buffer[key] = (TaskInstanceState.QUEUED,
result.task_id)
- def _send_tasks_to_celery(self, task_tuples_to_send:
Sequence[TaskInstanceInCelery]):
+ def _send_tasks_to_celery(self, task_tuples_to_send:
Sequence[WorkloadInCelery]):
from airflow.providers.celery.executors.celery_executor_utils import
send_task_to_executor
if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
@@ -375,11 +384,3 @@ class CeleryExecutor(BaseExecutor):
from airflow.providers.celery.cli.definition import
get_celery_cli_commands
return get_celery_cli_commands()
-
- def queue_workload(self, workload: workloads.All, session: Session | None)
-> None:
- from airflow.executors import workloads
-
- if not isinstance(workload, workloads.ExecuteTask):
- raise RuntimeError(f"{type(self)} cannot handle workloads of type
{type(workload)}")
- ti = workload.ti
- self.queued_tasks[ti.key] = workload
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
index b09737701f3..578d0a909ac 100644
---
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
+++
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
@@ -53,6 +53,9 @@ try:
except ImportError:
from airflow.utils.dag_parsing_context import
_airflow_parsing_context_manager
+if AIRFLOW_V_3_2_PLUS:
+ from airflow.executors.workloads.callback import execute_callback_workload
+
log = logging.getLogger(__name__)
if sys.platform == "darwin":
@@ -67,16 +70,21 @@ if TYPE_CHECKING:
from airflow.executors import workloads
from airflow.executors.base_executor import EventBufferValueType,
ExecutorConf
+ from airflow.executors.workloads.types import WorkloadKey
from airflow.models.taskinstance import TaskInstanceKey
# We can't use `if AIRFLOW_V_3_0_PLUS` conditions in type checks, so
unfortunately we just have to define
# the type as the union of both kinds
CommandType = Sequence[str]
- TaskInstanceInCelery: TypeAlias = tuple[
- TaskInstanceKey, workloads.All | CommandType, str | None, str | None
+ WorkloadInCelery: TypeAlias = tuple[WorkloadKey, workloads.All |
CommandType, str | None, str | None]
+ WorkloadInCeleryResult: TypeAlias = tuple[
+ WorkloadKey, CommandType, AsyncResult | "ExceptionWithTraceback"
]
+ # Deprecated alias for backward compatibility
+ TaskInstanceInCelery: TypeAlias = WorkloadInCelery
+
TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None]
OPERATION_TIMEOUT = conf.getfloat("celery", "operation_timeout")
@@ -182,9 +190,6 @@ def execute_workload(input: str) -> None:
celery_task_id = app.current_task.request.id
- if not isinstance(workload, workloads.ExecuteTask):
- raise ValueError(f"CeleryExecutor does not know how to handle
{type(workload)}")
-
log.info("[%s] Executing workload in Celery: %s", celery_task_id, workload)
base_url = conf.get("api", "base_url", fallback="/")
@@ -193,15 +198,22 @@ def execute_workload(input: str) -> None:
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"
- supervise(
- # This is the "wrong" ti type, but it duck types the same. TODO:
Create a protocol for this.
- ti=workload.ti, # type: ignore[arg-type]
- dag_rel_path=workload.dag_rel_path,
- bundle_info=workload.bundle_info,
- token=workload.token,
- server=conf.get("core", "execution_api_server_url",
fallback=default_execution_api_server),
- log_path=workload.log_path,
- )
+ if isinstance(workload, workloads.ExecuteTask):
+ supervise(
+ # This is the "wrong" ti type, but it duck types the same. TODO:
Create a protocol for this.
+ ti=workload.ti, # type: ignore[arg-type]
+ dag_rel_path=workload.dag_rel_path,
+ bundle_info=workload.bundle_info,
+ token=workload.token,
+ server=conf.get("core", "execution_api_server_url",
fallback=default_execution_api_server),
+ log_path=workload.log_path,
+ )
+ elif isinstance(workload, workloads.ExecuteCallback):
+ success, error_msg = execute_callback_workload(workload.callback, log)
+ if not success:
+ raise RuntimeError(error_msg or "Callback execution failed")
+ else:
+ raise ValueError(f"CeleryExecutor does not know how to handle
{type(workload)}")
if not AIRFLOW_V_3_0_PLUS:
@@ -303,16 +315,16 @@ class ExceptionWithTraceback:
self.traceback = exception_traceback
-def send_task_to_executor(
- task_tuple: TaskInstanceInCelery,
-) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]:
+def send_workload_to_executor(
+ workload_tuple: WorkloadInCelery,
+) -> WorkloadInCeleryResult:
"""
- Send task to executor.
+ Send workload to executor.
This function is called in ProcessPoolExecutor subprocesses. To avoid
pickling issues with
team-specific Celery apps, we pass the team_name and reconstruct the
Celery app here.
"""
- key, args, queue, team_name = task_tuple
+ key, args, queue, team_name = workload_tuple
# Reconstruct the Celery app from configuration, which may or may not be
team-specific.
# ExecutorConf wraps config access to automatically use team-specific
config where present.
@@ -326,7 +338,6 @@ def send_task_to_executor(
else:
# Airflow <3.2 ExecutorConf doesn't exist (at least not with the
required attributes), fall back to global conf
_conf = conf
-
# Create the Celery app with the correct configuration
celery_app = create_celery_app(_conf)
@@ -362,6 +373,10 @@ def send_task_to_executor(
return key, args, result
+# Backward compatibility alias
+send_task_to_executor = send_workload_to_executor
+
+
def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str |
ExceptionWithTraceback, Any]:
"""
Fetch and return the state of the given celery task.
diff --git
a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
index 49ae5b35b6f..34cfb27e86a 100644
---
a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
+++
b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py
@@ -104,18 +104,18 @@ class CeleryKubernetesExecutor(BaseExecutor):
def queued_tasks(self) -> dict[TaskInstanceKey, Any]:
"""Return queued tasks from celery and kubernetes executor."""
queued_tasks = self.celery_executor.queued_tasks.copy()
- queued_tasks.update(self.kubernetes_executor.queued_tasks)
+ queued_tasks.update(self.kubernetes_executor.queued_tasks) # type:
ignore[arg-type]
- return queued_tasks
+ return queued_tasks # type: ignore[return-value]
@queued_tasks.setter
def queued_tasks(self, value) -> None:
"""Not implemented for hybrid executors."""
- @property
+ @property # type: ignore[override]
def running(self) -> set[TaskInstanceKey]:
"""Return running tasks from celery and kubernetes executor."""
- return
self.celery_executor.running.union(self.kubernetes_executor.running)
+ return
self.celery_executor.running.union(self.kubernetes_executor.running) # type:
ignore[return-value, arg-type]
@running.setter
def running(self, value) -> None:
@@ -225,7 +225,7 @@ class CeleryKubernetesExecutor(BaseExecutor):
self.celery_executor.heartbeat()
self.kubernetes_executor.heartbeat()
- def get_event_buffer(
+ def get_event_buffer( # type: ignore[override]
self, dag_ids: list[str] | None = None
) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
@@ -237,7 +237,7 @@ class CeleryKubernetesExecutor(BaseExecutor):
cleared_events_from_celery =
self.celery_executor.get_event_buffer(dag_ids)
cleared_events_from_kubernetes =
self.kubernetes_executor.get_event_buffer(dag_ids)
- return {**cleared_events_from_celery, **cleared_events_from_kubernetes}
+ return {**cleared_events_from_celery,
**cleared_events_from_kubernetes} # type: ignore[dict-item]
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) ->
Sequence[TaskInstance]:
"""
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 3b055dd0b4d..da8ee15571b 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -42,6 +42,7 @@ from uuid6 import uuid7
from airflow._shared.timezones import timezone
from airflow.configuration import conf
from airflow.executors import workloads
+from airflow.executors.workloads.task import TaskInstanceDTO
from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.compat.sdk import AirflowException,
AirflowTaskTimeout, TaskInstanceKey
@@ -197,7 +198,7 @@ def setup_dagrun_with_success_and_fail_tasks(dag_maker):
executor.start()
with start_worker(app=app, logfile=sys.stdout, loglevel="info"):
- ti_success = workloads.TaskInstance.model_construct(
+ ti_success = TaskInstanceDTO.model_construct(
id=uuid7(),
task_id="success",
dag_id="id",
@@ -257,7 +258,7 @@ def setup_dagrun_with_success_and_fail_tasks(dag_maker):
else:
ti = TaskInstance(task=task, run_id="abc")
workload = workloads.ExecuteTask.model_construct(
- ti=workloads.TaskInstance.model_validate(ti,
from_attributes=True),
+ ti=TaskInstanceDTO.model_validate(ti, from_attributes=True),
)
key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1)
@@ -309,7 +310,7 @@ def setup_dagrun_with_success_and_fail_tasks(dag_maker):
else:
ti = TaskInstance(task=task, run_id="abc")
workload = workloads.ExecuteTask.model_construct(
- ti=workloads.TaskInstance.model_validate(ti,
from_attributes=True),
+ ti=TaskInstanceDTO.model_validate(ti, from_attributes=True),
)
key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1)
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py
index 114da7ec36f..f2eb64e46c5 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py
@@ -108,10 +108,10 @@ class LocalKubernetesExecutor(BaseExecutor):
def queued_tasks(self, value) -> None:
"""Not implemented for hybrid executors."""
- @property
+ @property # type: ignore[override]
def running(self) -> set[TaskInstanceKey]:
"""Return running tasks from local and kubernetes executor."""
- return
self.local_executor.running.union(self.kubernetes_executor.running)
+ return
self.local_executor.running.union(self.kubernetes_executor.running) # type:
ignore[return-value, arg-type]
@running.setter
def running(self, value) -> None:
@@ -219,7 +219,7 @@ class LocalKubernetesExecutor(BaseExecutor):
self.local_executor.heartbeat()
self.kubernetes_executor.heartbeat()
- def get_event_buffer(
+ def get_event_buffer( # type: ignore[override]
self, dag_ids: list[str] | None = None
) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
@@ -231,7 +231,7 @@ class LocalKubernetesExecutor(BaseExecutor):
cleared_events_from_local =
self.local_executor.get_event_buffer(dag_ids)
cleared_events_from_kubernetes =
self.kubernetes_executor.get_event_buffer(dag_ids)
- return {**cleared_events_from_local, **cleared_events_from_kubernetes}
+ return {**cleared_events_from_local, **cleared_events_from_kubernetes}
# type: ignore[dict-item]
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) ->
Sequence[TaskInstance]:
"""
diff --git
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 293d8700ec7..94feded9ff9 100644
---
a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++
b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -1005,7 +1005,7 @@ components:
- type: 'null'
title: Log Path
ti:
- $ref: '#/components/schemas/TaskInstance'
+ $ref: '#/components/schemas/TaskInstanceDTO'
sentry_integration:
type: string
title: Sentry Integration
@@ -1151,7 +1151,7 @@ components:
- log_chunk_data
title: PushLogsBody
description: Incremental new log content from worker.
- TaskInstance:
+ TaskInstanceDTO:
properties:
id:
type: string
@@ -1209,7 +1209,7 @@ components:
- pool_slots
- queue
- priority_weight
- title: TaskInstance
+ title: TaskInstanceDTO
description: Schema for TaskInstance with minimal required fields needed
for
Executors and Task SDK.
TaskInstanceState:
diff --git a/task-sdk/src/airflow/sdk/definitions/deadline.py
b/task-sdk/src/airflow/sdk/definitions/deadline.py
index aeb1ff89010..8c55e10d45c 100644
--- a/task-sdk/src/airflow/sdk/definitions/deadline.py
+++ b/task-sdk/src/airflow/sdk/definitions/deadline.py
@@ -21,7 +21,7 @@ from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from airflow.models.deadline import DeadlineReferenceType, ReferenceModels
-from airflow.sdk.definitions.callback import AsyncCallback, Callback
+from airflow.sdk.definitions.callback import AsyncCallback, Callback,
SyncCallback
if TYPE_CHECKING:
from collections.abc import Callable
@@ -44,7 +44,7 @@ class DeadlineAlert:
self.reference = reference
self.interval = interval
- if not isinstance(callback, AsyncCallback):
+ if not isinstance(callback, (AsyncCallback, SyncCallback)):
raise ValueError(f"Callbacks of type {type(callback).__name__} are
not currently supported")
self.callback = callback
diff --git a/task-sdk/tests/task_sdk/definitions/test_deadline.py
b/task-sdk/tests/task_sdk/definitions/test_deadline.py
index 1025cfc27a3..8e9e816b307 100644
--- a/task-sdk/tests/task_sdk/definitions/test_deadline.py
+++ b/task-sdk/tests/task_sdk/definitions/test_deadline.py
@@ -138,10 +138,27 @@ class TestDeadlineAlert:
alert_set = {alert1, alert2}
assert len(alert_set) == 1
- def test_deadline_alert_unsupported_callback(self):
- with pytest.raises(ValueError, match="Callbacks of type SyncCallback
are not currently supported"):
+ @pytest.mark.parametrize(
+ ("callback_class"),
+ [
+ pytest.param(AsyncCallback, id="async_callback"),
+ pytest.param(SyncCallback, id="sync_callback"),
+ ],
+ )
+ def test_deadline_alert_accepts_all_callbacks(self, callback_class):
+ alert = DeadlineAlert(
+ reference=DeadlineReference.DAGRUN_QUEUED_AT,
+ interval=timedelta(hours=1),
+ callback=callback_class(TEST_CALLBACK_PATH),
+ )
+ assert alert.callback is not None
+ assert isinstance(alert.callback, callback_class)
+
+ def test_deadline_alert_rejects_invalid_callback(self):
+ """Test that DeadlineAlert rejects non-callback types."""
+ with pytest.raises(ValueError, match="Callbacks of type str are not
currently supported"):
DeadlineAlert(
reference=DeadlineReference.DAGRUN_QUEUED_AT,
interval=timedelta(hours=1),
- callback=SyncCallback(TEST_CALLBACK_PATH),
+ callback="not_a_callback", # type: ignore
)