This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 94a0a0e8ce4d2b54cd6a08301684e299ca3c36cb
Author: Rocco Pascale <[email protected]>
AuthorDate: Sun Oct 24 11:58:35 2021 -0400

    Clear ti.next_method and ti.next_kwargs on task finish (#19183)
    
    (cherry picked from commit 8a0d6c2af86e6b7c3e73ba2ee8b16c1f18ad3771)
---
 airflow/models/taskinstance.py    | 19 ++++++++++---
 tests/models/test_taskinstance.py | 57 ++++++++++++++++++++++++++++++++-------
 2 files changed, 62 insertions(+), 14 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index a9f30dc..55e11fb 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1280,6 +1280,14 @@ class TaskInstance(Base, LoggingMixin):
             self._date_or_empty('end_date'),
         )
 
+    # Ensure we unset next_method and next_kwargs to ensure that any
+    # retries don't re-use them.
+    def clear_next_method_args(self):
+        self.log.debug("Clearing next_method and next_kwargs.")
+
+        self.next_method = None
+        self.next_kwargs = None
+
     @provide_session
     @Sentry.enrich_errors
     def _run_raw_task(
@@ -1369,6 +1377,9 @@ class TaskInstance(Base, LoggingMixin):
             # or dagrun timed out and task is marked as skipped
             # current behavior doesn't hit the callbacks
             if self.state in State.finished:
+                self.clear_next_method_args()
+                session.merge(self)
+                session.commit()
                 return
             else:
                 self.handle_failure(e, test_mode, error_file=error_file, 
session=session)
@@ -1382,6 +1393,7 @@ class TaskInstance(Base, LoggingMixin):
             
Stats.incr(f'ti.finish.{self.task.dag_id}.{self.task.task_id}.{self.state}')
 
         # Recording SKIPPED or SUCCESS
+        self.clear_next_method_args()
         self.end_date = timezone.utcnow()
         self._log_state()
         self.set_duration()
@@ -1667,6 +1679,8 @@ class TaskInstance(Base, LoggingMixin):
         # to same log file.
         self._try_number -= 1
 
+        self.clear_next_method_args()
+
         session.merge(self)
         session.commit()
         self.log.info('Rescheduling task, marking task as UP_FOR_RESCHEDULE')
@@ -1708,10 +1722,7 @@ class TaskInstance(Base, LoggingMixin):
             dag_run = self.get_dagrun(session=session)  # self.dag_run not 
populated by refresh_from_db
             session.add(TaskFail(task, dag_run.execution_date, 
self.start_date, self.end_date))
 
-        # Ensure we unset next_method and next_kwargs to ensure that any
-        # retries don't re-use them.
-        self.next_method = None
-        self.next_kwargs = None
+        self.clear_next_method_args()
 
         # Set state correctly and figure out how to log it and decide whether
         # to email
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index bf58aac..8b758f2 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -33,6 +33,7 @@ from airflow import models, settings
 from airflow.exceptions import (
     AirflowException,
     AirflowFailException,
+    AirflowRescheduleException,
     AirflowSensorTimeout,
     AirflowSkipException,
 )
@@ -481,18 +482,51 @@ class TestTaskInstance:
         ti.refresh_from_db()
         ti.state == state
 
-    def test_task_retry_wipes_next_fields(self, session, dag_maker):
+    @pytest.mark.parametrize(
+        "state",
+        [State.FAILED, State.SKIPPED, State.SUCCESS, State.UP_FOR_RESCHEDULE, 
State.UP_FOR_RETRY],
+    )
+    def test_task_wipes_next_fields(self, session, state, dag_maker):
         """
         Test that ensures that tasks wipe their next_method and next_kwargs
-        fields when they are queued for retry after a failure.
+        when they go into a state of FAILED, SKIPPED, SUCCESS, 
UP_FOR_RESCHEDULE, or UP_FOR_RETRY.
         """
 
-        with dag_maker('test_mark_failure_2'):
-            task = BashOperator(
-                task_id='test_retry_handling_op',
-                bash_command='exit 1',
-                retries=1,
-                retry_delay=datetime.timedelta(seconds=2),
+        def failure():
+            raise AirflowException
+
+        def skip():
+            raise AirflowSkipException
+
+        def success():
+            return None
+
+        def reschedule():
+            reschedule_date = timezone.utcnow()
+            raise AirflowRescheduleException(reschedule_date)
+
+        _retries = 0
+        _retry_delay = datetime.timedelta(seconds=0)
+
+        if state == State.FAILED:
+            _python_callable = failure
+        elif state == State.SKIPPED:
+            _python_callable = skip
+        elif state == State.SUCCESS:
+            _python_callable = success
+        elif state == State.UP_FOR_RESCHEDULE:
+            _python_callable = reschedule
+        elif state in [State.FAILED, State.UP_FOR_RETRY]:
+            _python_callable = failure
+            _retries = 1
+            _retry_delay = datetime.timedelta(seconds=2)
+
+        with dag_maker("test_deferred_method_clear"):
+            task = PythonOperator(
+                task_id="test_deferred_method_clear_task",
+                python_callable=_python_callable,
+                retries=_retries,
+                retry_delay=_retry_delay,
             )
 
         dr = dag_maker.create_dagrun()
@@ -503,13 +537,16 @@ class TestTaskInstance:
         session.commit()
 
         ti.task = task
-        with pytest.raises(AirflowException):
+        if state in [State.FAILED, State.UP_FOR_RETRY]:
+            with pytest.raises(AirflowException):
+                ti.run()
+        elif state in [State.SKIPPED, State.SUCCESS, State.UP_FOR_RESCHEDULE]:
             ti.run()
         ti.refresh_from_db()
 
         assert ti.next_method is None
         assert ti.next_kwargs is None
-        assert ti.state == State.UP_FOR_RETRY
+        assert ti.state == state
 
     @freeze_time('2021-09-19 04:56:35', as_kwarg='frozen_time')
     def test_retry_delay(self, dag_maker, frozen_time=None):

Reply via email to