This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a969694c2a1 Get `skipmixin` working temporarily (#45824)
a969694c2a1 is described below

commit a969694c2a17c07d7b7d91a884391f6b818117e4
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jan 21 17:52:36 2025 +0530

    Get `skipmixin` working temporarily (#45824)
---
 airflow/models/skipmixin.py    | 8 +++-----
 airflow/models/taskinstance.py | 1 -
 tests/models/test_skipmixin.py | 4 +++-
 3 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 5e3c47ad3a1..3b7d21d7b38 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING
 
 from sqlalchemy import tuple_, update
 
+from airflow import settings
 from airflow.exceptions import AirflowException
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
@@ -33,7 +34,6 @@ from airflow.utils.state import TaskInstanceState
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
-    from airflow.models.dagrun import DagRun
     from airflow.models.operator import Operator
     from airflow.sdk.definitions._internal.node import DAGNode
 
@@ -136,12 +136,10 @@ class SkipMixin(LoggingMixin):
                 session=session,
             )
 
-    @provide_session
     def skip_all_except(
         self,
         ti: TaskInstance,
         branch_task_ids: None | str | Iterable[str],
-        session: Session = NEW_SESSION,
     ):
         """
         Implement the logic for a branching operator.
@@ -178,12 +176,11 @@ class SkipMixin(LoggingMixin):
 
         log.info("Following branch %s", branch_task_id_set)
 
-        dag_run = ti.get_dagrun(session=session)
         if TYPE_CHECKING:
-            assert isinstance(dag_run, DagRun)
             assert ti.task
 
         task = ti.task
+        session = settings.Session()
         dag = TaskInstance.ensure_dag(ti, session=session)
 
         valid_task_ids = set(dag.task_ids)
@@ -212,6 +209,7 @@ class SkipMixin(LoggingMixin):
             for branch_task_id in list(branch_task_id_set):
                 
branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
 
+            dag_run = ti.get_dagrun(session=session)
             skip_tasks = [
                 (t.task_id, downstream_ti.map_index)
                 for t in downstream_tasks
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index d61331dd620..9a3a85b8253 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1896,7 +1896,6 @@ class TaskInstance(Base, LoggingMixin):
             task=runtime_ti.task,  # type: ignore[arg-type]
             map_index=runtime_ti.map_index,
         )
-        ti.refresh_from_db()
 
         if TYPE_CHECKING:
             assert ti
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 4c2e23e0ffd..71075209302 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -92,7 +92,7 @@ class TestSkipMixin:
         ],
         ids=["list-of-task-ids", "tuple-of-task-ids", "str-task-id", "None", 
"empty-list"],
     )
-    def test_skip_all_except(self, dag_maker, branch_task_ids, 
expected_states):
+    def test_skip_all_except(self, dag_maker, branch_task_ids, 
expected_states, session):
         with dag_maker(
             "dag_test_skip_all_except",
             serialized=True,
@@ -110,6 +110,8 @@ class TestSkipMixin:
 
         SkipMixin().skip_all_except(ti=ti1, branch_task_ids=branch_task_ids)
 
+        session.expire_all()
+
         def get_state(ti):
             ti.refresh_from_db()
             return ti.state

Reply via email to