This is an automated email from the ASF dual-hosted git repository.
dheerajturaga pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new da57d5ecf81 Add regression tests for task_instance_mutation_hook under
scheduler commit guard (#67980)
da57d5ecf81 is described below
commit da57d5ecf81e43e0f5b89e08a8289de0fbb3b9ac
Author: Dheeraj Turaga <[email protected]>
AuthorDate: Sun Jun 7 21:20:24 2026 -0500
Add regression tests for task_instance_mutation_hook under scheduler commit
guard (#67980)
* Add regression tests for task_instance_mutation_hook under scheduler
commit guard
Pin the behavior of task_instance_mutation_hook during mapped-task
expansion while the scheduler's prohibit_commit guard is active — the
exact path that crashes the scheduler when a hook opens a nested
committing session (e.g. a no-arg get_dagrun()).
Adds three tests to TestDagRun's neighborhood in test_dagrun.py:
- a naive DB-touching hook raises UNEXPECTED COMMIT under the guard,
- a session-reusing hook survives the guard and routes queue from
DagRun.conf on every expanded mapped TI,
- a deterministic hook called repeatedly per TI yields a stable result.
These close gaps where existing tests only asserted the hook was
invoked, never that its mutation survived mapped expansion or that the
guard was active on that path.
* Pin dag_run resolution contract for conf-routing mutation hooks
Add three tests that lock down the properties a non-committing,
DagRun-reading task_instance_mutation_hook depends on across upgrades:
- a freshly-built mapped TaskInstance exposes dag_run as loaded-None
(attribute access returns None without a DB hit or DetachedInstanceError),
- a committing hook crashes on the real DagRun.verify_integrity path under
prohibit_commit, proving the guard is live on the production method and
not just the hand-rolled expand_mapped_task wrapper,
- an attribute-access hook that resolves the DagRun via ti.dag_run (never
opening a committing session) survives verify_integrity under the guard
and leaves expansion intact.
These close gaps left by the earlier expand_mapped_task tests, which used
an explicit SELECT rather than attribute access and never drove the real
guarded scheduler method.
---
airflow-core/tests/unit/models/test_dagrun.py | 257 ++++++++++++++++++++++++++
1 file changed, 257 insertions(+)
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 1154a19617e..4ceefc30dfd 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import datetime
from collections import defaultdict
from collections.abc import Mapping
+from contextlib import contextmanager
from functools import reduce
from typing import TYPE_CHECKING
from unittest import mock
@@ -35,6 +36,7 @@ from opentelemetry.trace import StatusCode
from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from sqlalchemy import (
func,
+ inspect as sa_inspect,
select,
update,
)
@@ -70,12 +72,14 @@ from airflow.settings import get_policy_plugin_manager
from airflow.task.trigger_rule import TriggerRule
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.session import create_session
+from airflow.utils.sqlalchemy import prohibit_commit
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
from tests_common.test_utils import db
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.dag import sync_dag_to_db
+from tests_common.test_utils.mapping import expand_mapped_task
from tests_common.test_utils.mock_operators import MockOperator
from tests_common.test_utils.taskinstance import create_task_instance,
run_task_instance
from unit.models import DEFAULT_DATE as _DEFAULT_DATE
@@ -1609,6 +1613,259 @@ def
test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session)
assert indices == [0, 1, 2, 3]
+def _make_mapped_dag_for_expansion(dag_maker, session, *, dag_id, conf=None):
+ """Build a DAG whose mapped task expands from an upstream task's output
and create its DagRun.
+
+ Returns (dag, mapped, dr). Expansion is not performed here -- callers
drive it explicitly
+ via expand_mapped_task so they can wrap it in a prohibit_commit guard,
mirroring how the
+ scheduler expands mapped tasks inside _schedule_all_dag_runs.
+ """
+ with dag_maker(dag_id=dag_id, session=session, serialized=True) as dag:
+ upstream = BaseOperator(task_id="op1")
+ mapped =
MockOperator.partial(task_id="task_2").expand(arg2=upstream.output)
+ dr = dag_maker.create_dagrun(conf=conf or {})
+ return dag, mapped, dr
+
+
+@contextmanager
+def _registered_mutation_hook(hook):
+ """Register hook as the real task_instance_mutation_hook on the policy
plugin manager.
+
+ Patching at the plugin-manager level (rather than airflow.settings)
ensures both call sites
+ see it: TaskMap.expand_mapped_task resolves the wrapper lazily, while
refresh_from_task
+ holds a module-level reference bound at import time.
+ """
+ with mock.patch.object(
+ get_policy_plugin_manager().hook, "task_instance_mutation_hook",
autospec=True
+ ) as mock_hook:
+ mock_hook.side_effect = hook
+ yield mock_hook
+
+
+def
test_mutation_hook_committing_session_crashes_under_prohibit_commit(dag_maker,
session):
+ """A mutation hook that opens a nested committing session crashes mapped
expansion under the guard.
+
+ This pins the exact scheduler crash path: during mapped-task expansion
(TaskMap.expand_mapped_task)
+ the hook is invoked while the outer session is wrapped in prohibit_commit.
A hook that calls the
+ @provide_session-decorated TaskInstance.get_dagrun() with no session
argument reuses the
+ guarded scoped session; the create_session() context manager then commits
on exit, tripping the
+ before_commit guard with RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK
HA LOCKS!").
+ """
+
+ def naive_hook(task_instance):
+ # Reads DagRun.conf the unsafe way -- opens a fresh @provide_session
session that commits on exit.
+ task_instance.get_dagrun()
+
+ dag, mapped, dr = _make_mapped_dag_for_expansion(
+ dag_maker, session, dag_id="test_mutation_hook_committing_session"
+ )
+
+ with _registered_mutation_hook(naive_hook):
+ with prohibit_commit(session):
+ # The guard fires inside expand_mapped_task when the hook's nested
create_session()
+ # commits on exit -- expansion never completes, so there is
nothing to guard.commit().
+ with pytest.raises(RuntimeError, match="UNEXPECTED COMMIT"):
+ expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id,
"op1", length=3, session=session)
+
+
+def
test_mutation_hook_safe_session_reuse_routes_mapped_tis_under_prohibit_commit(dag_maker,
session):
+ """A session-reusing mutation hook survives mapped expansion under the
guard and routes every TI.
+
+ Positive counterpart to the crash test. During expansion the hook is
invoked both on transient,
+ session-less TaskInstance objects and (after session.merge) on instances
attached to the
+ guarded outer session. A safe hook reads DagRun.conf only through the
attached session -- a
+ plain SELECT that reuses the outer transaction, never opening a committing
create_session()
+ -- so the prohibit_commit guard never fires. The routed queue is then
asserted to persist on
+ every expanded mapped TI, closing the gap where existing tests only check
that the hook was called.
+
+ Note: task_instance.dag_run is not usable here -- a freshly-built mapped
TI has dag_run
+ marked "loaded as None" (see TaskInstance.get_dagrun), so the relationship
never lazy-loads.
+ Resolving via the attached session is the discipline a real conf-routing
hook must follow.
+ """
+
+ def safe_hook(task_instance):
+ attached_session = sa_inspect(task_instance).session
+ if attached_session is None:
+ # Transient instance (pre-merge); it will be re-invoked once
attached. Nothing safe to do.
+ return
+ dag_run = attached_session.scalar(
+ select(DagRun).where(DagRun.dag_id == task_instance.dag_id,
DagRun.run_id == task_instance.run_id)
+ )
+ if dag_run is not None and dag_run.conf.get("route") == "high":
+ task_instance.queue = "high_queue"
+
+ dag, mapped, dr = _make_mapped_dag_for_expansion(
+ dag_maker, session, dag_id="test_mutation_hook_safe_routing",
conf={"route": "high"}
+ )
+
+ with _registered_mutation_hook(safe_hook):
+ with prohibit_commit(session) as guard:
+ expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id,
"op1", length=3, session=session)
+ guard.commit()
+
+ queues = session.scalars(
+ select(TI.queue)
+ .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id,
TI.run_id == dr.run_id)
+ .order_by(TI.map_index)
+ ).all()
+ assert queues == ["high_queue", "high_queue", "high_queue"]
+
+
+def
test_mutation_hook_deterministic_across_repeated_invocation_during_expansion(dag_maker,
session):
+ """A mutation hook may be invoked more than once per TI during expansion;
the result must be stable.
+
+ TaskMap.expand_mapped_task invokes the hook on the transient TI and again
via refresh_from_task
+ after session.merge, so a given mapped index is mutated multiple times.
This asserts both that the
+ re-invocation really happens (at least one index sees >1 call) and that a
deterministic hook -- one that
+ sets queue as a pure function of TI identity -- yields the same persisted
value regardless of how
+ many times it ran.
+ """
+ call_counts: dict[int, int] = defaultdict(int)
+
+ def deterministic_hook(task_instance):
+ call_counts[task_instance.map_index] += 1
+ task_instance.queue = f"q_{task_instance.map_index}"
+
+ dag, mapped, dr = _make_mapped_dag_for_expansion(
+ dag_maker, session, dag_id="test_mutation_hook_deterministic"
+ )
+
+ with _registered_mutation_hook(deterministic_hook):
+ with prohibit_commit(session) as guard:
+ expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id,
"op1", length=3, session=session)
+ guard.commit()
+
+ # Re-invocation is real: at least one mapped index was mutated more than
once.
+ assert max(call_counts.values()) > 1
+ # Despite repeated invocation, the deterministic hook leaves each TI with
a stable, identity-derived queue.
+ rows = session.execute(
+ select(TI.map_index, TI.queue)
+ .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id,
TI.run_id == dr.run_id)
+ .order_by(TI.map_index)
+ ).all()
+ assert rows == [(0, "q_0"), (1, "q_1"), (2, "q_2")]
+
+
+def _make_literal_mapped_dagrun(dag_maker, session, *, dag_id, conf=None):
+ """Build a literal-mapped DAG and its running DagRun, returning (dr,
dag_version_id).
+
+ Unlike _make_mapped_dag_for_expansion (which leaves an xcom-mapped task
unexpanded so callers
+ can drive TaskMap.expand_mapped_task by hand), this builds a literal
.expand([...]) so that
+ create_dagrun materializes the mapped TIs immediately. Callers can then
re-invoke the mutation
+ hook on those persisted TIs by calling dr.verify_integrity(...) -- the
real scheduler method --
+ inside their own prohibit_commit guard.
+ """
+ with dag_maker(dag_id=dag_id, session=session, serialized=True):
+
+ @task
+ def mapped_task(arg):
+ return arg
+
+ mapped_task.expand(arg=[1, 2, 3])
+ dr = dag_maker.create_dagrun(conf=conf or {})
+ dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id,
session=session).id
+ return dr, dag_version_id
+
+
+def test_freshly_built_mapped_ti_exposes_dag_run_as_loaded_none(dag_maker,
session):
+ """A freshly-built mapped TaskInstance exposes dag_run as loaded-None, not
a lazy-load or raise.
+
+ TaskMap.expand_mapped_task constructs each expanded TI with
TaskInstance(task, run_id=..., ...)
+ and invokes the mutation hook on it before it is merged into a session. A
conf-routing hook that
+ resolves the DagRun by attribute access (the _resolve_dagrun discipline)
relies on ti.dag_run
+ returning None here -- without hitting the DB and without raising
DetachedInstanceError -- so it
+ can fall through to a non-DB resolution path. This pins that lifecycle
fact (documented on
+ TaskInstance.get_dagrun); if a future change made the relationship
eager-load or raise on a
+ transient instance, the workaround would silently change behavior and this
test would flag it.
+ """
+ with dag_maker(dag_id="test_loaded_none_canary", session=session,
serialized=True) as dag:
+ EmptyOperator(task_id="solo")
+ dr = dag_maker.create_dagrun()
+
+ task = dag.task_dict["solo"]
+ ti = TI(
+ task,
+ run_id=dr.run_id,
+ map_index=0,
+ dag_version_id=DagVersion.get_latest_version(dag_id=dr.dag_id,
session=session).id,
+ )
+
+ # Transient: not attached to any session, so a real _resolve_dagrun cannot
lazy-load via it.
+ assert sa_inspect(ti).session is None
+ # Attribute access returns None rather than raising or emitting a query --
the load-as-None state.
+ assert ti.dag_run is None
+
+
+def
test_naive_committing_hook_crashes_on_verify_integrity_under_guard(dag_maker,
session):
+ """A committing hook crashes on the real verify_integrity path under the
guard, not just the helper.
+
+ The committed expand_mapped_task tests hand-roll the prohibit_commit
wrapper. This drives the
+ actual scheduler method dr.verify_integrity(...) -- which re-invokes the
mutation hook on every
+ task instance -- inside the guard, proving the guard is genuinely live on
the production path. A
+ hook calling the @provide_session-decorated get_dagrun() with no session
reuses the guarded scoped
+ session; create_session() commits on exit and trips the before_commit
guard.
+ """
+
+ def naive_hook(task_instance):
+ task_instance.get_dagrun()
+
+ dr, dag_version_id = _make_literal_mapped_dagrun(
+ dag_maker, session, dag_id="test_verify_integrity_naive_hook"
+ )
+
+ with _registered_mutation_hook(naive_hook):
+ with prohibit_commit(session):
+ with pytest.raises(RuntimeError, match="UNEXPECTED COMMIT"):
+ dr.verify_integrity(dag_version_id=dag_version_id,
session=session)
+
+
+def
test_resolve_dagrun_attribute_access_is_safe_on_verify_integrity_under_guard(dag_maker,
session):
+ """An attribute-access conf-routing hook survives the real
verify_integrity path under the guard.
+
+ Mirrors the _resolve_dagrun mechanism a conf-routing cluster policy uses:
inspect whether dag_run
+ is loaded and whether the TI is attached, then resolve only via ti.dag_run
attribute access --
+ never opening a committing create_session(). Driven through
dr.verify_integrity(...) under a real
+ prohibit_commit guard, this pins that the attribute-access path is
non-committing and non-raising
+ on the production method, which the committed safe_hook test (explicit
SELECT on the attached
+ session) does not exercise.
+
+ The if/elif branches deliberately mirror the real (out-of-repo)
_resolve_dagrun so they act as a
+ regression anchor: if a future change to ti.dag_run's lazy-load semantics
broke the eager-loaded
+ or attached path, attribute access here would raise and the test would
catch it.
+ """
+ seen_map_indices = []
+
+ def resolve_dagrun_like_hook(task_instance):
+ state = sa_inspect(task_instance)
+ if "dag_run" not in state.unloaded:
+ _ = task_instance.dag_run # eager-loaded: cheap attribute read
+ elif state.session is not None:
+ _ = task_instance.dag_run # attached: lazy-load via the outer
session, no new session
+ # else: transient -> a real _resolve_dagrun walks the stack; nothing
to do here.
+ seen_map_indices.append(task_instance.map_index)
+
+ dr, dag_version_id = _make_literal_mapped_dagrun(
+ dag_maker, session, dag_id="test_verify_integrity_resolve_dagrun",
conf={"route": "high"}
+ )
+
+ with _registered_mutation_hook(resolve_dagrun_like_hook):
+ with prohibit_commit(session) as guard:
+ dr.verify_integrity(dag_version_id=dag_version_id, session=session)
+ guard.commit()
+
+ # The attribute-access hook ran (without raising or committing) on every
persisted mapped TI:
+ # verify_integrity re-invokes it via _check_for_removed_or_restored_tasks,
and the expanded TIs
+ # all survive. Any raise inside the hook would have propagated out of
verify_integrity and failed
+ # the test directly -- no flag needed.
+ assert sorted(seen_map_indices) == [0, 1, 2]
+ indices = session.scalars(
+ select(TI.map_index)
+ .where(TI.task_id == "mapped_task", TI.dag_id == dr.dag_id, TI.run_id
== dr.run_id)
+ .order_by(TI.map_index)
+ ).all()
+ assert indices == [0, 1, 2]
+
+
def test_verify_integrity_handles_stale_data_error(dag_maker, session):
"""Test that StaleDataError during _create_task_instances is caught and
session is rolled back."""
with dag_maker("test_stale_data_error_dag", session=session) as dag: