jason810496 commented on code in PR #57198:
URL: https://github.com/apache/airflow/pull/57198#discussion_r2462451998


##########
task-sdk-tests/tests/task_sdk_tests/test_task_instance_operations.py:
##########
@@ -98,24 +100,91 @@ def test_ti_get_task_states(sdk_client, dag_info):
     console.print("[green]✅ Task states test passed!")
 
 
-def test_ti_finish_failed(sdk_client, task_instance_id):
+def test_ti_set_rtif(sdk_client, task_instance_id):
+    """
+    Test setting Rendered Task Instance Fields (RTIF).
+    """
+    console.print("[yellow]Setting Rendered Task Instance Fields...")
+
+    rtif_data = {
+        "rendered_field_1": "test_value_1",
+        "rendered_field_2": "1234",
+    }
+
+    response = sdk_client.task_instances.set_rtif(task_instance_id, rtif_data)
+
+    console.print(" RTIF Response ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Status:[/] {response.ok}")
+    console.print(f"[bright_blue]Task Instance ID:[/] {task_instance_id}")
+    console.print(f"[bright_blue]Fields Set:[/] {list(rtif_data.keys())}")
+    console.print("=" * 72)
+
+    assert response.ok is True
+    console.print("[green]✅ RTIF test passed!")
+
+
+def test_ti_heartbeat(sdk_client, task_instance_id, core_api_headers, 
dag_info, monkeypatch):
     """
-    Test finishing a task instance with failed state.
+    Test sending heartbeat for a running task instance.
 
-    This is the LAST test and will terminate the long-running task.
-    It must run after all other tests that need the task to be running.
+    This test fetches the actual worker's PID and hostname from core API,
+    then patches get_hostname() to return the worker's hostname, allowing
+    the heartbeat to be accepted by the server.
     """
-    console.print("[yellow]Finishing task instance as FAILED...")
+    console.print("[yellow]Getting task instance details for heartbeat...")
 
-    # Finish the task with failed state
+    ti_url = (
+        f"http://localhost:8080/api/v2/dags/{dag_info['dag_id']}/"
+        
f"dagRuns/{dag_info['dag_run_id']}/taskInstances/long_running_task/tries/1"
+    )
+    ti_response = requests.get(ti_url, headers=core_api_headers, timeout=10)
+    ti_response.raise_for_status()
+
+    ti_data = ti_response.json()
+    worker_hostname = ti_data.get("hostname")
+    worker_pid = ti_data.get("pid")
+
+    console.print(" Worker Information ".center(72, "="))
+    console.print(f"[bright_blue]Worker Hostname:[/] {worker_hostname}")
+    console.print(f"[bright_blue]Worker PID:[/] {worker_pid}")
+    console.print("=" * 72)
+
+    assert worker_hostname is not None
+    assert worker_pid is not None
+
+    # Patch get_hostname to return the worker's hostname
+    from airflow.sdk.api import client as sdk_client_module
+
+    monkeypatch.setattr(sdk_client_module, "get_hostname", lambda: 
worker_hostname)
+
+    console.print("[yellow]Sending heartbeat with worker's PID/hostname...")
+
+    sdk_client.task_instances.heartbeat(task_instance_id, pid=worker_pid)
+

Review Comment:
   Do we need to validate the `last_heartbeat_at` of TI after calling 
`sdk_client.task_instances.heartbeat` ?
   However, it seems the `last_heartbeat_at` field is not included in Data 
Model of `get_task_instance_try_details` route.



##########
task-sdk-tests/tests/task_sdk_tests/test_task_instance_operations.py:
##########
@@ -98,24 +100,91 @@ def test_ti_get_task_states(sdk_client, dag_info):
     console.print("[green]✅ Task states test passed!")
 
 
-def test_ti_finish_failed(sdk_client, task_instance_id):
+def test_ti_set_rtif(sdk_client, task_instance_id):
+    """
+    Test setting Rendered Task Instance Fields (RTIF).
+    """
+    console.print("[yellow]Setting Rendered Task Instance Fields...")
+
+    rtif_data = {
+        "rendered_field_1": "test_value_1",
+        "rendered_field_2": "1234",
+    }
+
+    response = sdk_client.task_instances.set_rtif(task_instance_id, rtif_data)
+
+    console.print(" RTIF Response ".center(72, "="))
+    console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
+    console.print(f"[bright_blue]Status:[/] {response.ok}")
+    console.print(f"[bright_blue]Task Instance ID:[/] {task_instance_id}")
+    console.print(f"[bright_blue]Fields Set:[/] {list(rtif_data.keys())}")
+    console.print("=" * 72)
+
+    assert response.ok is True
+    console.print("[green]✅ RTIF test passed!")
+
+
+def test_ti_heartbeat(sdk_client, task_instance_id, core_api_headers, 
dag_info, monkeypatch):
     """
-    Test finishing a task instance with failed state.
+    Test sending heartbeat for a running task instance.
 
-    This is the LAST test and will terminate the long-running task.
-    It must run after all other tests that need the task to be running.
+    This test fetches the actual worker's PID and hostname from core API,
+    then patches get_hostname() to return the worker's hostname, allowing
+    the heartbeat to be accepted by the server.
     """
-    console.print("[yellow]Finishing task instance as FAILED...")
+    console.print("[yellow]Getting task instance details for heartbeat...")
 
-    # Finish the task with failed state
+    ti_url = (
+        f"http://localhost:8080/api/v2/dags/{dag_info['dag_id']}/"
+        
f"dagRuns/{dag_info['dag_run_id']}/taskInstances/long_running_task/tries/1"
+    )
+    ti_response = requests.get(ti_url, headers=core_api_headers, timeout=10)
+    ti_response.raise_for_status()
+
+    ti_data = ti_response.json()
+    worker_hostname = ti_data.get("hostname")
+    worker_pid = ti_data.get("pid")
+
+    console.print(" Worker Information ".center(72, "="))
+    console.print(f"[bright_blue]Worker Hostname:[/] {worker_hostname}")
+    console.print(f"[bright_blue]Worker PID:[/] {worker_pid}")
+    console.print("=" * 72)
+
+    assert worker_hostname is not None
+    assert worker_pid is not None
+
+    # Patch get_hostname to return the worker's hostname
+    from airflow.sdk.api import client as sdk_client_module
+
+    monkeypatch.setattr(sdk_client_module, "get_hostname", lambda: 
worker_hostname)
+
+    console.print("[yellow]Sending heartbeat with worker's PID/hostname...")
+
+    sdk_client.task_instances.heartbeat(task_instance_id, pid=worker_pid)
+
+    console.print(" Heartbeat Response ".center(72, "="))
+    console.print("[bright_blue]Status:[/] Success (204 No Content)")
+    console.print(f"[bright_blue]Task Instance ID:[/] {task_instance_id}")
+    console.print(f"[bright_blue]Used PID:[/] {worker_pid}")
+    console.print(f"[bright_blue]Used Hostname:[/] {worker_hostname}")
+    console.print("=" * 72)
+
+    console.print("[green]✅ Heartbeat test passed!")
+
+
+def test_ti_state_transitions(sdk_client, task_instance_id):
+    """
+    Test task instance state transition to terminal state.
+    """
+    console.print("[yellow]Testing state transition: RUNNING → FAILED...")
     sdk_client.task_instances.finish(
         id=task_instance_id, state=TerminalStateNonSuccess.FAILED, 
when=utcnow(), rendered_map_index="-1"
     )
 

Review Comment:
   Similar question here, do we need to retrieve TI again to validate the 
state? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to