This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fec2b10fb0 Fix SkipMixin with Database Isolation for AIP-44 (#40781)
fec2b10fb0 is described below

commit fec2b10fb01fecc3c25edcde51e28b12dcc55ce8
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Jul 17 22:04:49 2024 +0200

    Fix SkipMixin with Database Isolation for AIP-44 (#40781)
    
    * Fix SkipMixin with Database Isolation for AIP-44
    
    * Fix pytest of _log instance
    
    * Fix pytests
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  3 +
 airflow/models/dagrun.py                           |  1 +
 airflow/models/skipmixin.py                        | 73 +++++++++++++++-------
 airflow/models/taskinstance.py                     | 16 +++++
 tests/models/test_skipmixin.py                     |  6 +-
 5 files changed, 74 insertions(+), 25 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 1ead0e7cec..608824982c 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -49,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]:
     from airflow.models.dagrun import DagRun
     from airflow.models.dagwarning import DagWarning
     from airflow.models.serialized_dag import SerializedDagModel
+    from airflow.models.skipmixin import SkipMixin
     from airflow.models.taskinstance import (
         TaskInstance,
         _add_log,
@@ -110,6 +111,8 @@ def _initialize_map() -> dict[str, Callable]:
         DagRun.fetch_task_instance,
         DagRun._get_log_template,
         SerializedDagModel.get_serialized_dag,
+        SkipMixin._skip,
+        SkipMixin._skip_all_except,
         TaskInstance._check_and_change_state_before_execution,
         TaskInstance.get_task_instance,
         TaskInstance._get_dagrun,
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 9b84bb9b3c..6c3d0715b9 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -651,6 +651,7 @@ class DagRun(Base, LoggingMixin):
         )
 
     @staticmethod
+    @internal_api_call
     @provide_session
     def fetch_task_instance(
         dag_id: str,
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 3c89deda12..1ed56a43bf 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -18,15 +18,17 @@
 from __future__ import annotations
 
 import warnings
+from types import GeneratorType
 from typing import TYPE_CHECKING, Iterable, Sequence
 
 from sqlalchemy import select, update
 
+from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import TaskInstanceState
 
@@ -60,8 +62,8 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> 
Sequence[Operator]:
 class SkipMixin(LoggingMixin):
     """A Mixin to skip Tasks Instances."""
 
+    @staticmethod
     def _set_state_to_skipped(
-        self,
         dag_run: DagRun | DagRunPydantic,
         tasks: Sequence[str] | Sequence[tuple[str, int]],
         session: Session,
@@ -93,12 +95,28 @@ class SkipMixin(LoggingMixin):
                     .execution_options(synchronize_session=False)
                 )
 
-    @provide_session
     def skip(
         self,
         dag_run: DagRun | DagRunPydantic,
         execution_date: DateTime,
         tasks: Iterable[DAGNode],
+        map_index: int = -1,
+    ):
+        """Facade for compatibility for call to internal API."""
+        # SkipMixin may not necessarily have a task_id attribute. Only store 
to XCom if one is available.
+        task_id: str | None = getattr(self, "task_id", None)
+        SkipMixin._skip(
+            dag_run=dag_run, task_id=task_id, execution_date=execution_date, 
tasks=tasks, map_index=map_index
+        )
+
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def _skip(
+        dag_run: DagRun | DagRunPydantic,
+        task_id: str | None,
+        execution_date: DateTime,
+        tasks: Iterable[DAGNode],
         session: Session = NEW_SESSION,
         map_index: int = -1,
     ):
@@ -143,11 +161,9 @@ class SkipMixin(LoggingMixin):
             raise ValueError("dag_run is required")
 
         task_ids_list = [d.task_id for d in task_list]
-        self._set_state_to_skipped(dag_run, task_ids_list, session)
+        SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
         session.commit()
 
-        # SkipMixin may not necessarily have a task_id attribute. Only store 
to XCom if one is available.
-        task_id: str | None = getattr(self, "task_id", None)
         if task_id is not None:
             from airflow.models.xcom import XCom
 
@@ -165,6 +181,21 @@ class SkipMixin(LoggingMixin):
         self,
         ti: TaskInstance | TaskInstancePydantic,
         branch_task_ids: None | str | Iterable[str],
+    ):
+        """Facade for compatibility for call to internal API."""
+        # Ensure we don't serialize a generator object
+        if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
+            branch_task_ids = list(branch_task_ids)
+        SkipMixin._skip_all_except(ti=ti, branch_task_ids=branch_task_ids)
+
+    @classmethod
+    @internal_api_call
+    @provide_session
+    def _skip_all_except(
+        cls,
+        ti: TaskInstance | TaskInstancePydantic,
+        branch_task_ids: None | str | Iterable[str],
+        session: Session = NEW_SESSION,
     ):
         """
         Implement the logic for a branching operator.
@@ -175,6 +206,7 @@ class SkipMixin(LoggingMixin):
         branch_task_ids is stored to XCom so that NotPreviouslySkippedDep 
knows skipped tasks or
         newly added tasks should be skipped when they are cleared.
         """
+        log = cls().log  # Note: need to catch logger form instance, static 
logger breaks pytest
         if isinstance(branch_task_ids, str):
             branch_task_id_set = {branch_task_ids}
         elif isinstance(branch_task_ids, Iterable):
@@ -195,20 +227,15 @@ class SkipMixin(LoggingMixin):
                 f"but got {type(branch_task_ids).__name__!r}."
             )
 
-        self.log.info("Following branch %s", branch_task_id_set)
+        log.info("Following branch %s", branch_task_id_set)
 
-        dag_run = ti.get_dagrun()
+        dag_run = ti.get_dagrun(session=session)
         if TYPE_CHECKING:
             assert isinstance(dag_run, DagRun)
             assert ti.task
 
-        # TODO(potiuk): Handle TaskInstancePydantic case differently - we need 
to figure out the way to
-        # pass task that has been set in LocalTaskJob but in the way that 
TaskInstancePydantic definition
-        # does not attempt to serialize the field from/to ORM
         task = ti.task
-        dag = task.dag
-        if TYPE_CHECKING:
-            assert dag
+        dag = TaskInstance.ensure_dag(ti, session=session)
 
         valid_task_ids = set(dag.task_ids)
         invalid_task_ids = branch_task_id_set - valid_task_ids
@@ -239,15 +266,17 @@ class SkipMixin(LoggingMixin):
             skip_tasks = [
                 (t.task_id, downstream_ti.map_index)
                 for t in downstream_tasks
-                if (downstream_ti := dag_run.get_task_instance(t.task_id, 
map_index=ti.map_index))
+                if (
+                    downstream_ti := dag_run.get_task_instance(
+                        t.task_id, map_index=ti.map_index, session=session
+                    )
+                )
                 and t.task_id not in branch_task_id_set
             ]
 
             follow_task_ids = [t.task_id for t in downstream_tasks if 
t.task_id in branch_task_id_set]
-            self.log.info("Skipping tasks %s", skip_tasks)
-            with create_session() as session:
-                self._set_state_to_skipped(dag_run, skip_tasks, 
session=session)
-                # For some reason, session.commit() needs to happen before 
xcom_push.
-                # Otherwise the session is not committed.
-                session.commit()
-                ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, 
value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
+            log.info("Skipping tasks %s", skip_tasks)
+            SkipMixin._set_state_to_skipped(dag_run, skip_tasks, 
session=session)
+            ti.xcom_push(
+                key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: 
follow_task_ids}, session=session
+            )
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 80b3eedbc8..1dacbe7525 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2633,6 +2633,22 @@ class TaskInstance(Base, LoggingMixin):
 
         return dr
 
+    @classmethod
+    @provide_session
+    def ensure_dag(
+        cls, task_instance: TaskInstance | TaskInstancePydantic, session: 
Session = NEW_SESSION
+    ) -> DAG:
+        """Ensure that task has a dag object associated, might have been 
removed by serialization."""
+        if TYPE_CHECKING:
+            assert task_instance.task
+        if task_instance.task.dag is None or task_instance.task.dag is 
ATTRIBUTE_REMOVED:
+            task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag(
+                dag_id=task_instance.dag_id, session=session
+            )
+        if TYPE_CHECKING:
+            assert task_instance.task.dag
+        return task_instance.task.dag
+
     @classmethod
     @internal_api_call
     @provide_session
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 2fd5fb0fe6..465d15130f 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -65,7 +65,7 @@ class TestSkipMixin:
             execution_date=now,
             state=State.FAILED,
         )
-        SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks, 
session=session)
+        SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks)
 
         session.query(TI).filter(
             TI.dag_id == "dag",
@@ -91,7 +91,7 @@ class TestSkipMixin:
             RemovedInAirflow3Warning,
             match=r"Passing an execution_date to `skip\(\)` is deprecated in 
favour of passing a dag_run",
         ):
-            SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, 
session=session)
+            SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks)
 
         session.query(TI).filter(
             TI.dag_id == "dag",
@@ -103,7 +103,7 @@ class TestSkipMixin:
 
     def test_skip_none_tasks(self):
         session = Mock()
-        SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], 
session=session)
+        SkipMixin().skip(dag_run=None, execution_date=None, tasks=[])
         assert not session.query.called
         assert not session.commit.called
 

Reply via email to