kaxil commented on code in PR #44241:
URL: https://github.com/apache/airflow/pull/44241#discussion_r1854113365


##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -61,6 +62,28 @@ class TITargetStatePayload(BaseModel):
     state: IntermediateTIState
 
 
+class TIDeferredStatePayload(BaseModel):
+    """Schema for updating TaskInstance to a deferred state."""
+
+    state: Annotated[
+        Literal[IntermediateTIState.DEFERRED],
+        # Specify a default in the schema, but not in code, so Pydantic marks 
it as required.
+        WithJsonSchema(
+            {
+                "type": "string",
+                "enum": [IntermediateTIState.DEFERRED],
+                "default": IntermediateTIState.DEFERRED,
+            }
+        ),
+    ]
+
+    classpath: str
+    kwargs: dict[str, Any]
+    created_date: UtcDateTime

Review Comment:
   Don't think we need it here, we can handle it on server side itself in 
`update_state`



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()

Review Comment:
   Shouldn't need it



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )
+
+        ti.state = TaskInstanceState.DEFERRED
+        ti.trigger_id = trigger_row.id
+        ti.next_method = ti_patch_payload.next_method
+        ti.next_kwargs = ti_patch_payload.kwargs or {}
+        timeout = ti_patch_payload.timeout
+
+        # Calculate timeout too if it was passed
+        if timeout is not None:
+            ti.trigger_timeout = timezone.utcnow() + timeout
+        else:
+            ti.trigger_timeout = None
+
+        # If an execution_timeout is set, set the timeout to the minimum of
+        # it and the trigger timeout
+        if ti.task:

Review Comment:
   I don't think we would ever have `ti.task` value available here!



##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -78,6 +101,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> 
str:
         return str(state)
     elif state in set(TerminalTIState):
         return "_terminal_"
+    elif state == "deferred":

Review Comment:
   ```suggestion
       elif state == TIState.DEFERRED:
   ```
   to be consistent with L100



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )
+
+        ti.state = TaskInstanceState.DEFERRED

Review Comment:
   ```suggestion
           ti.state = State.DEFERRED
   ```
   
   Or we change other entries above



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )

Review Comment:
   It is already covered by 
   
   
https://github.com/apache/airflow/blob/6dd8f148ac837fa99fa02c82a06fbe4307b11baf/airflow/api_fastapi/execution_api/routes/task_instances.py#L75-L88



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )
+
+        ti.state = TaskInstanceState.DEFERRED
+        ti.trigger_id = trigger_row.id
+        ti.next_method = ti_patch_payload.next_method
+        ti.next_kwargs = ti_patch_payload.kwargs or {}
+        timeout = ti_patch_payload.timeout

Review Comment:
   You can reuse the following to update values
   
   
https://github.com/apache/airflow/blob/6dd8f148ac837fa99fa02c82a06fbe4307b11baf/airflow/api_fastapi/execution_api/routes/task_instances.py#L93
   
   Example: 
   
   
https://github.com/apache/airflow/blob/6dd8f148ac837fa99fa02c82a06fbe4307b11baf/airflow/api_fastapi/execution_api/routes/task_instances.py#L118-L124



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
         )
     elif isinstance(ti_patch_payload, TITerminalStatePayload):
         query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+    elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+        trigger_row = Trigger(
+            classpath=ti_patch_payload.classpath,
+            kwargs=ti_patch_payload.kwargs,
+            created_date=ti_patch_payload.created_date,
+        )
+        session.add(trigger_row)
+        session.flush()
+
+        ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+        if not ti:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail={
+                    "message": f"TaskInstance with id {ti_id_str} not found.",
+                },
+            )
+
+        ti.state = TaskInstanceState.DEFERRED
+        ti.trigger_id = trigger_row.id
+        ti.next_method = ti_patch_payload.next_method
+        ti.next_kwargs = ti_patch_payload.kwargs or {}
+        timeout = ti_patch_payload.timeout
+
+        # Calculate timeout too if it was passed
+        if timeout is not None:
+            ti.trigger_timeout = timezone.utcnow() + timeout
+        else:
+            ti.trigger_timeout = None
+
+        # If an execution_timeout is set, set the timeout to the minimum of
+        # it and the trigger timeout
+        if ti.task:
+            execution_timeout = ti.task.execution_timeout
+            if execution_timeout:
+                if TYPE_CHECKING:
+                    assert ti.start_date
+                if ti.trigger_timeout:
+                    ti.trigger_timeout = min(ti.start_date + 
execution_timeout, ti.trigger_timeout)
+                else:
+                    ti.trigger_timeout = ti.start_date + execution_timeout
+
+        session.commit()
+
+        log.info("TI %s state updated to: deferred", ti_id_str)
+        return

Review Comment:
   We don't need to return here or have session.commit, the code below will 
take care of it, if you set the attributes well



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -30,14 +30,16 @@
 from airflow.api_fastapi.common.db.common import get_session
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+    TIDeferredStatePayload,
     TIEnterRunningPayload,
     TIHeartbeatInfo,
     TIStateUpdate,
     TITerminalStatePayload,
 )
+from airflow.models import Trigger
 from airflow.models.taskinstance import TaskInstance as TI
 from airflow.utils import timezone
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState

Review Comment:
   Since we already have State, we don't need to import `TaskInstanceState` too



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to