This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b18fccb5763 AIP-72: Handle clearing of XComs when task starts
execution (#45506)
b18fccb5763 is described below
commit b18fccb576381f8ca378acfef34103f8e691b751
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jan 9 16:38:02 2025 +0530
AIP-72: Handle clearing of XComs when task starts execution (#45506)
---
.../execution_api/routes/task_instances.py | 23 +++++-
.../execution_api/routes/test_task_instances.py | 88 +++++++++++++++++++++-
2 files changed, 108 insertions(+), 3 deletions(-)
diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow/api_fastapi/execution_api/routes/task_instances.py
index 4956466ca70..6086e1093ce 100644
--- a/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ b/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -43,6 +43,7 @@ from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
+from airflow.models.xcom import XCom
from airflow.utils import timezone
from airflow.utils.state import State, TerminalTIState
@@ -73,9 +74,13 @@ def ti_run(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)
- old = select(TI.state, TI.dag_id, TI.run_id).where(TI.id ==
ti_id_str).with_for_update()
+ old = (
+ select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index,
TI.next_method)
+ .where(TI.id == ti_id_str)
+ .with_for_update()
+ )
try:
- (previous_state, dag_id, run_id) = session.execute(old).one()
+ (previous_state, dag_id, run_id, task_id, map_index, next_method) =
session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
@@ -144,6 +149,20 @@ def ti_run(
if not dr:
raise ValueError(f"DagRun with dag_id={dag_id} and run_id={run_id}
not found.")
+ # Clear XCom data for the task instance since we are certain it is
executing
+ # However, do not clear it for deferral
+ if not next_method:
+ if map_index < 0:
+ map_index = None
+ log.info("Clearing xcom data for task id: %s", ti_id_str)
+ XCom.clear(
+ dag_id=dag_id,
+ task_id=task_id,
+ run_id=run_id,
+ map_index=map_index,
+ session=session,
+ )
+
return TIRunContext(
dag_run=DagRun.model_validate(dr, from_attributes=True),
# TODO: Add variables and connections that are needed (and has
perms) for the task
diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py
b/tests/api_fastapi/execution_api/routes/test_task_instances.py
index 497c5fbaf3f..3941835b314 100644
--- a/tests/api_fastapi/execution_api/routes/test_task_instances.py
+++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py
@@ -22,7 +22,7 @@ from unittest import mock
import pytest
import uuid6
-from sqlalchemy import select
+from sqlalchemy import select, update
from sqlalchemy.exc import SQLAlchemyError
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
@@ -136,6 +136,92 @@ class TestTIRunState:
assert session.scalar(select(TaskInstance.state).where(TaskInstance.id
== ti.id)) == initial_ti_state
+ def test_xcom_cleared_when_ti_runs(self, client, session,
create_task_instance, time_machine):
+ """
+ Test that the xcoms are cleared when the Task Instance state is
updated to running.
+ """
+ instant_str = "2024-09-30T12:00:00Z"
+ instant = timezone.parse(instant_str)
+ time_machine.move_to(instant, tick=False)
+
+ ti = create_task_instance(
+ task_id="test_xcom_cleared_when_ti_runs",
+ state=State.QUEUED,
+ session=session,
+ start_date=instant,
+ )
+ session.commit()
+
+ # Lets stage a xcom push
+ ti.xcom_push(key="key", value="value")
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "random-hostname",
+ "unixname": "random-unixname",
+ "pid": 100,
+ "start_date": instant_str,
+ },
+ )
+
+ assert response.status_code == 200
+ # Once the task is running, we can check if xcom is cleared
+ assert ti.xcom_pull(task_ids="test_xcom_cleared_when_ti_runs",
key="key") is None
+
+ def test_xcom_not_cleared_for_deferral(self, client, session,
create_task_instance, time_machine):
+ """
+ Test that the xcoms are not cleared when the Task Instance state is
re-running after deferral.
+ """
+ instant_str = "2024-09-30T12:00:00Z"
+ instant = timezone.parse(instant_str)
+ time_machine.move_to(instant, tick=False)
+
+ ti = create_task_instance(
+ task_id="test_xcom_not_cleared_for_deferral",
+ state=State.RUNNING,
+ session=session,
+ start_date=instant,
+ )
+ session.commit()
+
+ # Move this task to deferred
+ payload = {
+ "state": "deferred",
+ "trigger_kwargs": {"key": "value", "moment":
"2024-12-18T00:00:00Z"},
+ "classpath": "my-classpath",
+ "next_method": "execute_callback",
+ "trigger_timeout": "P1D", # 1 day
+ }
+
+ response = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
+ assert response.status_code == 204
+ assert response.text == ""
+ session.expire_all()
+
+ # Deferred -> Queued so that we can run it again
+ query = update(TaskInstance).where(TaskInstance.id ==
ti.id).values(state="queued")
+ session.execute(query)
+ session.commit()
+
+ # Lets stage a xcom push
+ ti.xcom_push(key="key", value="value")
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "random-hostname",
+ "unixname": "random-unixname",
+ "pid": 100,
+ "start_date": instant_str,
+ },
+ )
+
+ assert response.status_code == 200
+ assert ti.xcom_pull(task_ids="test_xcom_not_cleared_for_deferral",
key="key") == "value"
+
class TestTIUpdateState:
def setup_method(self):