amoghrajesh commented on code in PR #44241:
URL: https://github.com/apache/airflow/pull/44241#discussion_r1853724490


##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )
+
+        ti.state = TaskInstanceState.DEFERRED
+        ti.trigger_id = trigger_row.id
+        ti.next_method = ti_patch_payload.next_method
+        ti.next_kwargs = ti_patch_payload.kwargs or {}
+        timeout = ti_patch_payload.timeout
+
+        # Calculate timeout too if it was passed
+        if timeout is not None:
+            ti.trigger_timeout = timezone.utcnow() + timeout
+        else:
+            ti.trigger_timeout = None
+
+        # If an execution_timeout is set, set the timeout to the minimum of
+        # it and the trigger timeout
+        if ti.task:
+            execution_timeout = ti.task.execution_timeout
+            if execution_timeout:
+                if TYPE_CHECKING:
+                    assert ti.start_date
+                if ti.trigger_timeout:
+                    ti.trigger_timeout = min(ti.start_date + 
execution_timeout, ti.trigger_timeout)
+                else:
+                    ti.trigger_timeout = ti.start_date + execution_timeout
+
+        session.commit()
+
+        log.info("TI %s state updated to: deferred", ti_id_str)
+        return

Review Comment:
   It will return a 204 here



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )
+
+        ti.state = TaskInstanceState.DEFERRED
+        ti.trigger_id = trigger_row.id
+        ti.next_method = ti_patch_payload.next_method
+        ti.next_kwargs = ti_patch_payload.kwargs or {}
+        timeout = ti_patch_payload.timeout
+
+        # Calculate timeout too if it was passed
+        if timeout is not None:
+            ti.trigger_timeout = timezone.utcnow() + timeout

Review Comment:
   timeout is of timedelta form



##########
task_sdk/tests/execution_time/test_supervisor.py:
##########
@@ -297,7 +331,10 @@ def test_handle_requests(
         generator.send(msg)
 
         # Verify the correct client method was called
-        mock_client_method.assert_called_once_with(method_arg)
+        mock_client_method.assert_called_once_with(*method_arg)
 
+        if mock_response == "":
+            # for task instance endpoints, there won't be any response

Review Comment:
   We should handle this based on how we handle the client response for task 
instance endpoints



##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -504,6 +516,12 @@ def handle_requests(self, log: FilteringBoundLogger) -> 
Generator[None, bytes, N
             elif isinstance(msg, GetVariable):
                 var = self.client.variables.get(msg.key)
                 resp = var.model_dump_json(exclude_unset=True).encode()
+            elif isinstance(msg, PatchTIToDeferred):
+                self.final_state = IntermediateTIState.DEFERRED
+                self.client.task_instances.defer(self.ti_id, 
msg.model_dump_json())
+                # hmmm, can we do better here
+                # setting to "\n" as we do not have a response to return

Review Comment:
   Needs some thought here @ashb @kaxil 



##########
task_sdk/tests/execution_time/test_supervisor.py:
##########
@@ -35,13 +35,21 @@
 from airflow.sdk.api import client as sdk_client
 from airflow.sdk.api.datamodels._generated import TaskInstance
 from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity
-from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection, 
GetVariable, VariableResult
+from airflow.sdk.execution_time.comms import (
+    ConnectionResult,
+    GetConnection,
+    GetVariable,
+    PatchTIToDeferred,
+    VariableResult,
+)
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise
 from airflow.utils import timezone as tz
 
 if TYPE_CHECKING:
     import kgb
 
+TI_ID = uuid7()

Review Comment:
   To make it predictable. It will be needed to assert the task instance 
endpoints as we will assert the id too



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to