potiuk commented on code in PR #38992:
URL: https://github.com/apache/airflow/pull/38992#discussion_r1570785504
##########
airflow/models/taskinstance.py:
##########
@@ -186,6 +186,193 @@ class TaskReturnCode(Enum):
"""When task exits with deferral to trigger."""
+@internal_api_call
+@provide_session
+def _merge_ti(ti, session: Session = NEW_SESSION):
+ session.merge(ti)
+ session.commit()
+
+
+@internal_api_call
+@provide_session
+def _add_log(
+ event,
+ task_instance=None,
+ owner=None,
+ owner_display_name=None,
+ extra=None,
+ session: Session = NEW_SESSION,
+ **kwargs,
+):
+ session.add(
+ Log(
+ event,
+ task_instance,
+ owner,
+ owner_display_name,
+ extra,
+ **kwargs,
+ )
+ )
+
+
+def _run_raw_task(
+ ti: TaskInstance | TaskInstancePydantic,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ job_id: str | None = None,
+ pool: str | None = None,
+ raise_on_defer: bool = False,
+ session: Session | None = None,
+) -> TaskReturnCode | None:
+ """
+ Run a task, update the state upon completion, and run any appropriate
callbacks.
+
+ Immediately runs the task (without checking or changing db state
+ before execution) and then sets the appropriate final state after
+ completion and runs any post-execute callbacks. Meant to be called
+ only after another function changes the state to running.
+
+ :param mark_success: Don't run the task, mark its state as success
+ :param test_mode: Doesn't record success or failure in the DB
+ :param pool: specifies the pool to use to run the task instance
+ :param session: SQLAlchemy ORM Session
+ """
+ if TYPE_CHECKING:
+ assert ti.task
+
+ ti.test_mode = test_mode
+ ti.refresh_from_task(ti.task, pool_override=pool)
+ ti.refresh_from_db(session=session)
+
+ ti.job_id = job_id
+ ti.hostname = get_hostname()
+ ti.pid = os.getpid()
+ if not test_mode:
+ TaskInstance.save_to_db(ti=ti, session=session)
+ actual_start_date = timezone.utcnow()
+ Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}",
tags=ti.stats_tags)
+ # Same metric with tagging
+ Stats.incr("ti.start", tags=ti.stats_tags)
+ # Initialize final state counters at zero
+ for state in State.task_states:
+ Stats.incr(
+ f"ti.finish.{ti.task.dag_id}.{ti.task.task_id}.{state}",
+ count=0,
+ tags=ti.stats_tags,
+ )
+ # Same metric with tagging
+ Stats.incr(
+ "ti.finish",
+ count=0,
+ tags={**ti.stats_tags, "state": str(state)},
+ )
+ with set_current_task_instance_session(session=session):
+ ti.task = ti.task.prepare_for_execution()
+ context = ti.get_template_context(ignore_param_exceptions=False,
session=session)
+
+ try:
+ if not mark_success:
+ TaskInstance._execute_task_with_callbacks(
+ self=ti, # type: ignore[arg-type]
+ context=context,
+ test_mode=test_mode,
+ session=session,
+ )
+ if not test_mode:
+ ti.refresh_from_db(lock_for_update=True, session=session)
+ ti.state = TaskInstanceState.SUCCESS
+ except TaskDeferred as defer:
+ # The task has signalled it wants to defer execution based on
+ # a trigger.
+ if raise_on_defer:
+ raise
+ ti.defer_task(exception=defer, session=session)
+ ti.log.info(
+ "Pausing task as DEFERRED. dag_id=%s, task_id=%s,
execution_date=%s, start_date=%s",
+ ti.dag_id,
+ ti.task_id,
+ _date_or_empty(task_instance=ti, attr="execution_date"),
+ _date_or_empty(task_instance=ti, attr="start_date"),
+ )
+ if not test_mode:
+ _add_log(event=ti.state, task_instance=ti, session=session)
+ TaskInstance.save_to_db(ti=ti, session=session)
+ return TaskReturnCode.DEFERRED
+ except AirflowSkipException as e:
+ # Recording SKIP
+ # log only if exception has any arguments to prevent log flooding
+ if e.args:
+ ti.log.info(e)
+ if not test_mode:
+ ti.refresh_from_db(lock_for_update=True, session=session)
+ ti.state = TaskInstanceState.SKIPPED
+ _run_finished_callback(callbacks=ti.task.on_skipped_callback,
context=context)
+ TaskInstance.save_to_db(ti=ti, session=session)
+ except AirflowRescheduleException as reschedule_exception:
+ ti._handle_reschedule(actual_start_date, reschedule_exception,
test_mode, session=session)
Review Comment:
I think `_handle_reschedule` should be simply turned into an `@internal_api`
call same as `handle_failure` (and then `save_to_db` would not be needed at all
as well - because merge/commit is already done inside `_handle_reschedule`.
That would also solve the `todo:` below where this was supposed to be done.
##########
airflow/models/taskinstance.py:
##########
@@ -186,6 +186,193 @@ class TaskReturnCode(Enum):
"""When task exits with deferral to trigger."""
+@internal_api_call
+@provide_session
+def _merge_ti(ti, session: Session = NEW_SESSION):
+ session.merge(ti)
+ session.commit()
+
+
+@internal_api_call
+@provide_session
+def _add_log(
+ event,
+ task_instance=None,
+ owner=None,
+ owner_display_name=None,
+ extra=None,
+ session: Session = NEW_SESSION,
+ **kwargs,
+):
+ session.add(
+ Log(
+ event,
+ task_instance,
+ owner,
+ owner_display_name,
+ extra,
+ **kwargs,
+ )
+ )
+
+
+def _run_raw_task(
+ ti: TaskInstance | TaskInstancePydantic,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ job_id: str | None = None,
+ pool: str | None = None,
+ raise_on_defer: bool = False,
+ session: Session | None = None,
+) -> TaskReturnCode | None:
+ """
+ Run a task, update the state upon completion, and run any appropriate
callbacks.
+
+ Immediately runs the task (without checking or changing db state
+ before execution) and then sets the appropriate final state after
+ completion and runs any post-execute callbacks. Meant to be called
+ only after another function changes the state to running.
+
+ :param mark_success: Don't run the task, mark its state as success
+ :param test_mode: Doesn't record success or failure in the DB
+ :param pool: specifies the pool to use to run the task instance
+ :param session: SQLAlchemy ORM Session
+ """
+ if TYPE_CHECKING:
+ assert ti.task
+
+ ti.test_mode = test_mode
+ ti.refresh_from_task(ti.task, pool_override=pool)
+ ti.refresh_from_db(session=session)
+
+ ti.job_id = job_id
+ ti.hostname = get_hostname()
+ ti.pid = os.getpid()
+ if not test_mode:
+ TaskInstance.save_to_db(ti=ti, session=session)
+ actual_start_date = timezone.utcnow()
+ Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}",
tags=ti.stats_tags)
+ # Same metric with tagging
+ Stats.incr("ti.start", tags=ti.stats_tags)
+ # Initialize final state counters at zero
+ for state in State.task_states:
+ Stats.incr(
+ f"ti.finish.{ti.task.dag_id}.{ti.task.task_id}.{state}",
+ count=0,
+ tags=ti.stats_tags,
+ )
+ # Same metric with tagging
+ Stats.incr(
+ "ti.finish",
+ count=0,
+ tags={**ti.stats_tags, "state": str(state)},
+ )
+ with set_current_task_instance_session(session=session):
+ ti.task = ti.task.prepare_for_execution()
+ context = ti.get_template_context(ignore_param_exceptions=False,
session=session)
+
+ try:
+ if not mark_success:
+ TaskInstance._execute_task_with_callbacks(
+ self=ti, # type: ignore[arg-type]
+ context=context,
+ test_mode=test_mode,
+ session=session,
+ )
+ if not test_mode:
+ ti.refresh_from_db(lock_for_update=True, session=session)
+ ti.state = TaskInstanceState.SUCCESS
+ except TaskDeferred as defer:
+ # The task has signalled it wants to defer execution based on
+ # a trigger.
+ if raise_on_defer:
+ raise
+ ti.defer_task(exception=defer, session=session)
+ ti.log.info(
+ "Pausing task as DEFERRED. dag_id=%s, task_id=%s,
execution_date=%s, start_date=%s",
+ ti.dag_id,
+ ti.task_id,
+ _date_or_empty(task_instance=ti, attr="execution_date"),
+ _date_or_empty(task_instance=ti, attr="start_date"),
+ )
+ if not test_mode:
+ _add_log(event=ti.state, task_instance=ti, session=session)
+ TaskInstance.save_to_db(ti=ti, session=session)
Review Comment:
I think we should merge this "save_to_db" into `defer_task`
method. there is n point (and it's even harmful) to split creatio of
Trigger and changing stte of the task instance - this should all be done in a
single `internal_api` call I think
##########
airflow/models/taskinstance.py:
##########
@@ -2555,7 +2811,7 @@ def _run_raw_task(
# a trigger.
if raise_on_defer:
raise
- self.defer_task(defer=defer, session=session)
+ self.defer_task(exception=defer, session=session)
Review Comment:
I believe this whole `_run_raw_task` implementation should not be duplicated
here? We should be able to provide one implementation that will serve both DB
isolation/No DB isolation cases.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]