This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fec2b10fb0 Fix SkipMixin with Database Isolation for AIP-44 (#40781)
fec2b10fb0 is described below
commit fec2b10fb01fecc3c25edcde51e28b12dcc55ce8
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Jul 17 22:04:49 2024 +0200
Fix SkipMixin with Database Isolation for AIP-44 (#40781)
* Fix SkipMixin with Database Isolation for AIP-44
* Fix pytest of _log instance
* Fix pytests
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 3 +
airflow/models/dagrun.py | 1 +
airflow/models/skipmixin.py | 73 +++++++++++++++-------
airflow/models/taskinstance.py | 16 +++++
tests/models/test_skipmixin.py | 6 +-
5 files changed, 74 insertions(+), 25 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 1ead0e7cec..608824982c 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -49,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning
from airflow.models.serialized_dag import SerializedDagModel
+ from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import (
TaskInstance,
_add_log,
@@ -110,6 +111,8 @@ def _initialize_map() -> dict[str, Callable]:
DagRun.fetch_task_instance,
DagRun._get_log_template,
SerializedDagModel.get_serialized_dag,
+ SkipMixin._skip,
+ SkipMixin._skip_all_except,
TaskInstance._check_and_change_state_before_execution,
TaskInstance.get_task_instance,
TaskInstance._get_dagrun,
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 9b84bb9b3c..6c3d0715b9 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -651,6 +651,7 @@ class DagRun(Base, LoggingMixin):
)
@staticmethod
+ @internal_api_call
@provide_session
def fetch_task_instance(
dag_id: str,
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 3c89deda12..1ed56a43bf 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -18,15 +18,17 @@
from __future__ import annotations
import warnings
+from types import GeneratorType
from typing import TYPE_CHECKING, Iterable, Sequence
from sqlalchemy import select, update
+from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import TaskInstanceState
@@ -60,8 +62,8 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) ->
Sequence[Operator]:
class SkipMixin(LoggingMixin):
"""A Mixin to skip Tasks Instances."""
+ @staticmethod
def _set_state_to_skipped(
- self,
dag_run: DagRun | DagRunPydantic,
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
@@ -93,12 +95,28 @@ class SkipMixin(LoggingMixin):
.execution_options(synchronize_session=False)
)
- @provide_session
def skip(
self,
dag_run: DagRun | DagRunPydantic,
execution_date: DateTime,
tasks: Iterable[DAGNode],
+ map_index: int = -1,
+ ):
+ """Facade for compatibility for call to internal API."""
+ # SkipMixin may not necessarily have a task_id attribute. Only store
to XCom if one is available.
+ task_id: str | None = getattr(self, "task_id", None)
+ SkipMixin._skip(
+ dag_run=dag_run, task_id=task_id, execution_date=execution_date,
tasks=tasks, map_index=map_index
+ )
+
+ @staticmethod
+ @internal_api_call
+ @provide_session
+ def _skip(
+ dag_run: DagRun | DagRunPydantic,
+ task_id: str | None,
+ execution_date: DateTime,
+ tasks: Iterable[DAGNode],
session: Session = NEW_SESSION,
map_index: int = -1,
):
@@ -143,11 +161,9 @@ class SkipMixin(LoggingMixin):
raise ValueError("dag_run is required")
task_ids_list = [d.task_id for d in task_list]
- self._set_state_to_skipped(dag_run, task_ids_list, session)
+ SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()
- # SkipMixin may not necessarily have a task_id attribute. Only store
to XCom if one is available.
- task_id: str | None = getattr(self, "task_id", None)
if task_id is not None:
from airflow.models.xcom import XCom
@@ -165,6 +181,21 @@ class SkipMixin(LoggingMixin):
self,
ti: TaskInstance | TaskInstancePydantic,
branch_task_ids: None | str | Iterable[str],
+ ):
+ """Facade for compatibility for call to internal API."""
+ # Ensure we don't serialize a generator object
+ if branch_task_ids and isinstance(branch_task_ids, GeneratorType):
+ branch_task_ids = list(branch_task_ids)
+ SkipMixin._skip_all_except(ti=ti, branch_task_ids=branch_task_ids)
+
+ @classmethod
+ @internal_api_call
+ @provide_session
+ def _skip_all_except(
+ cls,
+ ti: TaskInstance | TaskInstancePydantic,
+ branch_task_ids: None | str | Iterable[str],
+ session: Session = NEW_SESSION,
):
"""
Implement the logic for a branching operator.
@@ -175,6 +206,7 @@ class SkipMixin(LoggingMixin):
branch_task_ids is stored to XCom so that NotPreviouslySkippedDep
knows skipped tasks or
newly added tasks should be skipped when they are cleared.
"""
+ log = cls().log # Note: need to catch logger form instance, static
logger breaks pytest
if isinstance(branch_task_ids, str):
branch_task_id_set = {branch_task_ids}
elif isinstance(branch_task_ids, Iterable):
@@ -195,20 +227,15 @@ class SkipMixin(LoggingMixin):
f"but got {type(branch_task_ids).__name__!r}."
)
- self.log.info("Following branch %s", branch_task_id_set)
+ log.info("Following branch %s", branch_task_id_set)
- dag_run = ti.get_dagrun()
+ dag_run = ti.get_dagrun(session=session)
if TYPE_CHECKING:
assert isinstance(dag_run, DagRun)
assert ti.task
- # TODO(potiuk): Handle TaskInstancePydantic case differently - we need
to figure out the way to
- # pass task that has been set in LocalTaskJob but in the way that
TaskInstancePydantic definition
- # does not attempt to serialize the field from/to ORM
task = ti.task
- dag = task.dag
- if TYPE_CHECKING:
- assert dag
+ dag = TaskInstance.ensure_dag(ti, session=session)
valid_task_ids = set(dag.task_ids)
invalid_task_ids = branch_task_id_set - valid_task_ids
@@ -239,15 +266,17 @@ class SkipMixin(LoggingMixin):
skip_tasks = [
(t.task_id, downstream_ti.map_index)
for t in downstream_tasks
- if (downstream_ti := dag_run.get_task_instance(t.task_id,
map_index=ti.map_index))
+ if (
+ downstream_ti := dag_run.get_task_instance(
+ t.task_id, map_index=ti.map_index, session=session
+ )
+ )
and t.task_id not in branch_task_id_set
]
follow_task_ids = [t.task_id for t in downstream_tasks if
t.task_id in branch_task_id_set]
- self.log.info("Skipping tasks %s", skip_tasks)
- with create_session() as session:
- self._set_state_to_skipped(dag_run, skip_tasks,
session=session)
- # For some reason, session.commit() needs to happen before
xcom_push.
- # Otherwise the session is not committed.
- session.commit()
- ti.xcom_push(key=XCOM_SKIPMIXIN_KEY,
value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
+ log.info("Skipping tasks %s", skip_tasks)
+ SkipMixin._set_state_to_skipped(dag_run, skip_tasks,
session=session)
+ ti.xcom_push(
+ key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED:
follow_task_ids}, session=session
+ )
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 80b3eedbc8..1dacbe7525 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2633,6 +2633,22 @@ class TaskInstance(Base, LoggingMixin):
return dr
+ @classmethod
+ @provide_session
+ def ensure_dag(
+ cls, task_instance: TaskInstance | TaskInstancePydantic, session:
Session = NEW_SESSION
+ ) -> DAG:
+ """Ensure that task has a dag object associated, might have been
removed by serialization."""
+ if TYPE_CHECKING:
+ assert task_instance.task
+ if task_instance.task.dag is None or task_instance.task.dag is
ATTRIBUTE_REMOVED:
+ task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag(
+ dag_id=task_instance.dag_id, session=session
+ )
+ if TYPE_CHECKING:
+ assert task_instance.task.dag
+ return task_instance.task.dag
+
@classmethod
@internal_api_call
@provide_session
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 2fd5fb0fe6..465d15130f 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -65,7 +65,7 @@ class TestSkipMixin:
execution_date=now,
state=State.FAILED,
)
- SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks,
session=session)
+ SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks)
session.query(TI).filter(
TI.dag_id == "dag",
@@ -91,7 +91,7 @@ class TestSkipMixin:
RemovedInAirflow3Warning,
match=r"Passing an execution_date to `skip\(\)` is deprecated in
favour of passing a dag_run",
):
- SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks,
session=session)
+ SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks)
session.query(TI).filter(
TI.dag_id == "dag",
@@ -103,7 +103,7 @@ class TestSkipMixin:
def test_skip_none_tasks(self):
session = Mock()
- SkipMixin().skip(dag_run=None, execution_date=None, tasks=[],
session=session)
+ SkipMixin().skip(dag_run=None, execution_date=None, tasks=[])
assert not session.query.called
assert not session.commit.called