This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a969694c2a1 Get `skipmixin` working temporarily (#45824)
a969694c2a1 is described below
commit a969694c2a17c07d7b7d91a884391f6b818117e4
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jan 21 17:52:36 2025 +0530
Get `skipmixin` working temporarily (#45824)
---
airflow/models/skipmixin.py | 8 +++-----
airflow/models/taskinstance.py | 1 -
tests/models/test_skipmixin.py | 4 +++-
3 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 5e3c47ad3a1..3b7d21d7b38 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING
from sqlalchemy import tuple_, update
+from airflow import settings
from airflow.exceptions import AirflowException
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
@@ -33,7 +34,6 @@ from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.orm import Session
- from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions._internal.node import DAGNode
@@ -136,12 +136,10 @@ class SkipMixin(LoggingMixin):
session=session,
)
- @provide_session
def skip_all_except(
self,
ti: TaskInstance,
branch_task_ids: None | str | Iterable[str],
- session: Session = NEW_SESSION,
):
"""
Implement the logic for a branching operator.
@@ -178,12 +176,11 @@ class SkipMixin(LoggingMixin):
log.info("Following branch %s", branch_task_id_set)
- dag_run = ti.get_dagrun(session=session)
if TYPE_CHECKING:
- assert isinstance(dag_run, DagRun)
assert ti.task
task = ti.task
+ session = settings.Session()
dag = TaskInstance.ensure_dag(ti, session=session)
valid_task_ids = set(dag.task_ids)
@@ -212,6 +209,7 @@ class SkipMixin(LoggingMixin):
for branch_task_id in list(branch_task_id_set):
branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
+ dag_run = ti.get_dagrun(session=session)
skip_tasks = [
(t.task_id, downstream_ti.map_index)
for t in downstream_tasks
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index d61331dd620..9a3a85b8253 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1896,7 +1896,6 @@ class TaskInstance(Base, LoggingMixin):
task=runtime_ti.task, # type: ignore[arg-type]
map_index=runtime_ti.map_index,
)
- ti.refresh_from_db()
if TYPE_CHECKING:
assert ti
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 4c2e23e0ffd..71075209302 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -92,7 +92,7 @@ class TestSkipMixin:
],
ids=["list-of-task-ids", "tuple-of-task-ids", "str-task-id", "None",
"empty-list"],
)
- def test_skip_all_except(self, dag_maker, branch_task_ids,
expected_states):
+ def test_skip_all_except(self, dag_maker, branch_task_ids,
expected_states, session):
with dag_maker(
"dag_test_skip_all_except",
serialized=True,
@@ -110,6 +110,8 @@ class TestSkipMixin:
SkipMixin().skip_all_except(ti=ti1, branch_task_ids=branch_task_ids)
+ session.expire_all()
+
def get_state(ti):
ti.refresh_from_db()
return ti.state