This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-test by this push:
new ff3658a SkipMixin: Add missing session.commit() and test (#10421)
ff3658a is described below
commit ff3658a4fb5a36186bebc604b2ba06a351dcf0df
Author: yuqian90 <[email protected]>
AuthorDate: Wed Sep 23 04:08:12 2020 +0800
SkipMixin: Add missing session.commit() and test (#10421)
(cherry picked from commit 423a382678deac5cb161d38e9266ce47b5666344)
---
airflow/models/skipmixin.py | 3 +++
tests/models/test_skipmixin.py | 28 ++++++++++++++++++++++++++++
2 files changed, 31 insertions(+)
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index f45cac6..a65d484 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -160,6 +160,9 @@ class SkipMixin(LoggingMixin):
self._set_state_to_skipped(
dag_run, ti.execution_date, skip_tasks, session=session
)
+ # For some reason, session.commit() needs to happen before
xcom_push.
+ # Otherwise the session is not committed.
+ session.commit()
ti.xcom_push(
key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED:
follow_task_ids}
)
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index df21d4a..92a072b 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -94,3 +94,31 @@ class TestSkipMixin(unittest.TestCase):
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[],
session=session)
self.assertFalse(session.query.called)
self.assertFalse(session.commit.called)
+
+ def test_skip_all_except(self):
+ dag = DAG(
+ 'dag_test_skip_all_except',
+ start_date=DEFAULT_DATE,
+ )
+ with dag:
+ task1 = DummyOperator(task_id='task1')
+ task2 = DummyOperator(task_id='task2')
+ task3 = DummyOperator(task_id='task3')
+
+ task1 >> [task2, task3]
+
+ ti1 = TI(task1, execution_date=DEFAULT_DATE)
+ ti2 = TI(task2, execution_date=DEFAULT_DATE)
+ ti3 = TI(task3, execution_date=DEFAULT_DATE)
+
+ SkipMixin().skip_all_except(
+ ti=ti1,
+ branch_task_ids=['task2']
+ )
+
+ def get_state(ti):
+ ti.refresh_from_db()
+ return ti.state
+
+ assert get_state(ti2) == State.NONE
+ assert get_state(ti3) == State.SKIPPED