potiuk commented on code in PR #28900:
URL: https://github.com/apache/airflow/pull/28900#discussion_r1308462937
##########
airflow/models/taskinstance.py:
##########
@@ -371,6 +373,802 @@ def _creator_note(val):
return TaskInstanceNote(*val)
+def _execute_task(task_instance, context, task_orig):
+ """
+ Execute Task (optionally with a Timeout) and push Xcom results.
+
+ :param task_instance: the task instance
+ :param context: Jinja2 context
+ :param task_orig: origin task
+
+ :meta private:
+ """
+ task_to_execute = task_instance.task
+
+ if isinstance(task_to_execute, MappedOperator):
+ raise AirflowException("MappedOperator cannot be executed.")
+
+ # If the task has been deferred and is being executed due to a trigger,
+ # then we need to pick the right method to come back to, otherwise
+ # we go for the default execute
+ if task_instance.next_method:
+ # __fail__ is a special signal value for next_method that indicates
+ # this task was scheduled specifically to fail.
+ if task_instance.next_method == "__fail__":
+ next_kwargs = task_instance.next_kwargs or {}
+ traceback = next_kwargs.get("traceback")
+ if traceback is not None:
+ log.error("Trigger failed:\n%s", "\n".join(traceback))
+ raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
+ # Grab the callable off the Operator/Task and add in any kwargs
+ execute_callable = getattr(task_to_execute, task_instance.next_method)
+ if task_instance.next_kwargs:
+ execute_callable = partial(execute_callable,
**task_instance.next_kwargs)
+ else:
+ execute_callable = task_to_execute.execute
+ # If a timeout is specified for the task, make it fail
+ # if it goes beyond
+ if task_to_execute.execution_timeout:
+ # If we are coming in with a next_method (i.e. from a deferral),
+ # calculate the timeout from our start_date.
+ if task_instance.next_method:
+ timeout_seconds = (
+ task_to_execute.execution_timeout - (timezone.utcnow() -
task_instance.start_date)
+ ).total_seconds()
+ else:
+ timeout_seconds = task_to_execute.execution_timeout.total_seconds()
+ try:
+ # It's possible we're already timed out, so fast-fail if true
+ if timeout_seconds <= 0:
+ raise AirflowTaskTimeout()
+ # Run task in timeout wrapper
+ with timeout(timeout_seconds):
+ result = execute_callable(context=context)
+ except AirflowTaskTimeout:
+ task_to_execute.on_kill()
+ raise
+ else:
+ result = execute_callable(context=context)
+ with create_session() as session:
+ if task_to_execute.do_xcom_push:
+ xcom_value = result
+ else:
+ xcom_value = None
+ if xcom_value is not None: # If the task returns a result, push an
XCom containing it.
+ task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value,
session=session)
+ _record_task_map_for_downstreams(
+ task_instance=task_instance, task=task_orig, value=xcom_value,
session=session
+ )
+ return result
+
+
+def _refresh_from_db(
+ *, task_instance: TaskInstance | TaskInstancePydantic, session: Session,
lock_for_update: bool = False
+) -> None:
+ """
+ Refreshes the task instance from the database based on the primary key.
+
+ :param task_instance: the task instance
+ :param session: SQLAlchemy ORM Session
+ :param lock_for_update: if True, indicates that the database should
+ lock the TaskInstance (issuing a FOR UPDATE clause) until the
+ session is committed.
+
+ :meta private:
+ """
+ if task_instance in session:
+ session.refresh(task_instance,
TaskInstance.__mapper__.column_attrs.keys())
+
+ ti = TaskInstance.get_task_instance(
+ dag_id=task_instance.dag_id,
+ task_id=task_instance.task_id,
+ run_id=task_instance.run_id,
+ map_index=task_instance.map_index,
+ select_columns=True,
+ lock_for_update=lock_for_update,
+ session=session,
+ )
+
+ if ti:
+ # Fields ordered per model definition
+ task_instance.start_date = ti.start_date
+ task_instance.end_date = ti.end_date
+ task_instance.duration = ti.duration
+ task_instance.state = ti.state
+ # Since we selected columns, not the object, this is the raw value
+ task_instance.try_number = ti.try_number
+ task_instance.max_tries = ti.max_tries
+ task_instance.hostname = ti.hostname
+ task_instance.unixname = ti.unixname
+ task_instance.job_id = ti.job_id
+ task_instance.pool = ti.pool
+ task_instance.pool_slots = ti.pool_slots or 1
+ task_instance.queue = ti.queue
+ task_instance.priority_weight = ti.priority_weight
+ task_instance.operator = ti.operator
+ task_instance.custom_operator_name = ti.custom_operator_name
+ task_instance.queued_dttm = ti.queued_dttm
+ task_instance.queued_by_job_id = ti.queued_by_job_id
+ task_instance.pid = ti.pid
+ task_instance.executor_config = ti.executor_config
+ task_instance.external_executor_id = ti.external_executor_id
+ task_instance.trigger_id = ti.trigger_id
+ task_instance.next_method = ti.next_method
+ task_instance.next_kwargs = ti.next_kwargs
+ else:
+ task_instance.state = None
+
+
+def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) ->
None:
Review Comment:
This an example where we want to convert the method into `internal_api` one.
This one should get "task_id", "dag_id", "execution_date" - i.e. primary key of
TaskInstance as - parameters + the modified value (in this case it is the
"end_date") and it should berform update of the TaskInstance.
If you look at all the places where `_set_duration` is called, it is always
this:
```
self.end_date = timezone.utcnow()
self.set_duration()
```
What REALLY happens here, we want to update end date and recalculate
duration.
So what we should reaplace it with should IMHO look similarly to:
```
@internal_api_call
@provider_session
# <-- primary key
-> , <updated value>
def set_end_date(dag_id: str, task_id: str, execution_date: date, end_date:
date):
task_instance = session.get(dag_id=dag_id, task_id=task_id,
execution_date=execution_date | None)
task_instance.end_date = end_date
if task_instance.end_date ..... # rest follows from the original code
.....
....
session.commit()
```
This is the example where Local task (client) wants to update task instance
DB model - and we need to make an internal_api call. Basically everywhere where
from the client side we want to modify the DB model, we should make such a
conversion - to pass primary key of the object we want to update, and the value
to update as parameters, not the ORM object itself.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]