This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9985c35711 Fix logic of the skip_all_except method (#31153)
9985c35711 is described below

commit 9985c3571175d054bfabef02979ecc934e6aae73
Author: Dmitry Zhyhimont <[email protected]>
AuthorDate: Thu Jul 6 19:06:22 2023 +0300

    Fix logic of the skip_all_except method (#31153)
    
    * Fix logic of the skip_all_except method to work correctly with a mapped 
branch operator
    
    * Address feadback
    
    * Add an unit test
    
    * Skipp empty tasks list
    
    * Fix static checks
    
    * Address feadback
    
    * Use fully qualified import
    
    ---------
    
    Co-authored-by: Zhyhimont Dmitry <[email protected]>
    Co-authored-by: zhyhimont <[email protected]>
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/models/skipmixin.py    | 56 ++++++++++++++++++++++++++----------------
 tests/models/test_skipmixin.py | 36 +++++++++++++++++++++++++++
 2 files changed, 71 insertions(+), 21 deletions(-)

diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index d75a4a0e4d..849083e38b 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -21,18 +21,19 @@ import warnings
 from typing import TYPE_CHECKING, Iterable, Sequence
 
 from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
+from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
 from airflow.serialization.pydantic.dag_run import DagRunPydantic
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import State
 
 if TYPE_CHECKING:
     from pendulum import DateTime
     from sqlalchemy import Session
 
-    from airflow.models.dagrun import DagRun
     from airflow.models.operator import Operator
     from airflow.models.taskmixin import DAGNode
     from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
@@ -60,24 +61,30 @@ class SkipMixin(LoggingMixin):
     def _set_state_to_skipped(
         self,
         dag_run: DagRun | DagRunPydantic,
-        tasks: Iterable[Operator],
+        tasks: Sequence[str] | Sequence[tuple[str, int]],
         session: Session,
     ) -> None:
         """Used internally to set state of task instances to skipped from the 
same dag run."""
-        now = timezone.utcnow()
-
-        session.query(TaskInstance).filter(
-            TaskInstance.dag_id == dag_run.dag_id,
-            TaskInstance.run_id == dag_run.run_id,
-            TaskInstance.task_id.in_(d.task_id for d in tasks),
-        ).update(
-            {
-                TaskInstance.state: State.SKIPPED,
-                TaskInstance.start_date: now,
-                TaskInstance.end_date: now,
-            },
-            synchronize_session=False,
-        )
+        if tasks:
+            now = timezone.utcnow()
+            TI = TaskInstance
+            query = session.query(TI).filter(
+                TI.dag_id == dag_run.dag_id,
+                TI.run_id == dag_run.run_id,
+            )
+            if isinstance(tasks[0], tuple):
+                query = query.filter(tuple_in_condition((TI.task_id, 
TI.map_index), tasks))
+            else:
+                query = query.filter(TI.task_id.in_(tasks))
+
+            query.update(
+                {
+                    TaskInstance.state: State.SKIPPED,
+                    TaskInstance.start_date: now,
+                    TaskInstance.end_date: now,
+                },
+                synchronize_session=False,
+            )
 
     @provide_session
     def skip(
@@ -130,7 +137,8 @@ class SkipMixin(LoggingMixin):
         if dag_run is None:
             raise ValueError("dag_run is required")
 
-        self._set_state_to_skipped(dag_run, task_list, session)
+        task_ids_list = [d.task_id for d in task_list]
+        self._set_state_to_skipped(dag_run, task_ids_list, session)
         session.commit()
 
         # SkipMixin may not necessarily have a task_id attribute. Only store 
to XCom if one is available.
@@ -140,7 +148,7 @@ class SkipMixin(LoggingMixin):
 
             XCom.set(
                 key=XCOM_SKIPMIXIN_KEY,
-                value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]},
+                value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list},
                 task_id=task_id,
                 dag_id=dag_run.dag_id,
                 run_id=dag_run.run_id,
@@ -183,6 +191,7 @@ class SkipMixin(LoggingMixin):
             )
 
         dag_run = ti.get_dagrun()
+        assert isinstance(dag_run, DagRun)
 
         # TODO(potiuk): Handle TaskInstancePydantic case differently - we need 
to figure out the way to
         # pass task that has been set in LocalTaskJob but in the way that 
TaskInstancePydantic definition
@@ -218,10 +227,15 @@ class SkipMixin(LoggingMixin):
             for branch_task_id in list(branch_task_id_set):
                 
branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
 
-            skip_tasks = [t for t in downstream_tasks if t.task_id not in 
branch_task_id_set]
-            follow_task_ids = [t.task_id for t in downstream_tasks if 
t.task_id in branch_task_id_set]
+            skip_tasks = [
+                (t.task_id, downstream_ti.map_index)
+                for t in downstream_tasks
+                if (downstream_ti := dag_run.get_task_instance(t.task_id, 
map_index=ti.map_index))
+                and t.task_id not in branch_task_id_set
+            ]
 
-            self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
+            follow_task_ids = [t.task_id for t in downstream_tasks if 
t.task_id in branch_task_id_set]
+            self.log.info("Skipping tasks %s", skip_tasks)
             with create_session() as session:
                 self._set_state_to_skipped(dag_run, skip_tasks, 
session=session)
                 # For some reason, session.commit() needs to happen before 
xcom_push.
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index c9912a64d4..547dbec5b4 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -24,6 +24,7 @@ import pendulum
 import pytest
 
 from airflow import settings
+from airflow.decorators import task, task_group
 from airflow.exceptions import AirflowException
 from airflow.models.skipmixin import SkipMixin
 from airflow.models.taskinstance import TaskInstance as TI
@@ -133,6 +134,41 @@ class TestSkipMixin:
 
         assert executed_states == expected_states
 
+    def test_mapped_tasks_skip_all_except(self, dag_maker):
+        with dag_maker("dag_test_skip_all_except") as dag:
+
+            @task
+            def branch_op(k):
+                ...
+
+            @task_group
+            def task_group_op(k):
+                branch_a = EmptyOperator(task_id="branch_a")
+                branch_b = EmptyOperator(task_id="branch_b")
+                branch_op(k) >> [branch_a, branch_b]
+
+            task_group_op.expand(k=[i for i in range(2)])
+
+        dag_maker.create_dagrun()
+        branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), 
execution_date=DEFAULT_DATE, map_index=0)
+        branch_op_ti_1 = TI(dag.get_task("task_group_op.branch_op"), 
execution_date=DEFAULT_DATE, map_index=1)
+        branch_a_ti_0 = TI(dag.get_task("task_group_op.branch_a"), 
execution_date=DEFAULT_DATE, map_index=0)
+        branch_a_ti_1 = TI(dag.get_task("task_group_op.branch_a"), 
execution_date=DEFAULT_DATE, map_index=1)
+        branch_b_ti_0 = TI(dag.get_task("task_group_op.branch_b"), 
execution_date=DEFAULT_DATE, map_index=0)
+        branch_b_ti_1 = TI(dag.get_task("task_group_op.branch_b"), 
execution_date=DEFAULT_DATE, map_index=1)
+
+        SkipMixin().skip_all_except(ti=branch_op_ti_0, 
branch_task_ids="task_group_op.branch_a")
+        SkipMixin().skip_all_except(ti=branch_op_ti_1, 
branch_task_ids="task_group_op.branch_b")
+
+        def get_state(ti):
+            ti.refresh_from_db()
+            return ti.state
+
+        assert get_state(branch_a_ti_0) == State.NONE
+        assert get_state(branch_b_ti_0) == State.SKIPPED
+        assert get_state(branch_a_ti_1) == State.SKIPPED
+        assert get_state(branch_b_ti_1) == State.NONE
+
     def test_raise_exception_on_not_accepted_branch_task_ids_type(self, 
dag_maker):
         with dag_maker("dag_test_skip_all_except_wrong_type"):
             task = EmptyOperator(task_id="task")

Reply via email to