kaxil commented on code in PR #44241:
URL: https://github.com/apache/airflow/pull/44241#discussion_r1858561808
##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +126,39 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+ # Calculate timeout if it was passed
+ trigger_timeout: datetime | None = None
Review Comment:
```suggestion
```
##########
task_sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -159,8 +159,18 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
- except TaskDeferred:
- ...
+ except TaskDeferred as defer:
+ next_method = defer.method_name
+ timeout = defer.timeout
+ kw = defer.kwargs or {}
+ # handle classpath on the server side
+ msg = DeferTask(
+ kwargs=kw,
+ next_method=next_method,
+ timeout=timeout,
+ )
Review Comment:
You should be able to get classpath here with:
```
classpath = defer.trigger.classpath
```
##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -61,6 +62,26 @@ class TITargetStatePayload(BaseModel):
state: IntermediateTIState
+class TIDeferredStatePayload(BaseModel):
+ """Schema for updating TaskInstance to a deferred state."""
+
+ state: Annotated[
+ Literal[IntermediateTIState.DEFERRED],
+ # Specify a default in the schema, but not in code, so Pydantic marks
it as required.
+ WithJsonSchema(
+ {
+ "type": "string",
+ "enum": [IntermediateTIState.DEFERRED],
+ "default": IntermediateTIState.DEFERRED,
+ }
+ ),
+ ]
+
+ kwargs: dict[str, Any]
+ next_method: str
+ timeout: timedelta | None = None
Review Comment:
Should we call this `trigger_timeout`? This would directly go to TI table I
think
##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -61,6 +62,26 @@ class TITargetStatePayload(BaseModel):
state: IntermediateTIState
+class TIDeferredStatePayload(BaseModel):
+ """Schema for updating TaskInstance to a deferred state."""
+
+ state: Annotated[
+ Literal[IntermediateTIState.DEFERRED],
+ # Specify a default in the schema, but not in code, so Pydantic marks
it as required.
+ WithJsonSchema(
+ {
+ "type": "string",
+ "enum": [IntermediateTIState.DEFERRED],
+ "default": IntermediateTIState.DEFERRED,
+ }
+ ),
+ ]
+
+ kwargs: dict[str, Any]
Review Comment:
Also add default_factory as type so we don't need to accept empty kwargs
```suggestion
kwargs: Annotated[dict[str, Any], Field(default_factory=dict)]
```
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -504,11 +518,16 @@ def handle_requests(self, log: FilteringBoundLogger) ->
Generator[None, bytes, N
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, DeferTask):
+ self.final_state = IntermediateTIState.DEFERRED
+ self.client.task_instances.defer(self.ti_id,
msg.model_dump_json())
Review Comment:
```suggestion
self.client.task_instances.defer(self.ti_id,
msg.model_dump_json(exclude_unset=True))
```
##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -122,6 +126,39 @@ def ti_update_state(
)
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ elif isinstance(ti_patch_payload, TIDeferredStatePayload):
+ # Calculate timeout if it was passed
+ trigger_timeout: datetime | None = None
+ if ti_patch_payload.timeout is not None:
+ trigger_timeout = timezone.utcnow() + ti_patch_payload.timeout
+ else:
+ trigger_timeout = None
+
+ defer = TaskDeferred(
+ trigger=None,
+ method_name=ti_patch_payload.next_method,
+ kwargs=ti_patch_payload.kwargs,
+ timeout=None,
+ )
+
+ classpath, _, _ = defer.serialize()
+ trigger_row = Trigger(
+ classpath=classpath,
+ kwargs=ti_patch_payload.kwargs,
+ )
+ session.add(trigger_row)
Review Comment:
This is wrong. this will always set `classpath` as
`airflow.exception.TaskDeferred` which will mean the Trigger will try to use
that to run after deferral and fail
##########
task_sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -504,11 +518,16 @@ def handle_requests(self, log: FilteringBoundLogger) ->
Generator[None, bytes, N
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
+ elif isinstance(msg, DeferTask):
+ self.final_state = IntermediateTIState.DEFERRED
+ self.client.task_instances.defer(self.ti_id,
msg.model_dump_json())
Review Comment:
So it doesn't need to pass defaults
##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -61,6 +62,26 @@ class TITargetStatePayload(BaseModel):
state: IntermediateTIState
+class TIDeferredStatePayload(BaseModel):
+ """Schema for updating TaskInstance to a deferred state."""
+
+ state: Annotated[
+ Literal[IntermediateTIState.DEFERRED],
+ # Specify a default in the schema, but not in code, so Pydantic marks
it as required.
+ WithJsonSchema(
+ {
+ "type": "string",
+ "enum": [IntermediateTIState.DEFERRED],
+ "default": IntermediateTIState.DEFERRED,
+ }
+ ),
+ ]
+
+ kwargs: dict[str, Any]
Review Comment:
Should we call this `trigger_kwargs` or `next_kwargs` instead so it is clear?
##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -61,6 +62,26 @@ class TITargetStatePayload(BaseModel):
state: IntermediateTIState
+class TIDeferredStatePayload(BaseModel):
+ """Schema for updating TaskInstance to a deferred state."""
+
+ state: Annotated[
+ Literal[IntermediateTIState.DEFERRED],
+ # Specify a default in the schema, but not in code, so Pydantic marks
it as required.
+ WithJsonSchema(
+ {
+ "type": "string",
+ "enum": [IntermediateTIState.DEFERRED],
+ "default": IntermediateTIState.DEFERRED,
+ }
+ ),
+ ]
+
+ kwargs: dict[str, Any]
+ next_method: str
+ timeout: timedelta | None = None
Review Comment:
Fine if we keep it as-is too
##########
task_sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -159,8 +159,18 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
- except TaskDeferred:
- ...
+ except TaskDeferred as defer:
+ next_method = defer.method_name
+ timeout = defer.timeout
+ kw = defer.kwargs or {}
+ # handle classpath on the server side
+ msg = DeferTask(
+ kwargs=kw,
+ next_method=next_method,
+ timeout=timeout,
+ )
+ global SUPERVISOR_COMMS
+ SUPERVISOR_COMMS.send_request(msg=msg, log=log)
Review Comment:
We should add a test case for this i.e. when what msg is sent when `run`
raises Deferred
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]