This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 31cdf87d1e15206b04581b03c3e5675afa1ebd52
Author: Wei Lee <[email protected]>
AuthorDate: Tue Nov 4 14:51:25 2025 +0800

    [v3-1-test] Simplify typing in TriggerRuleDep (#57733) (#57779)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 .../src/airflow/ti_deps/deps/trigger_rule_dep.py   | 68 ++++++++++------------
 1 file changed, 31 insertions(+), 37 deletions(-)

diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 971b156a067..0aad55a24dd 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import collections.abc
 import functools
 from collections import Counter
-from collections.abc import Iterator, KeysView
+from collections.abc import Iterator, KeysView, Mapping
 from typing import TYPE_CHECKING, NamedTuple
 
 from sqlalchemy import and_, func, or_, select
@@ -34,8 +34,10 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
     from sqlalchemy.sql.expression import ColumnOperators
 
+    from airflow.models.mappedoperator import MappedOperator
     from airflow.models.taskinstance import TaskInstance
     from airflow.serialization.definitions.taskgroup import 
SerializedMappedTaskGroup
+    from airflow.serialization.serialized_objects import SerializedBaseOperator
     from airflow.ti_deps.dep_context import DepContext
     from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
 
@@ -131,6 +133,10 @@ class TriggerRuleDep(BaseTIDep):
         from airflow.models.mappedoperator import is_mapped
         from airflow.models.taskinstance import TaskInstance
 
+        task = ti.task
+        if TYPE_CHECKING:
+            assert task
+
         @functools.lru_cache
         def _get_expanded_ti_count() -> int:
             """
@@ -141,13 +147,10 @@ class TriggerRuleDep(BaseTIDep):
             """
             from airflow.models.mappedoperator import get_mapped_ti_count
 
-            if TYPE_CHECKING:
-                assert ti.task
-
-            return get_mapped_ti_count(ti.task, ti.run_id, session=session)
+            return get_mapped_ti_count(task, ti.run_id, session=session)
 
-        def _iter_expansion_dependencies(task_group: 
SerializedMappedTaskGroup) -> Iterator[str]:
-            if (task := ti.task) is not None and is_mapped(task):
+        def _iter_expansion_dependencies(task_group: SerializedMappedTaskGroup 
| None) -> Iterator[str]:
+            if is_mapped(task):
                 for op in task.iter_mapped_dependencies():
                     yield op.task_id
             if task_group and task_group.iter_mapped_task_groups():
@@ -167,14 +170,13 @@ class TriggerRuleDep(BaseTIDep):
             task instance of the same task).
             """
             if TYPE_CHECKING:
-                assert ti.task
-                assert ti.task.dag
-                assert ti.task.task_group
+                assert task.dag
+                assert task.task_group
 
-            if is_mapped(ti.task.task_group):
-                is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, 
TR.ONE_FAILED, TR.ONE_DONE)
+            if is_mapped(task.task_group):
+                is_fast_triggered = task.trigger_rule in (TR.ONE_SUCCESS, 
TR.ONE_FAILED, TR.ONE_DONE)
                 if is_fast_triggered and upstream_id not in set(
-                    _iter_expansion_dependencies(task_group=ti.task.task_group)
+                    _iter_expansion_dependencies(task_group=task.task_group)
                 ):
                     return None
 
@@ -183,7 +185,7 @@ class TriggerRuleDep(BaseTIDep):
             except (NotFullyPopulated, NotMapped):
                 return None
             return ti.get_relevant_upstream_map_indexes(
-                upstream=ti.task.dag.task_dict[upstream_id],
+                upstream=task.dag.task_dict[upstream_id],
                 ti_count=expanded_ti_count,
                 session=session,
             )
@@ -198,15 +200,12 @@ class TriggerRuleDep(BaseTIDep):
                 2. ti is in a mapped task group and upstream has a map index
                   that ti does not depend on.
             """
-            if TYPE_CHECKING:
-                assert ti.task
-
             # Not actually an upstream task.
             if upstream.task_id not in relevant_ids:
                 return False
             # The current task is not in a mapped task group. All tis from an
             # upstream task are relevant.
-            if ti.task.get_closest_mapped_task_group() is None:
+            if task.get_closest_mapped_task_group() is None:
                 return True
             # The upstream ti is not expanded. The upstream may be mapped or
             # not, but the ti is relevant either way.
@@ -228,10 +227,7 @@ class TriggerRuleDep(BaseTIDep):
             # it depends on all upstream task instances.
             from airflow.models.taskinstance import TaskInstance
 
-            if TYPE_CHECKING:
-                assert ti.task
-
-            if ti.task.get_closest_mapped_task_group() is None:
+            if task.get_closest_mapped_task_group() is None:
                 yield TaskInstance.task_id.in_(relevant_tasks.keys())
                 return
             # Otherwise we need to figure out which map indexes are depended on
@@ -257,16 +253,16 @@ class TriggerRuleDep(BaseTIDep):
                 else:
                     yield and_(TaskInstance.task_id == upstream_id, 
TaskInstance.map_index == map_indexes)
 
-        def _evaluate_setup_constraint(*, relevant_setups) -> 
Iterator[tuple[TIDepStatus, bool]]:
+        def _evaluate_setup_constraint(
+            *, relevant_setups: Mapping[str, SerializedBaseOperator | 
MappedOperator]
+        ) -> Iterator[tuple[TIDepStatus, bool]]:
             """
             Evaluate whether ``ti``'s trigger rule was met as part of the 
setup constraint.
 
             :param relevant_setups: Relevant setups for the current task 
instance.
             """
-            if TYPE_CHECKING:
-                assert ti.task
-
-            task = ti.task
+            if not relevant_setups:
+                return
 
             indirect_setups = {k: v for k, v in relevant_setups.items() if k 
not in task.upstream_task_ids}
             finished_upstream_tis = (
@@ -348,10 +344,6 @@ class TriggerRuleDep(BaseTIDep):
 
         def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
             """Evaluate whether ``ti``'s trigger rule in direct relatives was 
met."""
-            if TYPE_CHECKING:
-                assert ti.task
-
-            task = ti.task
             upstream_tasks = {t.task_id: t for t in task.upstream_list}
             trigger_rule = task.trigger_rule
             trigger_rule_str = getattr(trigger_rule, "value", trigger_rule)
@@ -359,7 +351,7 @@ class TriggerRuleDep(BaseTIDep):
             finished_upstream_tis = (
                 finished_ti
                 for finished_ti in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
-                if _is_relevant_upstream(upstream=finished_ti, 
relevant_ids=ti.task.upstream_task_ids)
+                if _is_relevant_upstream(upstream=finished_ti, 
relevant_ids=task.upstream_task_ids)
             )
             upstream_states = 
_UpstreamTIStates.calculate(finished_upstream_tis)
 
@@ -624,12 +616,14 @@ class TriggerRuleDep(BaseTIDep):
                     reason=f"No strategy to evaluate trigger rule 
'{trigger_rule_str}'."
                 )
 
-        if TYPE_CHECKING:
-            assert ti.task
-
-        if not ti.task.is_teardown:
+        if not task.is_teardown:
             # a teardown cannot have any indirect setups
-            relevant_setups = {t.task_id: t for t in 
ti.task.get_upstreams_only_setups()}
+            relevant_setups: dict[str, MappedOperator | 
SerializedBaseOperator] = {
+                # TODO (GH-52141): This should return scheduler types, but
+                # currently we reuse logic in SDK DAGNode.
+                t.task_id: t  # type: ignore[misc]
+                for t in task.get_upstreams_only_setups()
+            }
             if relevant_setups:
                 for status, changed in 
_evaluate_setup_constraint(relevant_setups=relevant_setups):
                     yield status

Reply via email to