This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 31cdf87d1e15206b04581b03c3e5675afa1ebd52 Author: Wei Lee <[email protected]> AuthorDate: Tue Nov 4 14:51:25 2025 +0800 [v3-1-test] Simplify typing in TriggerRuleDep (#57733) (#57779) Co-authored-by: Tzu-ping Chung <[email protected]> --- .../src/airflow/ti_deps/deps/trigger_rule_dep.py | 68 ++++++++++------------ 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 971b156a067..0aad55a24dd 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -20,7 +20,7 @@ from __future__ import annotations import collections.abc import functools from collections import Counter -from collections.abc import Iterator, KeysView +from collections.abc import Iterator, KeysView, Mapping from typing import TYPE_CHECKING, NamedTuple from sqlalchemy import and_, func, or_, select @@ -34,8 +34,10 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnOperators + from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import TIDepStatus @@ -131,6 +133,10 @@ class TriggerRuleDep(BaseTIDep): from airflow.models.mappedoperator import is_mapped from airflow.models.taskinstance import TaskInstance + task = ti.task + if TYPE_CHECKING: + assert task + @functools.lru_cache def _get_expanded_ti_count() -> int: """ @@ -141,13 +147,10 @@ class TriggerRuleDep(BaseTIDep): """ from airflow.models.mappedoperator import get_mapped_ti_count - if TYPE_CHECKING: - assert ti.task - - return get_mapped_ti_count(ti.task, ti.run_id, session=session) + return get_mapped_ti_count(task, ti.run_id, session=session) - def _iter_expansion_dependencies(task_group: SerializedMappedTaskGroup) -> Iterator[str]: - if (task := ti.task) is not None and is_mapped(task): + def _iter_expansion_dependencies(task_group: SerializedMappedTaskGroup | None) -> Iterator[str]: + if is_mapped(task): for op in task.iter_mapped_dependencies(): yield op.task_id if task_group and task_group.iter_mapped_task_groups(): @@ -167,14 +170,13 @@ class TriggerRuleDep(BaseTIDep): task instance of the same task). """ if TYPE_CHECKING: - assert ti.task - assert ti.task.dag - assert ti.task.task_group + assert task.dag + assert task.task_group - if is_mapped(ti.task.task_group): - is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS, TR.ONE_FAILED, TR.ONE_DONE) + if is_mapped(task.task_group): + is_fast_triggered = task.trigger_rule in (TR.ONE_SUCCESS, TR.ONE_FAILED, TR.ONE_DONE) if is_fast_triggered and upstream_id not in set( - _iter_expansion_dependencies(task_group=ti.task.task_group) + _iter_expansion_dependencies(task_group=task.task_group) ): return None @@ -183,7 +185,7 @@ class TriggerRuleDep(BaseTIDep): except (NotFullyPopulated, NotMapped): return None return ti.get_relevant_upstream_map_indexes( - upstream=ti.task.dag.task_dict[upstream_id], + upstream=task.dag.task_dict[upstream_id], ti_count=expanded_ti_count, session=session, ) @@ -198,15 +200,12 @@ class TriggerRuleDep(BaseTIDep): 2. ti is in a mapped task group and upstream has a map index that ti does not depend on. """ - if TYPE_CHECKING: - assert ti.task - # Not actually an upstream task. if upstream.task_id not in relevant_ids: return False # The current task is not in a mapped task group. All tis from an # upstream task are relevant. - if ti.task.get_closest_mapped_task_group() is None: + if task.get_closest_mapped_task_group() is None: return True # The upstream ti is not expanded. The upstream may be mapped or # not, but the ti is relevant either way. @@ -228,10 +227,7 @@ class TriggerRuleDep(BaseTIDep): # it depends on all upstream task instances. from airflow.models.taskinstance import TaskInstance - if TYPE_CHECKING: - assert ti.task - - if ti.task.get_closest_mapped_task_group() is None: + if task.get_closest_mapped_task_group() is None: yield TaskInstance.task_id.in_(relevant_tasks.keys()) return # Otherwise we need to figure out which map indexes are depended on @@ -257,16 +253,16 @@ class TriggerRuleDep(BaseTIDep): else: yield and_(TaskInstance.task_id == upstream_id, TaskInstance.map_index == map_indexes) - def _evaluate_setup_constraint(*, relevant_setups) -> Iterator[tuple[TIDepStatus, bool]]: + def _evaluate_setup_constraint( + *, relevant_setups: Mapping[str, SerializedBaseOperator | MappedOperator] + ) -> Iterator[tuple[TIDepStatus, bool]]: """ Evaluate whether ``ti``'s trigger rule was met as part of the setup constraint. :param relevant_setups: Relevant setups for the current task instance. """ - if TYPE_CHECKING: - assert ti.task - - task = ti.task + if not relevant_setups: + return indirect_setups = {k: v for k, v in relevant_setups.items() if k not in task.upstream_task_ids} finished_upstream_tis = ( @@ -348,10 +344,6 @@ class TriggerRuleDep(BaseTIDep): def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: """Evaluate whether ``ti``'s trigger rule in direct relatives was met.""" - if TYPE_CHECKING: - assert ti.task - - task = ti.task upstream_tasks = {t.task_id: t for t in task.upstream_list} trigger_rule = task.trigger_rule trigger_rule_str = getattr(trigger_rule, "value", trigger_rule) @@ -359,7 +351,7 @@ class TriggerRuleDep(BaseTIDep): finished_upstream_tis = ( finished_ti for finished_ti in dep_context.ensure_finished_tis(ti.get_dagrun(session), session) - if _is_relevant_upstream(upstream=finished_ti, relevant_ids=ti.task.upstream_task_ids) + if _is_relevant_upstream(upstream=finished_ti, relevant_ids=task.upstream_task_ids) ) upstream_states = _UpstreamTIStates.calculate(finished_upstream_tis) @@ -624,12 +616,14 @@ class TriggerRuleDep(BaseTIDep): reason=f"No strategy to evaluate trigger rule '{trigger_rule_str}'." ) - if TYPE_CHECKING: - assert ti.task - - if not ti.task.is_teardown: + if not task.is_teardown: # a teardown cannot have any indirect setups - relevant_setups = {t.task_id: t for t in ti.task.get_upstreams_only_setups()} + relevant_setups: dict[str, MappedOperator | SerializedBaseOperator] = { + # TODO (GH-52141): This should return scheduler types, but + # currently we reuse logic in SDK DAGNode. + t.task_id: t # type: ignore[misc] + for t in task.get_upstreams_only_setups() + } if relevant_setups: for status, changed in _evaluate_setup_constraint(relevant_setups=relevant_setups): yield status
