jscheffl commented on code in PR #38992:
URL: https://github.com/apache/airflow/pull/38992#discussion_r1606094032
##########
airflow/serialization/pydantic/taskinstance.py:
##########
@@ -21,9 +21,17 @@
from typing_extensions import Annotated
+from airflow.exceptions import AirflowRescheduleException, TaskDeferred
from airflow.models import Operator
from airflow.models.baseoperator import BaseOperator
-from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstance import (
+ TaskInstance,
+ TaskReturnCode,
+ _defer_task,
+ _handle_reschedule,
+ _run_raw_task,
+ _set_ti_attrs,
Review Comment:
As of overlapping names and potential confusions I'd recommend to make
explicitly different names, like e.g.:
```suggestion
_defer_task as _ti_defer_task,
_handle_reschedule as _ti_handle_reschedule,
_run_raw_task as _ti_run_raw_task,
_set_ti_attrs as _ti_set_ti_attrs,
```
##########
airflow/models/taskinstance.py:
##########
@@ -872,6 +1163,8 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance |
TaskInstancePydantic)
return task_instance.task.retries and task_instance.try_number <=
task_instance.max_tries
+@provide_session
+@internal_api_call
def _handle_failure(
Review Comment:
as above - make naming distinct.
```suggestion
def _handle_ti_failure(
```
##########
airflow/models/taskinstance.py:
##########
@@ -1265,6 +1558,132 @@ def _update_rtif(ti, rendered_fields, session: Session
| None = None):
RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id,
session=session)
+def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session:
Session):
+ from airflow.models.dagrun import DagRun
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
+
+ if isinstance(ti, TaskInstancePydantic):
+ orm_ti = DagRun.fetch_task_instance(
+ dag_id=ti.dag_id,
+ dag_run_id=ti.run_id,
+ task_id=ti.task_id,
+ map_index=ti.map_index,
+ session=session,
+ )
+ if TYPE_CHECKING:
+ assert orm_ti
+ ti, pydantic_ti = orm_ti, ti
+ _set_ti_attrs(ti, pydantic_ti)
+ ti.task = pydantic_ti.task
+ return ti
+
+
+@internal_api_call
+@provide_session
+def _defer_task(
Review Comment:
as above - make naming distinct.
```suggestion
def _defer_ti_task(
```
##########
airflow/models/taskinstance.py:
##########
@@ -504,6 +791,34 @@ def _execute_callable(context: Context,
**execute_callable_kwargs):
return result
+def _set_ti_attrs(target, source):
+ # Fields ordered per model definition
+ target.start_date = source.start_date
+ target.end_date = source.end_date
+ target.duration = source.duration
+ target.state = source.state
+ target.try_number = source.try_number
+ target.max_tries = source.max_tries
+ target.hostname = source.hostname
+ target.unixname = source.unixname
+ target.job_id = source.job_id
+ target.pool = source.pool
+ target.pool_slots = source.pool_slots or 1
+ target.queue = source.queue
+ target.priority_weight = source.priority_weight
+ target.operator = source.operator
+ target.custom_operator_name = source.custom_operator_name
+ target.queued_dttm = source.queued_dttm
+ target.queued_by_job_id = source.queued_by_job_id
+ target.pid = source.pid
+ target.executor = source.executor
+ target.executor_config = source.executor_config
+ target.external_executor_id = source.external_executor_id
+ target.trigger_id = source.trigger_id
+ target.next_method = source.next_method
+ target.next_kwargs = source.next_kwargs
Review Comment:
I don't know if this is an inconsistency fro ma previous change but if I
take a look to the fields in the model I am missing:
- updated_at
- rendered_map_index
- trigger_timeout
##########
airflow/models/taskinstance.py:
##########
@@ -187,6 +187,191 @@ class TaskReturnCode(Enum):
"""When task exits with deferral to trigger."""
+@internal_api_call
+@provide_session
+def _merge_ti(ti, session: Session = NEW_SESSION):
+ session.merge(ti)
+ session.commit()
+
+
+@internal_api_call
+@provide_session
+def _add_log(
+ event,
+ task_instance=None,
+ owner=None,
+ owner_display_name=None,
+ extra=None,
+ session: Session = NEW_SESSION,
+ **kwargs,
+):
+ session.add(
+ Log(
+ event,
+ task_instance,
+ owner,
+ owner_display_name,
+ extra,
+ **kwargs,
+ )
+ )
+
+
+def _run_raw_task(
Review Comment:
Can we make a slightly different name for this method in order to better
distinguish it from `TaskInstance._run_raw_task()`? Else this might be
confusion between class and module method.
Proposal:
```suggestion
def _run_ti_raw_task(
```
##########
airflow/models/taskinstance.py:
##########
@@ -1265,6 +1558,132 @@ def _update_rtif(ti, rendered_fields, session: Session
| None = None):
RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id,
session=session)
+def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session:
Session):
+ from airflow.models.dagrun import DagRun
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
+
+ if isinstance(ti, TaskInstancePydantic):
+ orm_ti = DagRun.fetch_task_instance(
+ dag_id=ti.dag_id,
+ dag_run_id=ti.run_id,
+ task_id=ti.task_id,
+ map_index=ti.map_index,
+ session=session,
+ )
+ if TYPE_CHECKING:
+ assert orm_ti
+ ti, pydantic_ti = orm_ti, ti
+ _set_ti_attrs(ti, pydantic_ti)
+ ti.task = pydantic_ti.task
+ return ti
+
+
+@internal_api_call
+@provide_session
+def _defer_task(
+ ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session:
Session = NEW_SESSION
+) -> TaskInstancePydantic | TaskInstance:
+ from airflow.models.trigger import Trigger
+
+ # First, make the trigger entry
+ trigger_row = Trigger.from_object(exception.trigger)
+ session.add(trigger_row)
+ session.flush()
+
+ ti = _coalesce_to_orm_ti(ti=ti, session=session) # ensure orm obj in case
it's pydantic
+
+ if TYPE_CHECKING:
+ assert ti.task
+
+ # Then, update ourselves so it matches the deferral request
+ # Keep an eye on the logic in `check_and_change_state_before_execution()`
+ # depending on self.next_method semantics
+ ti.state = TaskInstanceState.DEFERRED
+ ti.trigger_id = trigger_row.id
+ ti.next_method = exception.method_name
+ ti.next_kwargs = exception.kwargs or {}
+
+ # Calculate timeout too if it was passed
+ if exception.timeout is not None:
+ ti.trigger_timeout = timezone.utcnow() + exception.timeout
+ else:
+ ti.trigger_timeout = None
+
+ # If an execution_timeout is set, set the timeout to the minimum of
+ # it and the trigger timeout
+ execution_timeout = ti.task.execution_timeout
+ if execution_timeout:
+ if TYPE_CHECKING:
+ assert ti.start_date
+ if ti.trigger_timeout:
+ ti.trigger_timeout = min(ti.start_date + execution_timeout,
ti.trigger_timeout)
+ else:
+ ti.trigger_timeout = ti.start_date + execution_timeout
+ if ti.test_mode:
+ _add_log(event=ti.state, task_instance=ti, session=session)
+ session.merge(ti)
+ session.commit()
+ return ti
+
+
+@internal_api_call
+@provide_session
+def _handle_reschedule(
Review Comment:
as above - make naming distinct.
```suggestion
def _handle_ti_reschedule(
```
##########
airflow/models/taskinstance.py:
##########
@@ -1265,6 +1558,132 @@ def _update_rtif(ti, rendered_fields, session: Session
| None = None):
RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id,
session=session)
+def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session:
Session):
Review Comment:
Can we ensure with typing?
```suggestion
def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session:
Session) -> TaskInstance:
```
##########
airflow/models/taskinstance.py:
##########
@@ -504,6 +791,34 @@ def _execute_callable(context: Context,
**execute_callable_kwargs):
return result
+def _set_ti_attrs(target, source):
Review Comment:
Can you add some typing?
```suggestion
def _set_ti_attrs(target: TaskInstance | TaskInstancePydantic, source:
TaskInstance | TaskInstancePydantic):
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]