This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b18fccb5763 AIP-72: Handle clearing of XComs when task starts 
execution (#45506)
b18fccb5763 is described below

commit b18fccb576381f8ca378acfef34103f8e691b751
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jan 9 16:38:02 2025 +0530

    AIP-72: Handle clearing of XComs when task starts execution (#45506)
---
 .../execution_api/routes/task_instances.py         | 23 +++++-
 .../execution_api/routes/test_task_instances.py    | 88 +++++++++++++++++++++-
 2 files changed, 108 insertions(+), 3 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 4956466ca70..6086e1093ce 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -43,6 +43,7 @@ from airflow.models.dagrun import DagRun as DR
 from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
+from airflow.models.xcom import XCom
 from airflow.utils import timezone
 from airflow.utils.state import State, TerminalTIState
 
@@ -73,9 +74,13 @@ def ti_run(
     # We only use UUID above for validation purposes
     ti_id_str = str(task_instance_id)
 
-    old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id == 
ti_id_str).with_for_update()
+    old = (
+        select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, 
TI.next_method)
+        .where(TI.id == ti_id_str)
+        .with_for_update()
+    )
     try:
-        (previous_state, dag_id, run_id) = session.execute(old).one()
+        (previous_state, dag_id, run_id, task_id, map_index, next_method) = 
session.execute(old).one()
     except NoResultFound:
         log.error("Task Instance %s not found", ti_id_str)
         raise HTTPException(
@@ -144,6 +149,20 @@ def ti_run(
         if not dr:
             raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id} 
not found.")
 
+        # Clear XCom data for the task instance since we are certain it is 
executing
+        # However, do not clear it for deferral
+        if not next_method:
+            if map_index < 0:
+                map_index = None
+            log.info("Clearing xcom data for task id: %s", ti_id_str)
+            XCom.clear(
+                dag_id=dag_id,
+                task_id=task_id,
+                run_id=run_id,
+                map_index=map_index,
+                session=session,
+            )
+
         return TIRunContext(
             dag_run=DagRun.model_validate(dr, from_attributes=True),
             # TODO: Add variables and connections that are needed (and has 
perms) for the task
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py 
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 497c5fbaf3f..3941835b314 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -22,7 +22,7 @@ from unittest import mock
 
 import pytest
 import uuid6
-from sqlalchemy import select
+from sqlalchemy import select, update
 from sqlalchemy.exc import SQLAlchemyError
 
 from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
@@ -136,6 +136,92 @@ class TestTIRunState:
 
         assert session.scalar(select(TaskInstance.state).where(TaskInstance.id 
== ti.id)) == initial_ti_state
 
+    def test_xcom_cleared_when_ti_runs(self, client, session, 
create_task_instance, time_machine):
+        """
+        Test that the xcoms are cleared when the Task Instance state is 
updated to running.
+        """
+        instant_str = "2024-09-30T12:00:00Z"
+        instant = timezone.parse(instant_str)
+        time_machine.move_to(instant, tick=False)
+
+        ti = create_task_instance(
+            task_id="test_xcom_cleared_when_ti_runs",
+            state=State.QUEUED,
+            session=session,
+            start_date=instant,
+        )
+        session.commit()
+
+        # Lets stage a xcom push
+        ti.xcom_push(key="key", value="value")
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/run",
+            json={
+                "state": "running",
+                "hostname": "random-hostname",
+                "unixname": "random-unixname",
+                "pid": 100,
+                "start_date": instant_str,
+            },
+        )
+
+        assert response.status_code == 200
+        # Once the task is running, we can check if xcom is cleared
+        assert ti.xcom_pull(task_ids="test_xcom_cleared_when_ti_runs", 
key="key") is None
+
+    def test_xcom_not_cleared_for_deferral(self, client, session, 
create_task_instance, time_machine):
+        """
+        Test that the xcoms are not cleared when the Task Instance state is 
re-running after deferral.
+        """
+        instant_str = "2024-09-30T12:00:00Z"
+        instant = timezone.parse(instant_str)
+        time_machine.move_to(instant, tick=False)
+
+        ti = create_task_instance(
+            task_id="test_xcom_not_cleared_for_deferral",
+            state=State.RUNNING,
+            session=session,
+            start_date=instant,
+        )
+        session.commit()
+
+        # Move this task to deferred
+        payload = {
+            "state": "deferred",
+            "trigger_kwargs": {"key": "value", "moment": 
"2024-12-18T00:00:00Z"},
+            "classpath": "my-classpath",
+            "next_method": "execute_callback",
+            "trigger_timeout": "P1D",  # 1 day
+        }
+
+        response = client.patch(f"/execution/task-instances/{ti.id}/state", 
json=payload)
+        assert response.status_code == 204
+        assert response.text == ""
+        session.expire_all()
+
+        # Deferred -> Queued so that we can run it again
+        query = update(TaskInstance).where(TaskInstance.id == 
ti.id).values(state="queued")
+        session.execute(query)
+        session.commit()
+
+        # Lets stage a xcom push
+        ti.xcom_push(key="key", value="value")
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/run",
+            json={
+                "state": "running",
+                "hostname": "random-hostname",
+                "unixname": "random-unixname",
+                "pid": 100,
+                "start_date": instant_str,
+            },
+        )
+
+        assert response.status_code == 200
+        assert ti.xcom_pull(task_ids="test_xcom_not_cleared_for_deferral", 
key="key") == "value"
+
 
 class TestTIUpdateState:
     def setup_method(self):

Reply via email to