dstandish commented on code in PR #38992:
URL: https://github.com/apache/airflow/pull/38992#discussion_r1571078294


##########
airflow/models/taskinstance.py:
##########
@@ -186,6 +186,193 @@ class TaskReturnCode(Enum):
     """When task exits with deferral to trigger."""
 
 
+@internal_api_call
+@provide_session
+def _merge_ti(ti, session: Session = NEW_SESSION):
+    session.merge(ti)
+    session.commit()
+
+
+@internal_api_call
+@provide_session
+def _add_log(
+    event,
+    task_instance=None,
+    owner=None,
+    owner_display_name=None,
+    extra=None,
+    session: Session = NEW_SESSION,
+    **kwargs,
+):
+    session.add(
+        Log(
+            event,
+            task_instance,
+            owner,
+            owner_display_name,
+            extra,
+            **kwargs,
+        )
+    )
+
+
+def _run_raw_task(
+    ti: TaskInstance | TaskInstancePydantic,
+    mark_success: bool = False,
+    test_mode: bool = False,
+    job_id: str | None = None,
+    pool: str | None = None,
+    raise_on_defer: bool = False,
+    session: Session | None = None,
+) -> TaskReturnCode | None:
+    """
+    Run a task, update the state upon completion, and run any appropriate 
callbacks.
+
+    Immediately runs the task (without checking or changing db state
+    before execution) and then sets the appropriate final state after
+    completion and runs any post-execute callbacks. Meant to be called
+    only after another function changes the state to running.
+
+    :param mark_success: Don't run the task, mark its state as success
+    :param test_mode: Doesn't record success or failure in the DB
+    :param pool: specifies the pool to use to run the task instance
+    :param session: SQLAlchemy ORM Session
+    """
+    if TYPE_CHECKING:
+        assert ti.task
+
+    ti.test_mode = test_mode
+    ti.refresh_from_task(ti.task, pool_override=pool)
+    ti.refresh_from_db(session=session)
+
+    ti.job_id = job_id
+    ti.hostname = get_hostname()
+    ti.pid = os.getpid()
+    if not test_mode:
+        TaskInstance.save_to_db(ti=ti, session=session)
+    actual_start_date = timezone.utcnow()
+    Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}", 
tags=ti.stats_tags)
+    # Same metric with tagging
+    Stats.incr("ti.start", tags=ti.stats_tags)
+    # Initialize final state counters at zero
+    for state in State.task_states:
+        Stats.incr(
+            f"ti.finish.{ti.task.dag_id}.{ti.task.task_id}.{state}",
+            count=0,
+            tags=ti.stats_tags,
+        )
+        # Same metric with tagging
+        Stats.incr(
+            "ti.finish",
+            count=0,
+            tags={**ti.stats_tags, "state": str(state)},
+        )
+    with set_current_task_instance_session(session=session):
+        ti.task = ti.task.prepare_for_execution()
+        context = ti.get_template_context(ignore_param_exceptions=False, 
session=session)
+
+        try:
+            if not mark_success:
+                TaskInstance._execute_task_with_callbacks(
+                    self=ti,  # type: ignore[arg-type]
+                    context=context,
+                    test_mode=test_mode,
+                    session=session,
+                )
+            if not test_mode:
+                ti.refresh_from_db(lock_for_update=True, session=session)
+            ti.state = TaskInstanceState.SUCCESS
+        except TaskDeferred as defer:
+            # The task has signalled it wants to defer execution based on
+            # a trigger.
+            if raise_on_defer:
+                raise
+            ti.defer_task(exception=defer, session=session)
+            ti.log.info(
+                "Pausing task as DEFERRED. dag_id=%s, task_id=%s, 
execution_date=%s, start_date=%s",
+                ti.dag_id,
+                ti.task_id,
+                _date_or_empty(task_instance=ti, attr="execution_date"),
+                _date_or_empty(task_instance=ti, attr="start_date"),
+            )
+            if not test_mode:
+                _add_log(event=ti.state, task_instance=ti, session=session)
+                TaskInstance.save_to_db(ti=ti, session=session)
+            return TaskReturnCode.DEFERRED
+        except AirflowSkipException as e:
+            # Recording SKIP
+            # log only if exception has any arguments to prevent log flooding
+            if e.args:
+                ti.log.info(e)
+            if not test_mode:
+                ti.refresh_from_db(lock_for_update=True, session=session)
+            ti.state = TaskInstanceState.SKIPPED
+            _run_finished_callback(callbacks=ti.task.on_skipped_callback, 
context=context)
+            TaskInstance.save_to_db(ti=ti, session=session)
+        except AirflowRescheduleException as reschedule_exception:
+            ti._handle_reschedule(actual_start_date, reschedule_exception, 
test_mode, session=session)

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to