This is an automated email from the ASF dual-hosted git repository.

dheerajturaga pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new da57d5ecf81 Add regression tests for task_instance_mutation_hook under 
scheduler commit guard (#67980)
da57d5ecf81 is described below

commit da57d5ecf81e43e0f5b89e08a8289de0fbb3b9ac
Author: Dheeraj Turaga <[email protected]>
AuthorDate: Sun Jun 7 21:20:24 2026 -0500

    Add regression tests for task_instance_mutation_hook under scheduler commit 
guard (#67980)
    
    * Add regression tests for task_instance_mutation_hook under scheduler 
commit guard
    
    Pin the behavior of task_instance_mutation_hook during mapped-task
    expansion while the scheduler's prohibit_commit guard is active — the
    exact path that crashes the scheduler when a hook opens a nested
    committing session (e.g. a no-arg get_dagrun()).
    
    Adds three tests to TestDagRun's neighborhood in test_dagrun.py:
    - a naive DB-touching hook raises UNEXPECTED COMMIT under the guard,
    - a session-reusing hook survives the guard and routes queue from
      DagRun.conf on every expanded mapped TI,
    - a deterministic hook called repeatedly per TI yields a stable result.
    
    These close gaps where existing tests only asserted the hook was
    invoked, never that its mutation survived mapped expansion or that the
    guard was active on that path.
    
    * Pin dag_run resolution contract for conf-routing mutation hooks
    
    Add three tests that lock down the properties a non-committing,
    DagRun-reading task_instance_mutation_hook depends on across upgrades:
    
    - a freshly-built mapped TaskInstance exposes dag_run as loaded-None
      (attribute access returns None without a DB hit or DetachedInstanceError),
    - a committing hook crashes on the real DagRun.verify_integrity path under
      prohibit_commit, proving the guard is live on the production method and
      not just the hand-rolled expand_mapped_task wrapper,
    - an attribute-access hook that resolves the DagRun via ti.dag_run (never
      opening a committing session) survives verify_integrity under the guard
      and leaves expansion intact.
    
    These close gaps left by the earlier expand_mapped_task tests, which used
    an explicit SELECT rather than attribute access and never drove the real
    guarded scheduler method.
---
 airflow-core/tests/unit/models/test_dagrun.py | 257 ++++++++++++++++++++++++++
 1 file changed, 257 insertions(+)

diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index 1154a19617e..4ceefc30dfd 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import datetime
 from collections import defaultdict
 from collections.abc import Mapping
+from contextlib import contextmanager
 from functools import reduce
 from typing import TYPE_CHECKING
 from unittest import mock
@@ -35,6 +36,7 @@ from opentelemetry.trace import StatusCode
 from opentelemetry.trace.propagation.tracecontext import 
TraceContextTextMapPropagator
 from sqlalchemy import (
     func,
+    inspect as sa_inspect,
     select,
     update,
 )
@@ -70,12 +72,14 @@ from airflow.settings import get_policy_plugin_manager
 from airflow.task.trigger_rule import TriggerRule
 from airflow.triggers.base import StartTriggerArgs
 from airflow.utils.session import create_session
+from airflow.utils.sqlalchemy import prohibit_commit
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 from tests_common.test_utils import db
 from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.dag import sync_dag_to_db
+from tests_common.test_utils.mapping import expand_mapped_task
 from tests_common.test_utils.mock_operators import MockOperator
 from tests_common.test_utils.taskinstance import create_task_instance, 
run_task_instance
 from unit.models import DEFAULT_DATE as _DEFAULT_DATE
@@ -1609,6 +1613,259 @@ def 
test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session)
         assert indices == [0, 1, 2, 3]
 
 
+def _make_mapped_dag_for_expansion(dag_maker, session, *, dag_id, conf=None):
+    """Build a DAG whose mapped task expands from an upstream task's output 
and create its DagRun.
+
+    Returns (dag, mapped, dr). Expansion is not performed here -- callers 
drive it explicitly
+    via expand_mapped_task so they can wrap it in a prohibit_commit guard, 
mirroring how the
+    scheduler expands mapped tasks inside _schedule_all_dag_runs.
+    """
+    with dag_maker(dag_id=dag_id, session=session, serialized=True) as dag:
+        upstream = BaseOperator(task_id="op1")
+        mapped = 
MockOperator.partial(task_id="task_2").expand(arg2=upstream.output)
+    dr = dag_maker.create_dagrun(conf=conf or {})
+    return dag, mapped, dr
+
+
+@contextmanager
+def _registered_mutation_hook(hook):
+    """Register hook as the real task_instance_mutation_hook on the policy 
plugin manager.
+
+    Patching at the plugin-manager level (rather than airflow.settings) 
ensures both call sites
+    see it: TaskMap.expand_mapped_task resolves the wrapper lazily, while 
refresh_from_task
+    holds a module-level reference bound at import time.
+    """
+    with mock.patch.object(
+        get_policy_plugin_manager().hook, "task_instance_mutation_hook", 
autospec=True
+    ) as mock_hook:
+        mock_hook.side_effect = hook
+        yield mock_hook
+
+
+def 
test_mutation_hook_committing_session_crashes_under_prohibit_commit(dag_maker, 
session):
+    """A mutation hook that opens a nested committing session crashes mapped 
expansion under the guard.
+
+    This pins the exact scheduler crash path: during mapped-task expansion 
(TaskMap.expand_mapped_task)
+    the hook is invoked while the outer session is wrapped in prohibit_commit. 
A hook that calls the
+    @provide_session-decorated TaskInstance.get_dagrun() with no session 
argument reuses the
+    guarded scoped session; the create_session() context manager then commits 
on exit, tripping the
+    before_commit guard with RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK 
HA LOCKS!").
+    """
+
+    def naive_hook(task_instance):
+        # Reads DagRun.conf the unsafe way -- opens a fresh @provide_session 
session that commits on exit.
+        task_instance.get_dagrun()
+
+    dag, mapped, dr = _make_mapped_dag_for_expansion(
+        dag_maker, session, dag_id="test_mutation_hook_committing_session"
+    )
+
+    with _registered_mutation_hook(naive_hook):
+        with prohibit_commit(session):
+            # The guard fires inside expand_mapped_task when the hook's nested 
create_session()
+            # commits on exit -- expansion never completes, so there is 
nothing to guard.commit().
+            with pytest.raises(RuntimeError, match="UNEXPECTED COMMIT"):
+                expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id, 
"op1", length=3, session=session)
+
+
+def 
test_mutation_hook_safe_session_reuse_routes_mapped_tis_under_prohibit_commit(dag_maker,
 session):
+    """A session-reusing mutation hook survives mapped expansion under the 
guard and routes every TI.
+
+    Positive counterpart to the crash test. During expansion the hook is 
invoked both on transient,
+    session-less TaskInstance objects and (after session.merge) on instances 
attached to the
+    guarded outer session. A safe hook reads DagRun.conf only through the 
attached session -- a
+    plain SELECT that reuses the outer transaction, never opening a committing 
create_session()
+    -- so the prohibit_commit guard never fires. The routed queue is then 
asserted to persist on
+    every expanded mapped TI, closing the gap where existing tests only check 
that the hook was called.
+
+    Note: task_instance.dag_run is not usable here -- a freshly-built mapped 
TI has dag_run
+    marked "loaded as None" (see TaskInstance.get_dagrun), so the relationship 
never lazy-loads.
+    Resolving via the attached session is the discipline a real conf-routing 
hook must follow.
+    """
+
+    def safe_hook(task_instance):
+        attached_session = sa_inspect(task_instance).session
+        if attached_session is None:
+            # Transient instance (pre-merge); it will be re-invoked once 
attached. Nothing safe to do.
+            return
+        dag_run = attached_session.scalar(
+            select(DagRun).where(DagRun.dag_id == task_instance.dag_id, 
DagRun.run_id == task_instance.run_id)
+        )
+        if dag_run is not None and dag_run.conf.get("route") == "high":
+            task_instance.queue = "high_queue"
+
+    dag, mapped, dr = _make_mapped_dag_for_expansion(
+        dag_maker, session, dag_id="test_mutation_hook_safe_routing", 
conf={"route": "high"}
+    )
+
+    with _registered_mutation_hook(safe_hook):
+        with prohibit_commit(session) as guard:
+            expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id, 
"op1", length=3, session=session)
+            guard.commit()
+
+    queues = session.scalars(
+        select(TI.queue)
+        .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, 
TI.run_id == dr.run_id)
+        .order_by(TI.map_index)
+    ).all()
+    assert queues == ["high_queue", "high_queue", "high_queue"]
+
+
+def 
test_mutation_hook_deterministic_across_repeated_invocation_during_expansion(dag_maker,
 session):
+    """A mutation hook may be invoked more than once per TI during expansion; 
the result must be stable.
+
+    TaskMap.expand_mapped_task invokes the hook on the transient TI and again 
via refresh_from_task
+    after session.merge, so a given mapped index is mutated multiple times. 
This asserts both that the
+    re-invocation really happens (at least one index sees >1 call) and that a 
deterministic hook -- one that
+    sets queue as a pure function of TI identity -- yields the same persisted 
value regardless of how
+    many times it ran.
+    """
+    call_counts: dict[int, int] = defaultdict(int)
+
+    def deterministic_hook(task_instance):
+        call_counts[task_instance.map_index] += 1
+        task_instance.queue = f"q_{task_instance.map_index}"
+
+    dag, mapped, dr = _make_mapped_dag_for_expansion(
+        dag_maker, session, dag_id="test_mutation_hook_deterministic"
+    )
+
+    with _registered_mutation_hook(deterministic_hook):
+        with prohibit_commit(session) as guard:
+            expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id, 
"op1", length=3, session=session)
+            guard.commit()
+
+    # Re-invocation is real: at least one mapped index was mutated more than 
once.
+    assert max(call_counts.values()) > 1
+    # Despite repeated invocation, the deterministic hook leaves each TI with 
a stable, identity-derived queue.
+    rows = session.execute(
+        select(TI.map_index, TI.queue)
+        .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, 
TI.run_id == dr.run_id)
+        .order_by(TI.map_index)
+    ).all()
+    assert rows == [(0, "q_0"), (1, "q_1"), (2, "q_2")]
+
+
+def _make_literal_mapped_dagrun(dag_maker, session, *, dag_id, conf=None):
+    """Build a literal-mapped DAG and its running DagRun, returning (dr, 
dag_version_id).
+
+    Unlike _make_mapped_dag_for_expansion (which leaves an xcom-mapped task 
unexpanded so callers
+    can drive TaskMap.expand_mapped_task by hand), this builds a literal 
.expand([...]) so that
+    create_dagrun materializes the mapped TIs immediately. Callers can then 
re-invoke the mutation
+    hook on those persisted TIs by calling dr.verify_integrity(...) -- the 
real scheduler method --
+    inside their own prohibit_commit guard.
+    """
+    with dag_maker(dag_id=dag_id, session=session, serialized=True):
+
+        @task
+        def mapped_task(arg):
+            return arg
+
+        mapped_task.expand(arg=[1, 2, 3])
+    dr = dag_maker.create_dagrun(conf=conf or {})
+    dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, 
session=session).id
+    return dr, dag_version_id
+
+
+def test_freshly_built_mapped_ti_exposes_dag_run_as_loaded_none(dag_maker, 
session):
+    """A freshly-built mapped TaskInstance exposes dag_run as loaded-None, not 
a lazy-load or raise.
+
+    TaskMap.expand_mapped_task constructs each expanded TI with 
TaskInstance(task, run_id=..., ...)
+    and invokes the mutation hook on it before it is merged into a session. A 
conf-routing hook that
+    resolves the DagRun by attribute access (the _resolve_dagrun discipline) 
relies on ti.dag_run
+    returning None here -- without hitting the DB and without raising 
DetachedInstanceError -- so it
+    can fall through to a non-DB resolution path. This pins that lifecycle 
fact (documented on
+    TaskInstance.get_dagrun); if a future change made the relationship 
eager-load or raise on a
+    transient instance, the workaround would silently change behavior and this 
test would flag it.
+    """
+    with dag_maker(dag_id="test_loaded_none_canary", session=session, 
serialized=True) as dag:
+        EmptyOperator(task_id="solo")
+    dr = dag_maker.create_dagrun()
+
+    task = dag.task_dict["solo"]
+    ti = TI(
+        task,
+        run_id=dr.run_id,
+        map_index=0,
+        dag_version_id=DagVersion.get_latest_version(dag_id=dr.dag_id, 
session=session).id,
+    )
+
+    # Transient: not attached to any session, so a real _resolve_dagrun cannot 
lazy-load via it.
+    assert sa_inspect(ti).session is None
+    # Attribute access returns None rather than raising or emitting a query -- 
the load-as-None state.
+    assert ti.dag_run is None
+
+
+def 
test_naive_committing_hook_crashes_on_verify_integrity_under_guard(dag_maker, 
session):
+    """A committing hook crashes on the real verify_integrity path under the 
guard, not just the helper.
+
+    The committed expand_mapped_task tests hand-roll the prohibit_commit 
wrapper. This drives the
+    actual scheduler method dr.verify_integrity(...) -- which re-invokes the 
mutation hook on every
+    task instance -- inside the guard, proving the guard is genuinely live on 
the production path. A
+    hook calling the @provide_session-decorated get_dagrun() with no session 
reuses the guarded scoped
+    session; create_session() commits on exit and trips the before_commit 
guard.
+    """
+
+    def naive_hook(task_instance):
+        task_instance.get_dagrun()
+
+    dr, dag_version_id = _make_literal_mapped_dagrun(
+        dag_maker, session, dag_id="test_verify_integrity_naive_hook"
+    )
+
+    with _registered_mutation_hook(naive_hook):
+        with prohibit_commit(session):
+            with pytest.raises(RuntimeError, match="UNEXPECTED COMMIT"):
+                dr.verify_integrity(dag_version_id=dag_version_id, 
session=session)
+
+
+def 
test_resolve_dagrun_attribute_access_is_safe_on_verify_integrity_under_guard(dag_maker,
 session):
+    """An attribute-access conf-routing hook survives the real 
verify_integrity path under the guard.
+
+    Mirrors the _resolve_dagrun mechanism a conf-routing cluster policy uses: 
inspect whether dag_run
+    is loaded and whether the TI is attached, then resolve only via ti.dag_run 
attribute access --
+    never opening a committing create_session(). Driven through 
dr.verify_integrity(...) under a real
+    prohibit_commit guard, this pins that the attribute-access path is 
non-committing and non-raising
+    on the production method, which the committed safe_hook test (explicit 
SELECT on the attached
+    session) does not exercise.
+
+    The if/elif branches deliberately mirror the real (out-of-repo) 
_resolve_dagrun so they act as a
+    regression anchor: if a future change to ti.dag_run's lazy-load semantics 
broke the eager-loaded
+    or attached path, attribute access here would raise and the test would 
catch it.
+    """
+    seen_map_indices = []
+
+    def resolve_dagrun_like_hook(task_instance):
+        state = sa_inspect(task_instance)
+        if "dag_run" not in state.unloaded:
+            _ = task_instance.dag_run  # eager-loaded: cheap attribute read
+        elif state.session is not None:
+            _ = task_instance.dag_run  # attached: lazy-load via the outer 
session, no new session
+        # else: transient -> a real _resolve_dagrun walks the stack; nothing 
to do here.
+        seen_map_indices.append(task_instance.map_index)
+
+    dr, dag_version_id = _make_literal_mapped_dagrun(
+        dag_maker, session, dag_id="test_verify_integrity_resolve_dagrun", 
conf={"route": "high"}
+    )
+
+    with _registered_mutation_hook(resolve_dagrun_like_hook):
+        with prohibit_commit(session) as guard:
+            dr.verify_integrity(dag_version_id=dag_version_id, session=session)
+            guard.commit()
+
+    # The attribute-access hook ran (without raising or committing) on every 
persisted mapped TI:
+    # verify_integrity re-invokes it via _check_for_removed_or_restored_tasks, 
and the expanded TIs
+    # all survive. Any raise inside the hook would have propagated out of 
verify_integrity and failed
+    # the test directly -- no flag needed.
+    assert sorted(seen_map_indices) == [0, 1, 2]
+    indices = session.scalars(
+        select(TI.map_index)
+        .where(TI.task_id == "mapped_task", TI.dag_id == dr.dag_id, TI.run_id 
== dr.run_id)
+        .order_by(TI.map_index)
+    ).all()
+    assert indices == [0, 1, 2]
+
+
 def test_verify_integrity_handles_stale_data_error(dag_maker, session):
     """Test that StaleDataError during _create_task_instances is caught and 
session is rolled back."""
     with dag_maker("test_stale_data_error_dag", session=session) as dag:

Reply via email to