This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 07173625db126208bfb756fff3649d266089b541
Author: Amogh Desai <[email protected]>
AuthorDate: Tue Oct 28 15:08:30 2025 +0530

    Respect task retries for signal killed tasks (#55767)
    
    (cherry picked from commit de0c78e782c0cd4a7c2f592d2a431d1dc8fd3cac)
---
 .../src/airflow/sdk/execution_time/supervisor.py   | 11 ++++
 .../task_sdk/execution_time/test_supervisor.py     | 70 ++++++++++++++++++++++
 2 files changed, 81 insertions(+)

diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index fafd56209bd..70e293cbbe4 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -921,6 +921,9 @@ class ActivitySubprocess(WatchedSubprocess):
     _last_successful_heartbeat: float = attrs.field(default=0, init=False)
     _last_heartbeat_attempt: float = attrs.field(default=0, init=False)
 
+    _should_retry: bool = attrs.field(default=False, init=False)
+    """Whether the task should retry or not as decided by the API server."""
+
     # After the failure of a heartbeat, we'll increment this counter. If it 
reaches `MAX_FAILED_HEARTBEATS`, we
     # will kill theprocess. This is to handle temporary network issues etc. 
ensuring that the process
     # does not hang around forever.
@@ -960,6 +963,7 @@ class ActivitySubprocess(WatchedSubprocess):
             # message. But before we do that, we need to tell the server it's 
started (so it has the chance to
             # tell us "no, stop!" for any reason)
             ti_context = self.client.task_instances.start(ti.id, self.pid, 
start_date)
+            self._should_retry = ti_context.should_retry
             self._last_successful_heartbeat = time.monotonic()
         except Exception:
             # On any error kill that subprocess!
@@ -1158,6 +1162,13 @@ class ActivitySubprocess(WatchedSubprocess):
             return self._terminal_state or TaskInstanceState.SUCCESS
         if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED:
             return SERVER_TERMINATED
+
+        # Any negative exit code indicates a signal kill
+        # We consider all signal kills as potentially retryable
+        # since they're often transient issues that could succeed on retry
+        if self._exit_code < 0 and self._should_retry:
+            return TaskInstanceState.UP_FOR_RETRY
+
         return TaskInstanceState.FAILED
 
     def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, 
req_id: int):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 6001edf737f..10b945d7fd5 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -2430,6 +2430,76 @@ def test_remote_logging_conn(remote_logging, 
remote_conn, expected_env, monkeypa
             assert connection_available["conn_uri"] is not None, "Connection 
URI was None during upload"
 
 
+class TestSignalRetryLogic:
+    """Test signal based retry logic in ActivitySubprocess."""
+
+    @pytest.mark.parametrize(
+        "signal",
+        [
+            signal.SIGTERM,
+            signal.SIGKILL,
+            signal.SIGABRT,
+            signal.SIGSEGV,
+        ],
+    )
+    def test_signals_with_retry(self, mocker, signal):
+        """Test that signals with task retries."""
+        mock_watched_subprocess = ActivitySubprocess(
+            process_log=mocker.MagicMock(),
+            id=TI_ID,
+            pid=12345,
+            stdin=mocker.Mock(),
+            process=mocker.Mock(),
+            client=mocker.Mock(),
+        )
+
+        mock_watched_subprocess._exit_code = -signal
+        mock_watched_subprocess._should_retry = True
+
+        result = mock_watched_subprocess.final_state
+        assert result == TaskInstanceState.UP_FOR_RETRY
+
+    @pytest.mark.parametrize(
+        "signal",
+        [
+            signal.SIGKILL,
+            signal.SIGTERM,
+            signal.SIGABRT,
+            signal.SIGSEGV,
+        ],
+    )
+    def test_signals_without_retry_always_fail(self, mocker, signal):
+        """Test that signals without task retries enabled always fail."""
+        mock_watched_subprocess = ActivitySubprocess(
+            process_log=mocker.MagicMock(),
+            id=TI_ID,
+            pid=12345,
+            stdin=mocker.Mock(),
+            process=mocker.Mock(),
+            client=mocker.Mock(),
+        )
+        mock_watched_subprocess._should_retry = False
+        mock_watched_subprocess._exit_code = -signal
+
+        result = mock_watched_subprocess.final_state
+        assert result == TaskInstanceState.FAILED
+
+    def test_non_signal_exit_code_goes_to_failed(self, mocker):
+        """Test that non signal exit codes go to failed regardless of task 
retries."""
+        mock_watched_subprocess = ActivitySubprocess(
+            process_log=mocker.MagicMock(),
+            id=TI_ID,
+            pid=12345,
+            stdin=mocker.Mock(),
+            process=mocker.Mock(),
+            client=mocker.Mock(),
+        )
+        mock_watched_subprocess._exit_code = 1
+        mock_watched_subprocess._should_retry = True
+
+        assert mock_watched_subprocess.final_state == TaskInstanceState.FAILED
+
+
 def test_remote_logging_conn_caches_connection_not_client(monkeypatch):
     """Test that connection caching doesn't retain API client references."""
     import gc

Reply via email to