This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit ad70ce214fc8d252f3f7c9c10c06043372481628 Author: yuqian90 <[email protected]> AuthorDate: Wed Sep 23 04:08:12 2020 +0800 SkipMixin: Add missing session.commit() and test (#10421) (cherry picked from commit 423a382678deac5cb161d38e9266ce47b5666344) --- airflow/models/skipmixin.py | 3 +++ tests/models/test_skipmixin.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index f45cac6..a65d484 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -160,6 +160,9 @@ class SkipMixin(LoggingMixin): self._set_state_to_skipped( dag_run, ti.execution_date, skip_tasks, session=session ) + # For some reason, session.commit() needs to happen before xcom_push. + # Otherwise the session is not committed. + session.commit() ti.xcom_push( key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids} ) diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index df21d4a..92a072b 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -94,3 +94,31 @@ class TestSkipMixin(unittest.TestCase): SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session) self.assertFalse(session.query.called) self.assertFalse(session.commit.called) + + def test_skip_all_except(self): + dag = DAG( + 'dag_test_skip_all_except', + start_date=DEFAULT_DATE, + ) + with dag: + task1 = DummyOperator(task_id='task1') + task2 = DummyOperator(task_id='task2') + task3 = DummyOperator(task_id='task3') + + task1 >> [task2, task3] + + ti1 = TI(task1, execution_date=DEFAULT_DATE) + ti2 = TI(task2, execution_date=DEFAULT_DATE) + ti3 = TI(task3, execution_date=DEFAULT_DATE) + + SkipMixin().skip_all_except( + ti=ti1, + branch_task_ids=['task2'] + ) + + def get_state(ti): + ti.refresh_from_db() + return ti.state + + assert get_state(ti2) == State.NONE + assert get_state(ti3) == State.SKIPPED
