This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e9b59fc2592 Ensure that the Task SDK regularly sends heartbeats for 
running tasks (#44162)
e9b59fc2592 is described below

commit e9b59fc2592d9a70ec4bfbb5052b9a729205e07e
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Mon Nov 18 22:13:36 2024 +0000

    Ensure that the Task SDK regularly sends heartbeats for running tasks 
(#44162)
    
    There is more nuance and edge cases to support, but this is the crux of the
    behaviour we want.
    
    This fixes the payload to be what the server expects, and fixes the URL 
suffix
    to match latest changes too
---
 task_sdk/src/airflow/sdk/api/client.py             | 10 +++---
 .../src/airflow/sdk/execution_time/supervisor.py   |  4 +--
 task_sdk/tests/execution_time/test_supervisor.py   | 38 ++++++++++++++++++++++
 3 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/api/client.py 
b/task_sdk/src/airflow/sdk/api/client.py
index f52e6dde622..eaae71dae3f 100644
--- a/task_sdk/src/airflow/sdk/api/client.py
+++ b/task_sdk/src/airflow/sdk/api/client.py
@@ -32,6 +32,7 @@ from airflow.sdk.api.datamodels._generated import (
     ConnectionResponse,
     TerminalTIState,
     TIEnterRunningPayload,
+    TIHeartbeatInfo,
     TITerminalStatePayload,
     ValidationError as RemoteValidationError,
 )
@@ -109,16 +110,17 @@ class TaskInstanceOperations:
         """Tell the API server that this TI has started running."""
         body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), 
unixname=getuser(), start_date=when)
 
-        self.client.patch(f"task-instance/{id}/state", 
content=body.model_dump_json())
+        self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
     def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
         """Tell the API server that this TI has reached a terminal state."""
         body = TITerminalStatePayload(end_date=when, 
state=TerminalTIState(state))
 
-        self.client.patch(f"task-instance/{id}/state", 
content=body.model_dump_json())
+        self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
-    def heartbeat(self, id: uuid.UUID):
-        self.client.put(f"task-instance/{id}/heartbeat")
+    def heartbeat(self, id: uuid.UUID, pid: int):
+        body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
+        self.client.put(f"task-instances/{id}/heartbeat", 
content=body.model_dump_json())
 
 
 class ConnectionOperations:
diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index 6ecd8ff5698..7faddebb25c 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -129,7 +129,7 @@ def _reopen_std_io_handles(child_stdin, child_stdout, 
child_stderr):
         sys.stderr = sys.__stderr__
 
     # Ensure that sys.stdout et al (and the underlying filehandles for C 
libraries etc) are connected to the
-    # pipes form the supervisor
+    # pipes from the supervisor
 
     for handle_name, sock, mode, close in (
         ("stdin", child_stdin, "r", True),
@@ -403,7 +403,7 @@ class WatchedSubprocess:
                     continue
 
                 try:
-                    self.client.task_instances.heartbeat(self.ti_id)
+                    self.client.task_instances.heartbeat(self.ti_id, 
pid=self._process.pid)
                     self._last_heartbeat = time.monotonic()
                 except Exception:
                     log.warning("Couldn't heartbeat", exc_info=True)
diff --git a/task_sdk/tests/execution_time/test_supervisor.py 
b/task_sdk/tests/execution_time/test_supervisor.py
index 428ade1c35a..f741d7c21b5 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -22,7 +22,10 @@ import logging
 import os
 import signal
 import sys
+from time import sleep
+from typing import TYPE_CHECKING
 from unittest.mock import MagicMock
+from uuid import UUID
 
 import pytest
 import structlog
@@ -33,6 +36,9 @@ from airflow.sdk.api.datamodels._generated import TaskInstance
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess
 from airflow.utils import timezone as tz
 
+if TYPE_CHECKING:
+    import kgb
+
 
 def lineno():
     """Returns the current line number in our program."""
@@ -153,3 +159,35 @@ class TestWatchedSubprocess:
         rc = proc.wait()
 
         assert rc == -9
+
+    def test_regular_heartbeat(self, spy_agency: kgb.SpyAgency, monkeypatch):
+        """Test that the WatchedSubprocess class regularly sends heartbeat 
requests, up to a certain frequency"""
+        import airflow.sdk.execution_time.supervisor
+
+        monkeypatch.setattr(airflow.sdk.execution_time.supervisor, 
"FASTEST_HEARTBEAT_INTERVAL", 0.1)
+
+        def subprocess_main():
+            sys.stdin.readline()
+
+            for _ in range(5):
+                print("output", flush=True)
+                sleep(0.05)
+
+        id = UUID("4d828a62-a417-4936-a7a6-2b3fabacecab")
+        spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat)
+        proc = WatchedSubprocess.start(
+            path=os.devnull,
+            ti=TaskInstance(
+                id=id,
+                task_id="b",
+                dag_id="c",
+                run_id="d",
+                try_number=1,
+            ),
+            client=sdk_client.Client(base_url="", dry_run=True, token=""),
+            target=subprocess_main,
+        )
+        assert proc.wait() == 0
+        assert spy.called_with(id, pid=proc.pid)  # noqa: PGH005
+        # The exact number we get will depend on timing behaviour, so be a 
little lenient
+        assert 2 <= len(spy.calls) <= 4

Reply via email to