This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a74b5f0694 ECS Executor: Set tasks to RUNNING state once active 
(#39212)
a74b5f0694 is described below

commit a74b5f069481e1a2339cfd95e137619b16390906
Author: Niko Oliveira <[email protected]>
AuthorDate: Mon May 6 10:29:40 2024 -0700

    ECS Executor: Set tasks to RUNNING state once active (#39212)
    
    Tasks were previously being put into QUEUED state after they were active
    in the ECS executor. This was to store executor state for task adoption
    but had the side effect of removing them from the list of running task
    instances (which has other knock-on effects). Instead, change tasks into
    the RUNNING state, and do not remove them from the list of running
    tasks.
    
    * Update change_state usage in debug and celery executor
    
    - DebugExecutor: was overriding the change_state method from the base
    executor, but changing no behaviour, so move to using the base executor
    implementation
    - CeleryExecutor: Plumb through the new param so that the signature
    matches the base executor
    
    * Call running_state in try/catch for backcompat
---
 airflow/executors/base_executor.py                 | 25 +++++++---
 airflow/executors/debug_executor.py                |  5 --
 airflow/jobs/scheduler_job_runner.py               |  9 +++-
 .../amazon/aws/executors/ecs/ecs_executor.py       |  7 ++-
 .../providers/celery/executors/celery_executor.py  | 10 +++-
 tests/executors/test_base_executor.py              | 53 +++++++++++++++++++++-
 .../amazon/aws/executors/ecs/test_ecs_executor.py  |  6 ++-
 7 files changed, 97 insertions(+), 18 deletions(-)

diff --git a/airflow/executors/base_executor.py 
b/airflow/executors/base_executor.py
index a091e0c3f9..d2921ccf67 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -303,19 +303,23 @@ class BaseExecutor(LoggingMixin):
             self.execute_async(key=key, command=command, queue=queue, 
executor_config=executor_config)
             self.running.add(key)
 
-    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, 
info=None) -> None:
+    def change_state(
+        self, key: TaskInstanceKey, state: TaskInstanceState, info=None, 
remove_running=True
+    ) -> None:
         """
         Change state of the task.
 
-        :param info: Executor information for the task instance
         :param key: Unique key for the task instance
         :param state: State to set for the task.
+        :param info: Executor information for the task instance
+        :param remove_running: Whether or not to remove the TI key from 
running set
         """
         self.log.debug("Changing state: %s", key)
-        try:
-            self.running.remove(key)
-        except KeyError:
-            self.log.debug("Could not find key: %s", key)
+        if remove_running:
+            try:
+                self.running.remove(key)
+            except KeyError:
+                self.log.debug("Could not find key: %s", key)
         self.event_buffer[key] = state, info
 
     def fail(self, key: TaskInstanceKey, info=None) -> None:
@@ -345,6 +349,15 @@ class BaseExecutor(LoggingMixin):
         """
         self.change_state(key, TaskInstanceState.QUEUED, info)
 
+    def running_state(self, key: TaskInstanceKey, info=None) -> None:
+        """
+        Set running state for the event.
+
+        :param info: Executor information for the task instance
+        :param key: Unique key for the task instance
+        """
+        self.change_state(key, TaskInstanceState.RUNNING, info, 
remove_running=False)
+
     def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, 
EventBufferValueType]:
         """
         Return and flush the event buffer.
diff --git a/airflow/executors/debug_executor.py 
b/airflow/executors/debug_executor.py
index a315ee31f9..80fb673cab 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -155,8 +155,3 @@ class DebugExecutor(BaseExecutor):
 
     def terminate(self) -> None:
         self._terminated.set()
-
-    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, 
info=None) -> None:
-        self.log.debug("Popping %s from executor task queue.", key)
-        self.running.remove(key)
-        self.event_buffer[key] = state, info
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 631de5692e..49a065b5f5 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -692,7 +692,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             ti_primary_key_to_try_number_map[ti_key.primary] = 
ti_key.try_number
 
             self.log.info("Received executor event with state %s for task 
instance %s", state, ti_key)
-            if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, 
TaskInstanceState.QUEUED):
+            if state in (
+                TaskInstanceState.FAILED,
+                TaskInstanceState.SUCCESS,
+                TaskInstanceState.QUEUED,
+                TaskInstanceState.RUNNING,
+            ):
                 tis_with_right_state.append(ti_key)
 
         # Return if no finished tasks
@@ -711,7 +716,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             buffer_key = ti.key.with_try_number(try_number)
             state, info = event_buffer.pop(buffer_key)
 
-            if state == TaskInstanceState.QUEUED:
+            if state in (TaskInstanceState.QUEUED, TaskInstanceState.RUNNING):
                 ti.external_executor_id = info
                 self.log.info("Setting external_id for %s to %s", ti, info)
                 continue
diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py 
b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index 6730e57168..c5e7e3d6b4 100644
--- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -400,7 +400,12 @@ class AwsEcsExecutor(BaseExecutor):
             else:
                 task = run_task_response["tasks"][0]
                 self.active_workers.add_task(task, task_key, queue, cmd, 
exec_config, attempt_number)
-                self.queued(task_key, task.task_arn)
+                try:
+                    self.running_state(task_key, task.task_arn)
+                except AttributeError:
+                    # running_state is newly added, and only needed to support 
task adoption (an optional
+                    # executor feature).
+                    pass
         if failure_reasons:
             self.log.error(
                 "Pending ECS tasks failed to launch for the following reasons: 
%s. Retrying later.",
diff --git a/airflow/providers/celery/executors/celery_executor.py 
b/airflow/providers/celery/executors/celery_executor.py
index 0b4293cde7..1d4342f294 100644
--- a/airflow/providers/celery/executors/celery_executor.py
+++ b/airflow/providers/celery/executors/celery_executor.py
@@ -368,8 +368,14 @@ class CeleryExecutor(BaseExecutor):
             if state:
                 self.update_task_state(key, state, info)
 
-    def change_state(self, key: TaskInstanceKey, state: TaskInstanceState, 
info=None) -> None:
-        super().change_state(key, state, info)
+    def change_state(
+        self, key: TaskInstanceKey, state: TaskInstanceState, info=None, 
remove_running=True
+    ) -> None:
+        try:
+            super().change_state(key, state, info, 
remove_running=remove_running)
+        except AttributeError:
+            # Earlier versions of the BaseExecutor don't accept the 
remove_running parameter for this method
+            super().change_state(key, state, info)
         self.tasks.pop(key, None)
 
     def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) 
-> None:
diff --git a/tests/executors/test_base_executor.py 
b/tests/executors/test_base_executor.py
index 432ed867ac..0e75751faf 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -33,7 +33,7 @@ from airflow.executors.base_executor import BaseExecutor, 
RunningRetryAttemptTyp
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
 from airflow.utils import timezone
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
 
 
 def test_supports_sentry():
@@ -363,3 +363,54 @@ def test_running_retry_attempt_type(loop_duration, 
total_tries):
         assert a.elapsed > min_seconds_for_test
     assert a.total_tries == total_tries
     assert a.tries_after_min == 1
+
+
+def test_state_fail():
+    executor = BaseExecutor()
+    key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
+    executor.running.add(key)
+    info = "info"
+    executor.fail(key, info=info)
+    assert not executor.running
+    assert executor.event_buffer[key] == (TaskInstanceState.FAILED, info)
+
+
+def test_state_success():
+    executor = BaseExecutor()
+    key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
+    executor.running.add(key)
+    info = "info"
+    executor.success(key, info=info)
+    assert not executor.running
+    assert executor.event_buffer[key] == (TaskInstanceState.SUCCESS, info)
+
+
+def test_state_queued():
+    executor = BaseExecutor()
+    key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
+    executor.running.add(key)
+    info = "info"
+    executor.queued(key, info=info)
+    assert not executor.running
+    assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)
+
+
+def test_state_generic():
+    executor = BaseExecutor()
+    key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
+    executor.running.add(key)
+    info = "info"
+    executor.queued(key, info=info)
+    assert not executor.running
+    assert executor.event_buffer[key] == (TaskInstanceState.QUEUED, info)
+
+
+def test_state_running():
+    executor = BaseExecutor()
+    key = TaskInstanceKey("my_dag1", "my_task1", timezone.utcnow(), 1)
+    executor.running.add(key)
+    info = "info"
+    executor.running_state(key, info=info)
+    # Running state should not remove a command as running
+    assert executor.running
+    assert executor.event_buffer[key] == (TaskInstanceState.RUNNING, info)
diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py 
b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
index fd7bf67726..524360dbac 100644
--- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -367,7 +367,8 @@ class TestEcsExecutorTask:
 class TestAwsEcsExecutor:
     """Tests the AWS ECS Executor."""
 
-    def test_execute(self, mock_airflow_key, mock_executor):
+    
@mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state")
+    def test_execute(self, change_state_mock, mock_airflow_key, mock_executor):
         """Test execution from end-to-end."""
         airflow_key = mock_airflow_key()
 
@@ -393,6 +394,9 @@ class TestAwsEcsExecutor:
         # Task is stored in active worker.
         assert 1 == len(mock_executor.active_workers)
         assert ARN1 in 
mock_executor.active_workers.task_by_key(airflow_key).task_arn
+        change_state_mock.assert_called_once_with(
+            airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False
+        )
 
     @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
     def test_success_execute_api_exception(self, mock_backoff, mock_executor):

Reply via email to