This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9985c35711 Fix logic of the skip_all_except method (#31153)
9985c35711 is described below
commit 9985c3571175d054bfabef02979ecc934e6aae73
Author: Dmitry Zhyhimont <[email protected]>
AuthorDate: Thu Jul 6 19:06:22 2023 +0300
Fix logic of the skip_all_except method (#31153)
* Fix logic of the skip_all_except method to work correctly with a mapped
branch operator
* Address feadback
* Add an unit test
* Skipp empty tasks list
* Fix static checks
* Address feadback
* Use fully qualified import
---------
Co-authored-by: Zhyhimont Dmitry <[email protected]>
Co-authored-by: zhyhimont <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/models/skipmixin.py | 56 ++++++++++++++++++++++++++----------------
tests/models/test_skipmixin.py | 36 +++++++++++++++++++++++++++
2 files changed, 71 insertions(+), 21 deletions(-)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index d75a4a0e4d..849083e38b 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -21,18 +21,19 @@ import warnings
from typing import TYPE_CHECKING, Iterable, Sequence
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
+from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State
if TYPE_CHECKING:
from pendulum import DateTime
from sqlalchemy import Session
- from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.models.taskmixin import DAGNode
from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
@@ -60,24 +61,30 @@ class SkipMixin(LoggingMixin):
def _set_state_to_skipped(
self,
dag_run: DagRun | DagRunPydantic,
- tasks: Iterable[Operator],
+ tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
"""Used internally to set state of task instances to skipped from the
same dag run."""
- now = timezone.utcnow()
-
- session.query(TaskInstance).filter(
- TaskInstance.dag_id == dag_run.dag_id,
- TaskInstance.run_id == dag_run.run_id,
- TaskInstance.task_id.in_(d.task_id for d in tasks),
- ).update(
- {
- TaskInstance.state: State.SKIPPED,
- TaskInstance.start_date: now,
- TaskInstance.end_date: now,
- },
- synchronize_session=False,
- )
+ if tasks:
+ now = timezone.utcnow()
+ TI = TaskInstance
+ query = session.query(TI).filter(
+ TI.dag_id == dag_run.dag_id,
+ TI.run_id == dag_run.run_id,
+ )
+ if isinstance(tasks[0], tuple):
+ query = query.filter(tuple_in_condition((TI.task_id,
TI.map_index), tasks))
+ else:
+ query = query.filter(TI.task_id.in_(tasks))
+
+ query.update(
+ {
+ TaskInstance.state: State.SKIPPED,
+ TaskInstance.start_date: now,
+ TaskInstance.end_date: now,
+ },
+ synchronize_session=False,
+ )
@provide_session
def skip(
@@ -130,7 +137,8 @@ class SkipMixin(LoggingMixin):
if dag_run is None:
raise ValueError("dag_run is required")
- self._set_state_to_skipped(dag_run, task_list, session)
+ task_ids_list = [d.task_id for d in task_list]
+ self._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()
# SkipMixin may not necessarily have a task_id attribute. Only store
to XCom if one is available.
@@ -140,7 +148,7 @@ class SkipMixin(LoggingMixin):
XCom.set(
key=XCOM_SKIPMIXIN_KEY,
- value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]},
+ value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
task_id=task_id,
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
@@ -183,6 +191,7 @@ class SkipMixin(LoggingMixin):
)
dag_run = ti.get_dagrun()
+ assert isinstance(dag_run, DagRun)
# TODO(potiuk): Handle TaskInstancePydantic case differently - we need
to figure out the way to
# pass task that has been set in LocalTaskJob but in the way that
TaskInstancePydantic definition
@@ -218,10 +227,15 @@ class SkipMixin(LoggingMixin):
for branch_task_id in list(branch_task_id_set):
branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
- skip_tasks = [t for t in downstream_tasks if t.task_id not in
branch_task_id_set]
- follow_task_ids = [t.task_id for t in downstream_tasks if
t.task_id in branch_task_id_set]
+ skip_tasks = [
+ (t.task_id, downstream_ti.map_index)
+ for t in downstream_tasks
+ if (downstream_ti := dag_run.get_task_instance(t.task_id,
map_index=ti.map_index))
+ and t.task_id not in branch_task_id_set
+ ]
- self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
+ follow_task_ids = [t.task_id for t in downstream_tasks if
t.task_id in branch_task_id_set]
+ self.log.info("Skipping tasks %s", skip_tasks)
with create_session() as session:
self._set_state_to_skipped(dag_run, skip_tasks,
session=session)
# For some reason, session.commit() needs to happen before
xcom_push.
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index c9912a64d4..547dbec5b4 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -24,6 +24,7 @@ import pendulum
import pytest
from airflow import settings
+from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import TaskInstance as TI
@@ -133,6 +134,41 @@ class TestSkipMixin:
assert executed_states == expected_states
+ def test_mapped_tasks_skip_all_except(self, dag_maker):
+ with dag_maker("dag_test_skip_all_except") as dag:
+
+ @task
+ def branch_op(k):
+ ...
+
+ @task_group
+ def task_group_op(k):
+ branch_a = EmptyOperator(task_id="branch_a")
+ branch_b = EmptyOperator(task_id="branch_b")
+ branch_op(k) >> [branch_a, branch_b]
+
+ task_group_op.expand(k=[i for i in range(2)])
+
+ dag_maker.create_dagrun()
+ branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"),
execution_date=DEFAULT_DATE, map_index=0)
+ branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"),
execution_date=DEFAULT_DATE, map_index=1)
+ branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"),
execution_date=DEFAULT_DATE, map_index=0)
+ branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"),
execution_date=DEFAULT_DATE, map_index=1)
+ branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"),
execution_date=DEFAULT_DATE, map_index=0)
+ branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"),
execution_date=DEFAULT_DATE, map_index=1)
+
+ SkipMixin().skip_all_except(ti=branch_op_ti_0,
branch_task_ids="task_group_op.branch_a")
+ SkipMixin().skip_all_except(ti=branch_op_ti_1,
branch_task_ids="task_group_op.branch_b")
+
+ def get_state(ti):
+ ti.refresh_from_db()
+ return ti.state
+
+ assert get_state(branch_a_ti_0) == State.NONE
+ assert get_state(branch_b_ti_0) == State.SKIPPED
+ assert get_state(branch_a_ti_1) == State.SKIPPED
+ assert get_state(branch_b_ti_1) == State.NONE
+
def test_raise_exception_on_not_accepted_branch_task_ids_type(self,
dag_maker):
with dag_maker("dag_test_skip_all_except_wrong_type"):
task = EmptyOperator(task_id="task")