kaxil commented on code in PR #44241:
URL: https://github.com/apache/airflow/pull/44241#discussion_r1856503655
##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -30,11 +31,13 @@
from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+ TIDeferredStatePayload,
TIEnterRunningPayload,
TIHeartbeatInfo,
TIStateUpdate,
TITerminalStatePayload,
)
+from airflow.models import Trigger
Review Comment:
```suggestion
from airflow.models.trigger import Trigger
```
for consistency with L41
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -504,6 +516,12 @@ def handle_requests(self, log: FilteringBoundLogger) ->
Generator[None, bytes, N
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, PatchTIToDeferred):
+ self.final_state = IntermediateTIState.DEFERRED
+ self.client.task_instances.defer(self.ti_id,
msg.model_dump_json())
+ # hmmm, can we do better here
+ # setting to "\n" as we do not have a response to return
Review Comment:
Two options:
1) Move `self.stdin.write(resp + b"\n")` to all the if's where we are
sending a msg
2) Set `resp` to `None`. and change L529 to:
```python
if resp:
self.stdin.write(resp + b"\n")
```
##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+ trigger_row = Trigger(
+ classpath=ti_patch_payload.classpath,
+ kwargs=ti_patch_payload.kwargs,
+ created_date=ti_patch_payload.created_date,
+ )
+ session.add(trigger_row)
+ session.flush()
Review Comment:
Yeah but why do you need `session.flush`? The session will already be
commit'ted and handled via the `get_session`
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/airflow/api_fastapi/execution_api/routes/task_instances.py#L67
##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +125,44 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+ trigger_row = Trigger(
+ classpath=ti_patch_payload.classpath,
+ kwargs=ti_patch_payload.kwargs,
+ created_date=ti_patch_payload.created_date,
+ )
+ session.add(trigger_row)
+ session.flush()
+
+ ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
Review Comment:
We already have this in `old` variable:
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/airflow/api_fastapi/execution_api/routes/task_instances.py#L78
##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +125,44 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+ trigger_row = Trigger(
+ classpath=ti_patch_payload.classpath,
+ kwargs=ti_patch_payload.kwargs,
+ created_date=ti_patch_payload.created_date,
+ )
+ session.add(trigger_row)
+ session.flush()
+
+ ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+ # Calculate timeout too if it was passed
+ trigger_timeout: datetime | None = None
+ if ti_patch_payload.timeout is not None:
+ trigger_timeout = timezone.utcnow() + ti_patch_payload.timeout
+ else:
+ trigger_timeout = None
+
+ # If an execution_timeout is set, set the timeout to the minimum of
+ # it and the trigger timeout
+ if ti.task:
+ execution_timeout = ti.task.execution_timeout
+ if execution_timeout:
+ if TYPE_CHECKING:
+ assert ti.start_date
+ if ti.trigger_timeout:
+ trigger_timeout = min(ti.start_date + execution_timeout,
ti.trigger_timeout)
+ else:
+ trigger_timeout = ti.start_date + execution_timeout
+
+ query = update(TI).where(TI.id == ti_id_str)
+ query = query.values(
+ state=State.DEFERRED,
+ trigger_id=trigger_row.id,
+ next_method=ti_patch_payload.next_method,
+ next_kwargs=ti_patch_payload.kwargs or {},
Review Comment:
Do we need the `or` part here? When would it be `None`?
I think Pydantic would validate that here:
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#L81
##########
task_sdk/src/airflow/sdk/execution_time/comms.py:
##########
@@ -103,6 +104,12 @@ class TaskState(BaseModel):
type: Literal["TaskState"] = "TaskState"
+class PatchTIToDeferred(TIDeferredStatePayload):
Review Comment:
```suggestion
class DeferTask(TIDeferredStatePayload):
```
might be a better name 🤷
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -463,10 +467,18 @@ def final_state(self):
Not valid before the process has finished.
"""
+ if self._final_state:
+ return self._final_state
if self._exit_code == 0:
return self._terminal_state or TerminalTIState.SUCCESS
return TerminalTIState.FAILED
+ @final_state.setter
+ def final_state(self, value):
+ """Setter for final_state for certain task instance stated present in
IntermediateTIState."""
+ if value not in TerminalTIState:
+ self._final_state = value
Review Comment:
Why do we need a `setter`? Shouldn't we set TI state to falil if exit_code
!= 0 like we already in
https://github.com/apache/airflow/blob/4d3140dba5976acb38c3dffd4fd963514f8da634/task_sdk/src/airflow/sdk/execution_time/supervisor.py#L472-L474
##########
task_sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -159,8 +160,23 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
- except TaskDeferred:
- ...
+ except TaskDeferred as defer:
+ trigger = Trigger.from_object(defer.trigger)
Review Comment:
Why `Trigger` model here?
i.e. why do we need L164 to L169
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -389,9 +392,10 @@ def wait(self) -> int:
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1
- self.client.task_instances.finish(
- id=self.ti_id, state=self.final_state,
when=datetime.now(tz=timezone.utc)
- )
+ if self.final_state in TerminalTIState:
Review Comment:
Not 100% certain if we need this check -- I am thinking by defn
`self.final_state` has to be in TerminalState -- if not it is an error!
##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -61,6 +62,28 @@ class TITargetStatePayload(BaseModel):
state: IntermediateTIState
+class TIDeferredStatePayload(BaseModel):
+ """Schema for updating TaskInstance to a deferred state."""
+
+ state: Annotated[
+ Literal[IntermediateTIState.DEFERRED],
+ # Specify a default in the schema, but not in code, so Pydantic marks
it as required.
+ WithJsonSchema(
+ {
+ "type": "string",
+ "enum": [IntermediateTIState.DEFERRED],
+ "default": IntermediateTIState.DEFERRED,
+ }
+ ),
+ ]
+
+ classpath: str
+ kwargs: dict[str, Any]
+ created_date: UtcDateTime
Review Comment:
>If we do not allow to override it,
I can't think of scenarios where we would want to override the `created_at`
date from the worker. So better to remove it from here and handle it in API
logic i.e. in `update_state`
##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +124,53 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+ trigger_row = Trigger(
+ classpath=ti_patch_payload.classpath,
+ kwargs=ti_patch_payload.kwargs,
+ created_date=ti_patch_payload.created_date,
+ )
+ session.add(trigger_row)
+ session.flush()
+
+ ti = session.query(TI).filter(TI.id == ti_id_str).one_or_none()
+
+ if not ti:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail={
+ "message": f"TaskInstance with id {ti_id_str} not found.",
+ },
+ )
+
+ ti.state = TaskInstanceState.DEFERRED
+ ti.trigger_id = trigger_row.id
+ ti.next_method = ti_patch_payload.next_method
+ ti.next_kwargs = ti_patch_payload.kwargs or {}
+ timeout = ti_patch_payload.timeout
+
+ # Calculate timeout too if it was passed
+ if timeout is not None:
+ ti.trigger_timeout = timezone.utcnow() + timeout
+ else:
+ ti.trigger_timeout = None
+
+ # If an execution_timeout is set, set the timeout to the minimum of
+ # it and the trigger timeout
+ if ti.task:
Review Comment:
The TI is fetched from DB so `task` would be `None`. For now add a "TODO" to
handle it later -- we can get it from serialized DAG or get it via the API
itself -- but we can figure that out later
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]