vincbeck commented on code in PR #28900:
URL: https://github.com/apache/airflow/pull/28900#discussion_r1309075951


##########
airflow/models/taskinstance.py:
##########
@@ -371,6 +373,802 @@ def _creator_note(val):
         return TaskInstanceNote(*val)
 
 
+def _execute_task(task_instance, context, task_orig):
+    """
+    Execute Task (optionally with a Timeout) and push Xcom results.
+
+    :param task_instance: the task instance
+    :param context: Jinja2 context
+    :param task_orig: origin task
+
+    :meta private:
+    """
+    task_to_execute = task_instance.task
+
+    if isinstance(task_to_execute, MappedOperator):
+        raise AirflowException("MappedOperator cannot be executed.")
+
+    # If the task has been deferred and is being executed due to a trigger,
+    # then we need to pick the right method to come back to, otherwise
+    # we go for the default execute
+    if task_instance.next_method:
+        # __fail__ is a special signal value for next_method that indicates
+        # this task was scheduled specifically to fail.
+        if task_instance.next_method == "__fail__":
+            next_kwargs = task_instance.next_kwargs or {}
+            traceback = next_kwargs.get("traceback")
+            if traceback is not None:
+                log.error("Trigger failed:\n%s", "\n".join(traceback))
+            raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
+        # Grab the callable off the Operator/Task and add in any kwargs
+        execute_callable = getattr(task_to_execute, task_instance.next_method)
+        if task_instance.next_kwargs:
+            execute_callable = partial(execute_callable, 
**task_instance.next_kwargs)
+    else:
+        execute_callable = task_to_execute.execute
+    # If a timeout is specified for the task, make it fail
+    # if it goes beyond
+    if task_to_execute.execution_timeout:
+        # If we are coming in with a next_method (i.e. from a deferral),
+        # calculate the timeout from our start_date.
+        if task_instance.next_method:
+            timeout_seconds = (
+                task_to_execute.execution_timeout - (timezone.utcnow() - 
task_instance.start_date)
+            ).total_seconds()
+        else:
+            timeout_seconds = task_to_execute.execution_timeout.total_seconds()
+        try:
+            # It's possible we're already timed out, so fast-fail if true
+            if timeout_seconds <= 0:
+                raise AirflowTaskTimeout()
+            # Run task in timeout wrapper
+            with timeout(timeout_seconds):
+                result = execute_callable(context=context)
+        except AirflowTaskTimeout:
+            task_to_execute.on_kill()
+            raise
+    else:
+        result = execute_callable(context=context)
+    with create_session() as session:
+        if task_to_execute.do_xcom_push:
+            xcom_value = result
+        else:
+            xcom_value = None
+        if xcom_value is not None:  # If the task returns a result, push an 
XCom containing it.
+            task_instance.xcom_push(key=XCOM_RETURN_KEY, value=xcom_value, 
session=session)
+        _record_task_map_for_downstreams(
+            task_instance=task_instance, task=task_orig, value=xcom_value, 
session=session
+        )
+    return result
+
+
+def _refresh_from_db(
+    *, task_instance: TaskInstance | TaskInstancePydantic, session: Session, 
lock_for_update: bool = False
+) -> None:
+    """
+    Refreshes the task instance from the database based on the primary key.
+
+    :param task_instance: the task instance
+    :param session: SQLAlchemy ORM Session
+    :param lock_for_update: if True, indicates that the database should
+        lock the TaskInstance (issuing a FOR UPDATE clause) until the
+        session is committed.
+
+    :meta private:
+    """
+    if task_instance in session:
+        session.refresh(task_instance, 
TaskInstance.__mapper__.column_attrs.keys())
+
+    ti = TaskInstance.get_task_instance(
+        dag_id=task_instance.dag_id,
+        task_id=task_instance.task_id,
+        run_id=task_instance.run_id,
+        map_index=task_instance.map_index,
+        select_columns=True,
+        lock_for_update=lock_for_update,
+        session=session,
+    )
+
+    if ti:
+        # Fields ordered per model definition
+        task_instance.start_date = ti.start_date
+        task_instance.end_date = ti.end_date
+        task_instance.duration = ti.duration
+        task_instance.state = ti.state
+        # Since we selected columns, not the object, this is the raw value
+        task_instance.try_number = ti.try_number
+        task_instance.max_tries = ti.max_tries
+        task_instance.hostname = ti.hostname
+        task_instance.unixname = ti.unixname
+        task_instance.job_id = ti.job_id
+        task_instance.pool = ti.pool
+        task_instance.pool_slots = ti.pool_slots or 1
+        task_instance.queue = ti.queue
+        task_instance.priority_weight = ti.priority_weight
+        task_instance.operator = ti.operator
+        task_instance.custom_operator_name = ti.custom_operator_name
+        task_instance.queued_dttm = ti.queued_dttm
+        task_instance.queued_by_job_id = ti.queued_by_job_id
+        task_instance.pid = ti.pid
+        task_instance.executor_config = ti.executor_config
+        task_instance.external_executor_id = ti.external_executor_id
+        task_instance.trigger_id = ti.trigger_id
+        task_instance.next_method = ti.next_method
+        task_instance.next_kwargs = ti.next_kwargs
+    else:
+        task_instance.state = None
+
+
+def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> 
None:

Review Comment:
   100%



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to