potiuk commented on code in PR #38992:
URL: https://github.com/apache/airflow/pull/38992#discussion_r1570989592


##########
airflow/models/taskinstance.py:
##########
@@ -186,6 +186,193 @@ class TaskReturnCode(Enum):
     """When task exits with deferral to trigger."""
 
 
+@internal_api_call
+@provide_session
+def _merge_ti(ti, session: Session = NEW_SESSION):
+    session.merge(ti)
+    session.commit()
+
+
+@internal_api_call
+@provide_session
+def _add_log(
+    event,
+    task_instance=None,
+    owner=None,
+    owner_display_name=None,
+    extra=None,
+    session: Session = NEW_SESSION,
+    **kwargs,
+):
+    session.add(
+        Log(
+            event,
+            task_instance,
+            owner,
+            owner_display_name,
+            extra,
+            **kwargs,
+        )
+    )
+
+
+def _run_raw_task(
+    ti: TaskInstance | TaskInstancePydantic,
+    mark_success: bool = False,
+    test_mode: bool = False,
+    job_id: str | None = None,
+    pool: str | None = None,
+    raise_on_defer: bool = False,
+    session: Session | None = None,
+) -> TaskReturnCode | None:
+    """
+    Run a task, update the state upon completion, and run any appropriate 
callbacks.
+
+    Immediately runs the task (without checking or changing db state
+    before execution) and then sets the appropriate final state after
+    completion and runs any post-execute callbacks. Meant to be called
+    only after another function changes the state to running.
+
+    :param mark_success: Don't run the task, mark its state as success
+    :param test_mode: Doesn't record success or failure in the DB
+    :param pool: specifies the pool to use to run the task instance
+    :param session: SQLAlchemy ORM Session
+    """
+    if TYPE_CHECKING:
+        assert ti.task
+
+    ti.test_mode = test_mode
+    ti.refresh_from_task(ti.task, pool_override=pool)
+    ti.refresh_from_db(session=session)
+
+    ti.job_id = job_id
+    ti.hostname = get_hostname()
+    ti.pid = os.getpid()
+    if not test_mode:
+        TaskInstance.save_to_db(ti=ti, session=session)
+    actual_start_date = timezone.utcnow()
+    Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}", 
tags=ti.stats_tags)
+    # Same metric with tagging
+    Stats.incr("ti.start", tags=ti.stats_tags)
+    # Initialize final state counters at zero
+    for state in State.task_states:
+        Stats.incr(
+            f"ti.finish.{ti.task.dag_id}.{ti.task.task_id}.{state}",
+            count=0,
+            tags=ti.stats_tags,
+        )
+        # Same metric with tagging
+        Stats.incr(
+            "ti.finish",
+            count=0,
+            tags={**ti.stats_tags, "state": str(state)},
+        )
+    with set_current_task_instance_session(session=session):
+        ti.task = ti.task.prepare_for_execution()
+        context = ti.get_template_context(ignore_param_exceptions=False, 
session=session)
+
+        try:
+            if not mark_success:
+                TaskInstance._execute_task_with_callbacks(
+                    self=ti,  # type: ignore[arg-type]
+                    context=context,
+                    test_mode=test_mode,
+                    session=session,
+                )
+            if not test_mode:
+                ti.refresh_from_db(lock_for_update=True, session=session)
+            ti.state = TaskInstanceState.SUCCESS
+        except TaskDeferred as defer:
+            # The task has signalled it wants to defer execution based on
+            # a trigger.
+            if raise_on_defer:
+                raise
+            ti.defer_task(exception=defer, session=session)
+            ti.log.info(
+                "Pausing task as DEFERRED. dag_id=%s, task_id=%s, 
execution_date=%s, start_date=%s",
+                ti.dag_id,
+                ti.task_id,
+                _date_or_empty(task_instance=ti, attr="execution_date"),
+                _date_or_empty(task_instance=ti, attr="start_date"),
+            )
+            if not test_mode:
+                _add_log(event=ti.state, task_instance=ti, session=session)
+                TaskInstance.save_to_db(ti=ti, session=session)

Review Comment:
   Yes I understand that every RPC call is split, but that's the problem 
because it should not be.
   
   > It may be that this old session.merge(self); session.commit() is 
completely unnecessary and we can simply remove it. but i kept it just to keep 
same behavior.
   
   No - it's necessary as I understand it, It is actually committing all the 
changes that `defer_task` method made - because there was no commit in 
`defer_task` - only flush (which does not commit the changes, only synchronizes 
them with the DB - so it performs INSERT, but does not commit it). 
   
   Previously the commit happened only when "test_mode" was set to `false` and 
it was commiting things together:
   
   * Triggers added
   * Log entird added
   * all changes made in TI (because of merge)
   
   This was all single commit  -> we should retain that behaviour. Generall 
assumption about internal_api calls was that when they make changes  - they 
should span single transaction - and here transaction boundary is at the 
`commit` 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to