kaxil commented on code in PR #44241:
URL: https://github.com/apache/airflow/pull/44241#discussion_r1856503655


##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -30,11 +31,13 @@
 from airflow.api_fastapi.common.db.common import get_session
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+    TIDeferredStatePayload,
     TIEnterRunningPayload,
     TIHeartbeatInfo,
     TIStateUpdate,
     TITerminalStatePayload,
 )
+from airflow.models import Trigger

Review Comment:
   ```suggestion
   from airflow.models.trigger import Trigger
   ```
   
   for consistency with L41



##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -504,6 +516,12 @@ def handle_requests(self, log: FilteringBoundLogger) -> 
Generator[None, bytes, N
             elif isinstance(msg, GetVariable):
                 var = self.client.variables.get(msg.key)
                 resp = var.model_dump_json(exclude_unset=True).encode()
+            elif isinstance(msg, PatchTIToDeferred):
+                self.final_state = IntermediateTIState.DEFERRED
+                self.client.task_instances.defer(self.ti_id, 
msg.model_dump_json())
+                # hmmm, can we do better here
+                # setting to "\n" as we do not have a response to return

Review Comment:
   Two options:
   
   1) Move `self.stdin.write(resp + b"\n")` to all the if's where we are 
sending a msg
   2) Set `resp` to `None`. and change L529 to:
   
   ```python
   if resp:
        self.stdin.write(resp + b"\n")
   ```



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()

Review Comment:
   Yeah but why do you need `session.flush`? The session will already be 
commit'ted and handled via the `get_session`
   
   
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/airflow/api_fastapi/execution_api/routes/task_instances.py#L67



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +125,44 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()

Review Comment:
   We already have this in `old` variable:
   
   
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/airflow/api_fastapi/execution_api/routes/task_instances.py#L78



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +125,44 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        # Calculate timeout too if it was passed
+        trigger_timeout: datetime | None = None
+        if ti_patch_payload.timeout is not None:
+            trigger_timeout = timezone.utcnow() + ti_patch_payload.timeout
+        else:
+            trigger_timeout = None
+
+        # If an execution_timeout is set, set the timeout to the minimum of
+        # it and the trigger timeout
+        if ti.task:
+            execution_timeout = ti.task.execution_timeout
+            if execution_timeout:
+                if TYPE_CHECKING:
+                    assert ti.start_date
+                if ti.trigger_timeout:
+                    trigger_timeout = min(ti.start_date + execution_timeout, 
ti.trigger_timeout)
+                else:
+                    trigger_timeout = ti.start_date + execution_timeout
+
+        query = update(TI).where(TI.id == ti_id_str)
+        query = query.values(
+            state=State.DEFERRED,
+            trigger_id=trigger_row.id,
+            next_method=ti_patch_payload.next_method,
+            next_kwargs=ti_patch_payload.kwargs or {},

Review Comment:
   Do we need the `or` part here? When would it be `None`?
   
   
   I think Pydantic would validate that here:
   
   
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#L81



##########
task_sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -103,6 +104,12 @@ class TaskState(BaseModel):
     type: Literal["TaskState"] = "TaskState"
 
 
+class PatchTIToDeferred(TIDeferredStatePayload):

Review Comment:
   ```suggestion
   class DeferTask(TIDeferredStatePayload):
   ```
   
   might be a better name 🤷 



##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -463,10 +467,18 @@ def final_state(self):
 
         Not valid before the process has finished.
         """
+        if self._final_state:
+            return self._final_state
         if self._exit_code == 0:
             return self._terminal_state or TerminalTIState.SUCCESS
         return TerminalTIState.FAILED
 
+    @final_state.setter
+    def final_state(self, value):
+        """Setter for final_state for certain task instance stated present in 
IntermediateTIState."""
+        if value not in TerminalTIState:
+            self._final_state = value

Review Comment:
   Why do we need a `setter`? Shouldn't we set TI state to falil if exit_code 
!= 0 like we already in 
   
   
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/task_sdk/src/airflow/sdk/execution_time/supervisor.py#L472-L474



##########
task_sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -159,8 +160,23 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         # TODO next_method to support resuming from deferred
         # TODO: Get a real context object
         ti.task.execute({"task_instance": ti})  # type: ignore[attr-defined]
-    except TaskDeferred:
-        ...
+    except TaskDeferred as defer:
+        trigger = Trigger.from_object(defer.trigger)

Review Comment:
   Why `Trigger` model here?
   
   i.e.  why do we need L164 to L169



##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -389,9 +392,10 @@ def wait(self) -> int:
         # If it hasn't, assume it's failed
         self._exit_code = self._exit_code if self._exit_code is not None else 1
 
-        self.client.task_instances.finish(
-            id=self.ti_id, state=self.final_state, 
when=datetime.now(tz=timezone.utc)
-        )
+        if self.final_state in TerminalTIState:

Review Comment:
   Not 100% certain if we need this check -- I am thinking by defn 
`self.final_state` has to be in TerminalState -- if not it is an error!



##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -61,6 +62,28 @@ class TITargetStatePayload(BaseModel):
     state: IntermediateTIState
 
 
+class TIDeferredStatePayload(BaseModel):
+    """Schema for updating TaskInstance to a deferred state."""
+
+    state: Annotated[
+        Literal[IntermediateTIState.DEFERRED],
+        # Specify a default in the schema, but not in code, so Pydantic marks 
it as required.
+        WithJsonSchema(
+            {
+                "type": "string",
+                "enum": [IntermediateTIState.DEFERRED],
+                "default": IntermediateTIState.DEFERRED,
+            }
+        ),
+    ]
+
+    classpath: str
+    kwargs: dict[str, Any]
+    created_date: UtcDateTime

Review Comment:
   >If we do not allow to override it,
   
   I can't think of scenarios where we would want to override the `created_at` 
date from the worker. So better to remove it from here and handle it in API 
logic i.e. in `update_state`



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )
+
+        ti.state = TaskInstanceState.DEFERRED
+        ti.trigger_id = trigger_row.id
+        ti.next_method = ti_patch_payload.next_method
+        ti.next_kwargs = ti_patch_payload.kwargs or {}
+        timeout = ti_patch_payload.timeout
+
+        # Calculate timeout too if it was passed
+        if timeout is not None:
+            ti.trigger_timeout = timezone.utcnow() + timeout
+        else:
+            ti.trigger_timeout = None
+
+        # If an execution_timeout is set, set the timeout to the minimum of
+        # it and the trigger timeout
+        if ti.task:

Review Comment:
   The TI is fetched from DB so `task` would be `None`. For now add a "TODO" to 
handle it later -- we can get it from serialized DAG or get it via the API 
itself -- but we can figure that out later



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to