amoghrajesh commented on code in PR #44899: URL: https://github.com/apache/airflow/pull/44899#discussion_r1886694675
########## airflow/api_fastapi/execution_api/routes/task_instances.py: ########## @@ -48,6 +51,108 @@ log = logging.getLogger(__name__) [email protected]( + "/{task_instance_id}/run", Review Comment: I think `run` sounds better too ########## airflow/api_fastapi/execution_api/routes/task_instances.py: ########## @@ -92,35 +197,15 @@ def ti_update_state( query = update(TI).where(TI.id == ti_id_str).values(data) + # TODO: Instead remove this payload from discriminator accepted by this endpoint if isinstance(ti_patch_payload, TIEnterRunningPayload): - if previous_state != State.QUEUED: - log.warning( - "Can not start Task Instance ('%s') in invalid state: %s", - ti_id_str, - previous_state, - ) - - # TODO: Pass a RFC 9457 compliant error message in "detail" field - # https://datatracker.ietf.org/doc/html/rfc9457 - # to provide more information about the error - # FastAPI will automatically convert this to a JSON response - # This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370 - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail={ - "reason": "invalid_state", - "message": "TI was not in a state where it could be marked as running", - "previous_state": previous_state, - }, - ) - log.info("Task with %s state started on %s ", previous_state, ti_patch_payload.hostname) - # Ensure there is no end date set. - query = query.values( - end_date=None, - hostname=ti_patch_payload.hostname, - unixname=ti_patch_payload.unixname, - pid=ti_patch_payload.pid, - state=State.RUNNING, + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, Review Comment: Would a 400 bad request serve better here? https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/400 > The reason for a 400 response is typically due to malformed request syntax, invalid request message framing, or deceptive request routing. ########## airflow/api_fastapi/execution_api/routes/task_instances.py: ########## @@ -48,6 +51,108 @@ log = logging.getLogger(__name__) [email protected]( + "/{task_instance_id}/run", + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, + status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, + status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, + }, +) +def ti_run( + task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep +) -> TIRunContext: + """ + Run a TaskInstance. + + This endpoint is used to start a TaskInstance that is in the QUEUED state. + """ + # We only use UUID above for validation purposes + ti_id_str = str(task_instance_id) + + old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update() + try: + (previous_state, dag_id, run_id) = session.execute(old).one() + except NoResultFound: + log.error("Task Instance %s not found", ti_id_str) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": "Task Instance not found", + }, + ) + + # We exclude_unset to avoid updating fields that are not set in the payload + data = ti_run_payload.model_dump(exclude_unset=True) + + query = update(TI).where(TI.id == ti_id_str).values(data) + + if previous_state != State.QUEUED: + log.warning( + "Can not start Task Instance ('%s') in invalid state: %s", + ti_id_str, + previous_state, + ) + + # TODO: Pass a RFC 9457 compliant error message in "detail" field + # https://datatracker.ietf.org/doc/html/rfc9457 + # to provide more information about the error + # FastAPI will automatically convert this to a JSON response + # This might be added in FastAPI in https://github.com/fastapi/fastapi/issues/10370 + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "invalid_state", + "message": "TI was not in a state where it could be marked as running", + "previous_state": previous_state, + }, + ) + log.info("Task with %s state started on %s ", previous_state, ti_run_payload.hostname) + # Ensure there is no end date set. + query = query.values( + end_date=None, + hostname=ti_run_payload.hostname, + unixname=ti_run_payload.unixname, + pid=ti_run_payload.pid, + state=State.RUNNING, + ) + + try: + result = session.execute(query) + log.info("TI %s state updated: %s row(s) affected", ti_id_str, result.rowcount) + + dr = session.execute( + select( + DR.run_id, + DR.dag_id, + DR.data_interval_start, + DR.data_interval_end, + DR.start_date, + DR.end_date, + DR.run_type, + DR.conf, + DR.logical_date, + ).filter_by(dag_id=dag_id, run_id=run_id) + ).one_or_none() + + if not dr: + raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} not found.") + + return TIRunContext( + dag_run=DagRun.model_validate(dr, from_attributes=True), + # TODO: Add variables and connections that are needed (and has perms) for the task + variables=[], + connections=[], Review Comment: Looks good! ########## airflow/api_fastapi/execution_api/datamodels/taskinstance.py: ########## @@ -139,3 +142,34 @@ class TaskInstance(BaseModel): """Schema for setting RTIF for a task instance.""" RTIFPayload = RootModel[dict[str, str]] + + +class DagRun(BaseModel): + """Schema for DagRun model with minimal required fields needed for Runtime.""" + + # TODO: `dag_id` and `run_id` are duplicated from TaskInstance + # See if we can avoid sending these fields from API server and instead + # use the TaskInstance data to get the DAG run information in the client (Task Execution Interface). Review Comment: Yeah sounds good ########## airflow/api_fastapi/execution_api/routes/task_instances.py: ########## @@ -48,6 +51,108 @@ log = logging.getLogger(__name__) [email protected]( + "/{task_instance_id}/run", + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, + status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, + status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, + }, +) +def ti_run( + task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep +) -> TIRunContext: + """ + Run a TaskInstance. + + This endpoint is used to start a TaskInstance that is in the QUEUED state. + """ + # We only use UUID above for validation purposes + ti_id_str = str(task_instance_id) + + old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == ti_id_str).with_for_update() + try: + (previous_state, dag_id, run_id) = session.execute(old).one() + except NoResultFound: + log.error("Task Instance %s not found", ti_id_str) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": "Task Instance not found", + }, + ) + + # We exclude_unset to avoid updating fields that are not set in the payload + data = ti_run_payload.model_dump(exclude_unset=True) + + query = update(TI).where(TI.id == ti_id_str).values(data) + + if previous_state != State.QUEUED: Review Comment: I believe reschedule will come from "up_for_reschedule" and defer will come from "deferred". We might want to add checks for either of those? ########## task_sdk/tests/execution_time/test_task_runner.py: ########## @@ -318,3 +318,83 @@ def __init__(self, *args, **kwargs): msg=SetRenderedFields(rendered_fields=expected_rendered_fields), log=mock.ANY, ) + + +class TestRuntimeTaskInstance: + def test_get_context_without_ti_context_from_server(self, mocked_parse): + """Test get_template_context without ti_context_from_server.""" + from airflow.providers.standard.operators.python import PythonOperator + + task = PythonOperator( + task_id="hello", + python_callable=lambda: print("hello"), + ) Review Comment: Yeah good idea. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
