ashb closed pull request #3994: [AIRFLOW-3136] Add retry_number to TaskInstance 
Key property to avoid race condition
URL: https://github.com/apache/incubator-airflow/pull/3994
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/executors/kubernetes_executor.py 
b/airflow/contrib/executors/kubernetes_executor.py
index de1f9f4235..cf58169345 100644
--- a/airflow/contrib/executors/kubernetes_executor.py
+++ b/airflow/contrib/executors/kubernetes_executor.py
@@ -332,7 +332,7 @@ def run_next(self, next_job):
         """
         self.log.info('Kubernetes job is %s', str(next_job))
         key, command, kube_executor_config = next_job
-        dag_id, task_id, execution_date = key
+        dag_id, task_id, execution_date, try_number = key
         self.log.debug("Kubernetes running for command %s", command)
         self.log.debug("Kubernetes launching image %s", 
self.kube_config.kube_image)
         pod = self.worker_configuration.make_pod(
@@ -453,7 +453,8 @@ def _labels_to_key(self, labels):
         try:
             return (
                 labels['dag_id'], labels['task_id'],
-                
self._label_safe_datestring_to_datetime(labels['execution_date']))
+                
self._label_safe_datestring_to_datetime(labels['execution_date']),
+                labels['try_number'])
         except Exception as e:
             self.log.warn(
                 'Error while converting labels to key; labels: %s; exception: 
%s',
@@ -612,7 +613,7 @@ def _change_state(self, key, state, pod_id):
                 self.log.debug('Could not find key: %s', str(key))
                 pass
         self.event_buffer[key] = state
-        (dag_id, task_id, ex_time) = key
+        (dag_id, task_id, ex_time, try_number) = key
         item = self._session.query(TaskInstance).filter_by(
             dag_id=dag_id,
             task_id=task_id,
diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index a989dc4408..0979ba07fe 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -175,7 +175,7 @@ def get_event_buffer(self, dag_ids=None):
             self.event_buffer = dict()
         else:
             for key in list(self.event_buffer.keys()):
-                dag_id, _, _ = key
+                dag_id, _, _, _ = key
                 if dag_id in dag_ids:
                     cleared_events[key] = self.event_buffer.pop(key)
 
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 48e15f758d..15b9c65b82 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -1442,10 +1442,10 @@ def _process_executor_events(self, simple_dag_bag, 
session=None):
         TI = models.TaskInstance
         for key, state in 
list(self.executor.get_event_buffer(simple_dag_bag.dag_ids)
                                    .items()):
-            dag_id, task_id, execution_date = key
+            dag_id, task_id, execution_date, try_number = key
             self.log.info(
-                "Executor reports %s.%s execution_date=%s as %s",
-                dag_id, task_id, execution_date, state
+                "Executor reports %s.%s execution_date=%s as %s for try_number 
%s",
+                dag_id, task_id, execution_date, state, try_number
             )
             if state == State.FAILED or state == State.SUCCESS:
                 qry = session.query(TI).filter(TI.dag_id == dag_id,
@@ -1457,7 +1457,7 @@ def _process_executor_events(self, simple_dag_bag, 
session=None):
                     continue
 
                 # TODO: should we fail RUNNING as well, as we do in Backfills?
-                if ti.state == State.QUEUED:
+                if ti.try_number == try_number and ti.state == State.QUEUED:
                     msg = ("Executor reports task instance {} finished ({}) "
                            "although the task says its {}. Was the task "
                            "killed externally?".format(ti, state, ti.state))
diff --git a/airflow/models.py b/airflow/models.py
index 97a0fa92fc..0644dcdee5 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -1230,7 +1230,7 @@ def key(self):
         """
         Returns a tuple that identifies the task instance uniquely
         """
-        return self.dag_id, self.task_id, self.execution_date
+        return self.dag_id, self.task_id, self.execution_date, self.try_number
 
     @provide_session
     def set_state(self, state, session=None):
diff --git a/tests/executors/test_base_executor.py 
b/tests/executors/test_base_executor.py
index f640a75e01..29a953ece9 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -30,10 +30,10 @@ def test_get_event_buffer(self):
         executor = BaseExecutor()
 
         date = datetime.utcnow()
-
-        key1 = ("my_dag1", "my_task1", date)
-        key2 = ("my_dag2", "my_task1", date)
-        key3 = ("my_dag2", "my_task2", date)
+        try_number = 1
+        key1 = ("my_dag1", "my_task1", date, try_number)
+        key2 = ("my_dag2", "my_task1", date, try_number)
+        key3 = ("my_dag2", "my_task2", date, try_number)
         state = State.SUCCESS
         executor.event_buffer[key1] = state
         executor.event_buffer[key2] = state
diff --git a/tests/jobs.py b/tests/jobs.py
index bb714bd201..d9eb28c68e 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -73,7 +73,7 @@
 
 DEV_NULL = '/dev/null'
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
-
+TRY_NUMBER = 1
 # Include the words "airflow" and "dag" in the file contents,
 # tricking airflow into thinking these
 # files contain a DAG (otherwise Airflow will skip them)
@@ -2316,7 +2316,7 @@ def test_scheduler_process_task_instances(self):
         scheduler._process_task_instances(dag, queue=queue)
 
         queue.append.assert_called_with(
-            (dag.dag_id, dag_task1.task_id, DEFAULT_DATE)
+            (dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER)
         )
 
     def test_scheduler_do_not_schedule_removed_task(self):
@@ -2584,7 +2584,7 @@ def 
test_scheduler_max_active_runs_respected_after_clear(self):
         scheduler._process_task_instances(dag, queue=queue)
 
         queue.append.assert_called_with(
-            (dag.dag_id, dag_task1.task_id, DEFAULT_DATE)
+            (dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER)
         )
 
     @patch.object(TI, 'pool_full')
@@ -2937,13 +2937,18 @@ def run_with_error(task):
         do_schedule()
         self.assertTrue(executor.has_task(ti))
         ti.refresh_from_db()
-        self.assertEqual(ti.state, State.SCHEDULED)
+        # removing self.assertEqual(ti.state, State.SCHEDULED)
+        # as scheduler will move state from SCHEDULED to QUEUED
 
         # now the executor has cleared and it should be allowed the re-queue
         executor.queued_tasks.clear()
         do_schedule()
         ti.refresh_from_db()
         self.assertEqual(ti.state, State.QUEUED)
+        # calling below again in order to ensure with try_number 2,
+        # scheduler doesn't put task in queue
+        do_schedule()
+        self.assertEquals(1, len(executor.queued_tasks))
 
     @unittest.skipUnless("INTEGRATION" in os.environ, "Can only run end to 
end")
     def test_retry_handling_job(self):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to