This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-test by this push:
     new ff3658a  SkipMixin: Add missing session.commit() and test (#10421)
ff3658a is described below

commit ff3658a4fb5a36186bebc604b2ba06a351dcf0df
Author: yuqian90 <[email protected]>
AuthorDate: Wed Sep 23 04:08:12 2020 +0800

    SkipMixin: Add missing session.commit() and test (#10421)
    
    (cherry picked from commit 423a382678deac5cb161d38e9266ce47b5666344)
---
 airflow/models/skipmixin.py    |  3 +++
 tests/models/test_skipmixin.py | 28 ++++++++++++++++++++++++++++
 2 files changed, 31 insertions(+)

diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index f45cac6..a65d484 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -160,6 +160,9 @@ class SkipMixin(LoggingMixin):
                 self._set_state_to_skipped(
                     dag_run, ti.execution_date, skip_tasks, session=session
                 )
+                # For some reason, session.commit() needs to happen before 
xcom_push.
+                # Otherwise the session is not committed.
+                session.commit()
                 ti.xcom_push(
                     key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: 
follow_task_ids}
                 )
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index df21d4a..92a072b 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -94,3 +94,31 @@ class TestSkipMixin(unittest.TestCase):
         SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], 
session=session)
         self.assertFalse(session.query.called)
         self.assertFalse(session.commit.called)
+
+    def test_skip_all_except(self):
+        dag = DAG(
+            'dag_test_skip_all_except',
+            start_date=DEFAULT_DATE,
+        )
+        with dag:
+            task1 = DummyOperator(task_id='task1')
+            task2 = DummyOperator(task_id='task2')
+            task3 = DummyOperator(task_id='task3')
+
+            task1 >> [task2, task3]
+
+        ti1 = TI(task1, execution_date=DEFAULT_DATE)
+        ti2 = TI(task2, execution_date=DEFAULT_DATE)
+        ti3 = TI(task3, execution_date=DEFAULT_DATE)
+
+        SkipMixin().skip_all_except(
+            ti=ti1,
+            branch_task_ids=['task2']
+        )
+
+        def get_state(ti):
+            ti.refresh_from_db()
+            return ti.state
+
+        assert get_state(ti2) == State.NONE
+        assert get_state(ti3) == State.SKIPPED

Reply via email to