This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new eda82d6f7d2 [v3-1-test] Simplify typing in TriggerRuleDep (#57733)
(#57779) (#57719)
eda82d6f7d2 is described below
commit eda82d6f7d267986a38481645262bd7f87e01949
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Mon Jan 26 15:30:20 2026 +0800
[v3-1-test] Simplify typing in TriggerRuleDep (#57733) (#57779) (#57719)
Co-authored-by: Wei Lee <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
---
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 68 ++++++++++------------
1 file changed, 31 insertions(+), 37 deletions(-)
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 971b156a067..0aad55a24dd 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import collections.abc
import functools
from collections import Counter
-from collections.abc import Iterator, KeysView
+from collections.abc import Iterator, KeysView, Mapping
from typing import TYPE_CHECKING, NamedTuple
from sqlalchemy import and_, func, or_, select
@@ -34,8 +34,10 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnOperators
+ from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup
+ from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
@@ -131,6 +133,10 @@ class TriggerRuleDep(BaseTIDep):
from airflow.models.mappedoperator import is_mapped
from airflow.models.taskinstance import TaskInstance
+ task = ti.task
+ if TYPE_CHECKING:
+ assert task
+
@functools.lru_cache
def _get_expanded_ti_count() -> int:
"""
@@ -141,13 +147,10 @@ class TriggerRuleDep(BaseTIDep):
"""
from airflow.models.mappedoperator import get_mapped_ti_count
- if TYPE_CHECKING:
- assert ti.task
-
- return get_mapped_ti_count(ti.task, ti.run_id, session=session)
+ return get_mapped_ti_count(task, ti.run_id, session=session)
- def _iter_expansion_dependencies(task_group:
SerializedMappedTaskGroup) -> Iterator[str]:
- if (task := ti.task) is not None and is_mapped(task):
+ def _iter_expansion_dependencies(task_group: SerializedMappedTaskGroup
| None) -> Iterator[str]:
+ if is_mapped(task):
for op in task.iter_mapped_dependencies():
yield op.task_id
if task_group and task_group.iter_mapped_task_groups():
@@ -167,14 +170,13 @@ class TriggerRuleDep(BaseTIDep):
task instance of the same task).
"""
if TYPE_CHECKING:
- assert ti.task
- assert ti.task.dag
- assert ti.task.task_group
+ assert task.dag
+ assert task.task_group
- if is_mapped(ti.task.task_group):
- is_fast_triggered = ti.task.trigger_rule in (TR.ONE_SUCCESS,
TR.ONE_FAILED, TR.ONE_DONE)
+ if is_mapped(task.task_group):
+ is_fast_triggered = task.trigger_rule in (TR.ONE_SUCCESS,
TR.ONE_FAILED, TR.ONE_DONE)
if is_fast_triggered and upstream_id not in set(
- _iter_expansion_dependencies(task_group=ti.task.task_group)
+ _iter_expansion_dependencies(task_group=task.task_group)
):
return None
@@ -183,7 +185,7 @@ class TriggerRuleDep(BaseTIDep):
except (NotFullyPopulated, NotMapped):
return None
return ti.get_relevant_upstream_map_indexes(
- upstream=ti.task.dag.task_dict[upstream_id],
+ upstream=task.dag.task_dict[upstream_id],
ti_count=expanded_ti_count,
session=session,
)
@@ -198,15 +200,12 @@ class TriggerRuleDep(BaseTIDep):
2. ti is in a mapped task group and upstream has a map index
that ti does not depend on.
"""
- if TYPE_CHECKING:
- assert ti.task
-
# Not actually an upstream task.
if upstream.task_id not in relevant_ids:
return False
# The current task is not in a mapped task group. All tis from an
# upstream task are relevant.
- if ti.task.get_closest_mapped_task_group() is None:
+ if task.get_closest_mapped_task_group() is None:
return True
# The upstream ti is not expanded. The upstream may be mapped or
# not, but the ti is relevant either way.
@@ -228,10 +227,7 @@ class TriggerRuleDep(BaseTIDep):
# it depends on all upstream task instances.
from airflow.models.taskinstance import TaskInstance
- if TYPE_CHECKING:
- assert ti.task
-
- if ti.task.get_closest_mapped_task_group() is None:
+ if task.get_closest_mapped_task_group() is None:
yield TaskInstance.task_id.in_(relevant_tasks.keys())
return
# Otherwise we need to figure out which map indexes are depended on
@@ -257,16 +253,16 @@ class TriggerRuleDep(BaseTIDep):
else:
yield and_(TaskInstance.task_id == upstream_id,
TaskInstance.map_index == map_indexes)
- def _evaluate_setup_constraint(*, relevant_setups) ->
Iterator[tuple[TIDepStatus, bool]]:
+ def _evaluate_setup_constraint(
+ *, relevant_setups: Mapping[str, SerializedBaseOperator |
MappedOperator]
+ ) -> Iterator[tuple[TIDepStatus, bool]]:
"""
Evaluate whether ``ti``'s trigger rule was met as part of the
setup constraint.
:param relevant_setups: Relevant setups for the current task
instance.
"""
- if TYPE_CHECKING:
- assert ti.task
-
- task = ti.task
+ if not relevant_setups:
+ return
indirect_setups = {k: v for k, v in relevant_setups.items() if k
not in task.upstream_task_ids}
finished_upstream_tis = (
@@ -348,10 +344,6 @@ class TriggerRuleDep(BaseTIDep):
def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:
"""Evaluate whether ``ti``'s trigger rule in direct relatives was
met."""
- if TYPE_CHECKING:
- assert ti.task
-
- task = ti.task
upstream_tasks = {t.task_id: t for t in task.upstream_list}
trigger_rule = task.trigger_rule
trigger_rule_str = getattr(trigger_rule, "value", trigger_rule)
@@ -359,7 +351,7 @@ class TriggerRuleDep(BaseTIDep):
finished_upstream_tis = (
finished_ti
for finished_ti in
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
- if _is_relevant_upstream(upstream=finished_ti,
relevant_ids=ti.task.upstream_task_ids)
+ if _is_relevant_upstream(upstream=finished_ti,
relevant_ids=task.upstream_task_ids)
)
upstream_states =
_UpstreamTIStates.calculate(finished_upstream_tis)
@@ -624,12 +616,14 @@ class TriggerRuleDep(BaseTIDep):
reason=f"No strategy to evaluate trigger rule
'{trigger_rule_str}'."
)
- if TYPE_CHECKING:
- assert ti.task
-
- if not ti.task.is_teardown:
+ if not task.is_teardown:
# a teardown cannot have any indirect setups
- relevant_setups = {t.task_id: t for t in
ti.task.get_upstreams_only_setups()}
+ relevant_setups: dict[str, MappedOperator |
SerializedBaseOperator] = {
+ # TODO (GH-52141): This should return scheduler types, but
+ # currently we reuse logic in SDK DAGNode.
+ t.task_id: t # type: ignore[misc]
+ for t in task.get_upstreams_only_setups()
+ }
if relevant_setups:
for status, changed in
_evaluate_setup_constraint(relevant_setups=relevant_setups):
yield status