This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ea3be1a602 Use time not tries for queued & running re-checks. (#28586)
ea3be1a602 is described below

commit ea3be1a602b3e109169c6e90e555a418e2649f9a
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Dec 29 14:25:58 2022 -0800

    Use time not tries for queued & running re-checks. (#28586)
    
    Co-authored-by: Jed Cunningham 
<[email protected]>
---
 airflow/executors/base_executor.py    | 66 ++++++++++++++++++++-----
 tests/executors/test_base_executor.py | 93 ++++++++++++++++++++++++++++-------
 2 files changed, 131 insertions(+), 28 deletions(-)

diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index 83410c90e2..3ea5eeba25 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -17,10 +17,15 @@
 """Base executor - this is the base class for all the implemented executors."""
 from __future__ import annotations
 
+import logging
 import sys
 import warnings
-from collections import OrderedDict
-from typing import Any, Counter, List, Optional, Sequence, Tuple
+from collections import OrderedDict, defaultdict
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import Any, List, Optional, Sequence, Tuple
+
+import pendulum
 
 from airflow.callbacks.base_callback_sink import BaseCallbackSink
 from airflow.callbacks.callback_requests import CallbackRequest
@@ -33,8 +38,6 @@ from airflow.utils.state import State
 
 PARALLELISM: int = conf.getint("core", "PARALLELISM")
 
-QUEUEING_ATTEMPTS = 5
-
 # Command to execute - list of strings
 # the first element is always "airflow".
 # It should be result of TaskInstance.generate_command method.q
@@ -54,6 +57,44 @@ EventBufferValueType = Tuple[Optional[str], Any]
 # Task tuple to send to be executed
 TaskTuple = Tuple[TaskInstanceKey, CommandType, Optional[str], Optional[Any]]
 
+log = logging.getLogger(__name__)
+
+
+@dataclass
+class RunningRetryAttemptType:
+    """
+    For keeping track of attempts to queue again when task still apparently 
running.
+
+    We don't want to slow down the loop, so we don't block, but we allow it to 
be
+    re-checked for at least MIN_SECONDS seconds.
+    """
+
+    MIN_SECONDS = 10
+    total_tries: int = field(default=0, init=False)
+    tries_after_min: int = field(default=0, init=False)
+    first_attempt_time: datetime = field(default_factory=lambda: 
pendulum.now("UTC"), init=False)
+
+    @property
+    def elapsed(self):
+        """Seconds since first attempt"""
+        return (pendulum.now("UTC") - self.first_attempt_time).total_seconds()
+
+    def can_try_again(self):
+        """
+        If there has been at least one try greater than MIN_SECONDS after 
first attempt,
+        then return False.  Otherwise, return True.
+        """
+        if self.tries_after_min > 0:
+            return False
+
+        self.total_tries += 1
+
+        elapsed = self.elapsed
+        if elapsed > self.MIN_SECONDS:
+            self.tries_after_min += 1
+        log.debug("elapsed=%s tries=%s", elapsed, self.total_tries)
+        return True
+
 
 class BaseExecutor(LoggingMixin):
     """
@@ -77,7 +118,7 @@ class BaseExecutor(LoggingMixin):
         self.queued_tasks: OrderedDict[TaskInstanceKey, 
QueuedTaskInstanceType] = OrderedDict()
         self.running: set[TaskInstanceKey] = set()
         self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
-        self.attempts: Counter[TaskInstanceKey] = Counter()
+        self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = 
defaultdict(RunningRetryAttemptType)
 
     def __repr__(self):
         return f"{self.__class__.__name__}(parallelism={self.parallelism})"
@@ -212,16 +253,19 @@ class BaseExecutor(LoggingMixin):
             # removed from the running set in the meantime.
             if key in self.running:
                 attempt = self.attempts[key]
-                if attempt < QUEUEING_ATTEMPTS - 1:
-                    self.attempts[key] = attempt + 1
-                    self.log.info("task %s is still running", key)
+                if attempt.can_try_again():
+                    # if it hasn't been much time since first check, let it be 
checked again next time
+                    self.log.info("queued but still running; attempt=%s 
task=%s", attempt.total_tries, key)
                     continue
-
-                # We give up and remove the task from the queue.
-                self.log.error("could not queue task %s (still running after 
%d attempts)", key, attempt)
+                # Otherwise, we give up and remove the task from the queue.
+                self.log.error(
+                    "could not queue task %s (still running after %d 
attempts)", key, attempt.total_tries
+                )
                 del self.attempts[key]
                 del self.queued_tasks[key]
             else:
+                if key in self.attempts:
+                    del self.attempts[key]
                 task_tuples.append((key, command, queue, ti.executor_config))
 
         if task_tuples:
diff --git a/tests/executors/test_base_executor.py 
b/tests/executors/test_base_executor.py
index 4d7553710c..c88bd333dc 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -20,9 +20,12 @@ from __future__ import annotations
 from datetime import timedelta
 from unittest import mock
 
+import pendulum
+import pytest
+import time_machine
 from pytest import mark
 
-from airflow.executors.base_executor import QUEUEING_ATTEMPTS, BaseExecutor
+from airflow.executors.base_executor import BaseExecutor, 
RunningRetryAttemptType
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.taskinstance import TaskInstanceKey
 from airflow.utils import timezone
@@ -104,31 +107,63 @@ def test_trigger_queued_tasks(dag_maker, open_slots):
     assert executor.execute_async.call_count == open_slots
 
 
[email protected]("change_state_attempt", range(QUEUEING_ATTEMPTS + 2))
-def test_trigger_running_tasks(dag_maker, change_state_attempt):
[email protected](
+    "can_try_num, change_state_num, second_exec",
+    [
+        (2, 3, False),
+        (3, 3, True),
+        (4, 3, True),
+    ],
+)
[email protected]("airflow.executors.base_executor.RunningRetryAttemptType.can_try_again")
+def test_trigger_running_tasks(can_try_mock, dag_maker, can_try_num, 
change_state_num, second_exec):
+    can_try_mock.side_effect = [True for _ in range(can_try_num)] + [False]
     executor, dagrun = setup_trigger_tasks(dag_maker)
     open_slots = 100
     executor.trigger_tasks(open_slots)
     expected_calls = len(dagrun.task_instances)  # initially `execute_async` 
called for each task
-    assert len(executor.execute_async.mock_calls) == expected_calls
+    assert executor.execute_async.call_count == expected_calls
 
     # All the tasks are now "running", so while we enqueue them again here,
     # they won't be executed again until the executor has been notified of a 
state change.
-    enqueue_tasks(executor, dagrun)
+    ti = dagrun.task_instances[0]
+    assert ti.key in executor.running
+    assert ti.key not in executor.queued_tasks
+    executor.queue_command(ti, ["airflow"])
+
+    # this is the problem we're dealing with: ti.key both queued and running
+    assert ti.key in executor.queued_tasks and ti.key in executor.running
+    assert len(executor.attempts) == 0
+    executor.trigger_tasks(open_slots)
 
-    for attempt in range(QUEUEING_ATTEMPTS + 2):
-        # On the configured attempt, we notify the executor that the task has 
succeeded.
-        if attempt == change_state_attempt:
-            executor.change_state(dagrun.task_instances[0].key, State.SUCCESS)
-            # If we have not exceeded QUEUEING_ATTEMPTS, we should expect an 
additional "execute" call
-            if attempt < QUEUEING_ATTEMPTS:
-                expected_calls += 1
+    # first trigger call after queueing again creates an attempt object
+    assert len(executor.attempts) == 1
+    assert ti.key in executor.attempts
+
+    for attempt in range(2, change_state_num + 2):
         executor.trigger_tasks(open_slots)
-        assert len(executor.execute_async.mock_calls) == expected_calls
-    if change_state_attempt < QUEUEING_ATTEMPTS:
-        assert len(executor.execute_async.mock_calls) == 
len(dagrun.task_instances) + 1
-    else:
-        assert len(executor.execute_async.mock_calls) == 
len(dagrun.task_instances)
+        if attempt <= min(can_try_num, change_state_num):
+            assert ti.key in executor.queued_tasks and ti.key in 
executor.running
+        # On the configured attempt, we notify the executor that the task has 
succeeded.
+        if attempt == change_state_num:
+            executor.change_state(ti.key, State.SUCCESS)
+            assert ti.key not in executor.running
+    # retry was ok when state changed, ti.key will be in running (for the 
second time)
+    if can_try_num >= change_state_num:
+        assert ti.key in executor.running
+    else:  # otherwise, it won't be
+        assert ti.key not in executor.running
+    # either way, ti.key not in queued -- it was either removed because never 
left running
+    # or it was moved out when run 2nd time
+    assert ti.key not in executor.queued_tasks
+    assert not executor.attempts
+
+    # we expect one more "execute_async" if TI was marked successful
+    # this would move it out of running set and free the queued TI to be 
executed again
+    if second_exec is True:
+        expected_calls += 1
+
+    assert executor.execute_async.call_count == expected_calls
 
 
 def test_validate_airflow_tasks_run_command(dag_maker):
@@ -136,3 +171,27 @@ def test_validate_airflow_tasks_run_command(dag_maker):
     tis = dagrun.task_instances
     dag_id, task_id = 
BaseExecutor.validate_airflow_tasks_run_command(tis[0].command_as_list())
     assert dag_id == dagrun.dag_id and task_id == tis[0].task_id
+
+
[email protected]("loop_duration, total_tries", [(0.5, 12), (1.0, 7), 
(1.7, 4), (10, 2)])
+def test_running_retry_attempt_type(loop_duration, total_tries):
+    """
+    Verify can_try_again returns True until at least 5 seconds have passed.
+
+    For faster loops, we total tries will be higher.  If loops take longer 
than 5 seconds, still should
+    end up trying 2 times.
+    """
+    min_seconds_for_test = 5
+
+    with time_machine.travel(pendulum.now("UTC"), tick=False) as t:
+
+        # set MIN_SECONDS so tests don't break if the value is changed
+        RunningRetryAttemptType.MIN_SECONDS = min_seconds_for_test
+        a = RunningRetryAttemptType()
+        while True:
+            if not a.can_try_again():
+                break
+            t.shift(loop_duration)
+        assert a.elapsed > min_seconds_for_test
+    assert a.total_tries == total_tries
+    assert a.tries_after_min == 1

Reply via email to