kaxil commented on code in PR #66141:
URL: https://github.com/apache/airflow/pull/66141#discussion_r3307494951


##########
task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py:
##########
@@ -356,6 +357,15 @@ def supervise_callback(
         logger = structlog.get_logger(logger_name="callback").bind()
 
     with _ensure_client(server, token, client=client) as client:
+        # Mark the callback as RUNNING via the API. This is the single endpoint
+        # that accepts a workload-scoped token; it returns a fresh execution
+        # token via the Refreshed-API-Token header which the Client's response
+        # hook adopts automatically for the rest of the run.
+        client.callbacks.start(id)

Review Comment:
   `start()` sits outside the `try/finally` below, so if it raises (network 
blip, 5xx from the API, JWT validator hiccup), the inner `try:` never executes 
and `finish()` is never called. The scheduler fallback in 
`_process_executor_events` now covers this (QUEUED to terminal on executor 
event), so the row doesn't actually get stuck. But the supervisor's own 
contract becomes asymmetric and reads as a hidden dependency on the scheduler 
safety net.
   
   Two options worth considering:
   
   1. Move `start()` inside the `try`. If it raises, `finally` calls 
`finish(state=FAILED)` against a row still in QUEUED, the API returns 409, and 
the existing `try/except` around `finish()` swallows it. Cost: an extra log 
line on the failure path.
   2. Keep `start()` where it is, but add a one-line comment noting that the 
scheduler executor-event path is what catches the stuck-QUEUED case.
   
   Either is fine. Just flagging the asymmetry so it's a conscious choice 
rather than an accidental one.
   



##########
task-sdk/tests/task_sdk/execution_time/test_callback_supervisor.py:
##########
@@ -239,3 +245,106 @@ def test_handle_requests(
 
         if client_mock:
             mock_client_method.assert_called_once_with(*client_mock.args, 
**client_mock.kwargs)
+
+
+class TestSuperviseCallback:
+    """supervise_callback drives every callback state transition through the 
API."""
+
+    def _make_mock_client(self, mocker):
+        client = mocker.Mock(spec=Client)
+        client.callbacks = mocker.Mock(spec=CallbackOperations)
+        return client
+
+    def test_start_called_before_subprocess_then_finish_success(self, mocker):
+        cb_id = str(uuid4())
+        client = self._make_mock_client(mocker)
+
+        order: list[str] = []
+        client.callbacks.start.side_effect = lambda _id: 
order.append("start_api")
+        client.callbacks.finish.side_effect = lambda _id, state, output=None: 
order.append(f"finish:{state}")
+
+        proc = mocker.Mock()
+        proc.wait.return_value = 0
+
+        def _subprocess_start(**_):
+            order.append("subprocess_start")
+            return proc
+
+        mocker.patch.object(CallbackSubprocess, "start", 
side_effect=_subprocess_start)
+
+        exit_code = supervise_callback(
+            id=cb_id,
+            callback_path="tests.fake.callback",
+            callback_kwargs={},
+            token="workload-token",
+            client=client,
+        )
+
+        assert exit_code == 0
+        client.callbacks.start.assert_called_once_with(cb_id)
+        client.callbacks.finish.assert_called_once_with(cb_id, 
state="success", output=None)
+        assert order == ["start_api", "subprocess_start", "finish:success"]
+
+    def test_finish_called_with_failed_when_subprocess_exits_nonzero(self, 
mocker):
+        cb_id = str(uuid4())
+        client = self._make_mock_client(mocker)
+
+        proc = mocker.Mock()
+        proc.wait.return_value = 1
+        mocker.patch.object(CallbackSubprocess, "start", return_value=proc)
+
+        with pytest.raises(RuntimeError, match="exited with code 1"):
+            supervise_callback(
+                id=cb_id,
+                callback_path="tests.fake.callback",
+                callback_kwargs={},
+                token="workload-token",
+                client=client,
+            )
+
+        client.callbacks.start.assert_called_once_with(cb_id)
+        client.callbacks.finish.assert_called_once()
+        kwargs = client.callbacks.finish.call_args.kwargs
+        assert kwargs["state"] == "failed"
+        assert "exited with code 1" in kwargs["output"]
+
+    def test_finish_called_with_failed_when_subprocess_raises(self, mocker):
+        cb_id = str(uuid4())
+        client = self._make_mock_client(mocker)
+
+        mocker.patch.object(CallbackSubprocess, "start", 
side_effect=RuntimeError("boom"))
+
+        with pytest.raises(RuntimeError, match="boom"):
+            supervise_callback(
+                id=cb_id,
+                callback_path="tests.fake.callback",
+                callback_kwargs={},
+                token="workload-token",
+                client=client,
+            )
+
+        client.callbacks.finish.assert_called_once()
+        kwargs = client.callbacks.finish.call_args.kwargs
+        assert kwargs["state"] == "failed"
+        assert "RuntimeError" in kwargs["output"]

Review Comment:
   Test gap that pairs with the supervisor-side finding on 
`callback_supervisor.py:364`. The suite covers `CallbackSubprocess.start` 
raising and `client.callbacks.finish` raising, but not `client.callbacks.start` 
raising. A regression test would lock in whichever behaviour you pick for the 
start-raises path:
   
   ```python
   def test_callbacks_start_failure_propagates_and_does_not_call_finish(self, 
mocker):
       cb_id = str(uuid4())
       client = self._make_mock_client(mocker)
       client.callbacks.start.side_effect = RuntimeError("API unreachable")
   
       with pytest.raises(RuntimeError, match="API unreachable"):
           supervise_callback(
               id=cb_id,
               callback_path="tests.fake.callback",
               callback_kwargs={},
               token="workload-token",
               client=client,
           )
       client.callbacks.finish.assert_not_called()
   ```
   
   If you go with option 1 in the other thread (move `start()` inside the 
`try`), this test would instead assert `finish` is called with `state=FAILED` 
and the 409 is swallowed.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to