This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d07ff507b78 Allow setting run_id in xcom_pull method (#41343)
d07ff507b78 is described below

commit d07ff507b786f560ffe2d73dd05388ee06d11ebe
Author: Fred Thomsen <[email protected]>
AuthorDate: Tue Nov 12 02:02:29 2024 -0500

    Allow setting run_id in xcom_pull method (#41343)
---
 airflow/models/taskinstance.py    | 11 ++++++++++-
 tests/models/test_taskinstance.py | 34 ++++++++++++++++++++++++++++++++++
 2 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index f2b5f5cceea..c05b1dd62ec 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -569,6 +569,7 @@ def _xcom_pull(
     session: Session = NEW_SESSION,
     map_indexes: int | Iterable[int] | None = None,
     default: Any = None,
+    run_id: str | None = None,
 ) -> Any:
     """
     Pull XComs that optionally meet certain criteria.
@@ -588,6 +589,8 @@ def _xcom_pull(
     :param include_prior_dates: If False, only XComs from the current
         execution_date are returned. If *True*, XComs from previous dates
         are returned as well.
+    :param run_id: If provided, only pulls XComs from a DagRun w/a matching 
run_id.
+        If *None* (default), the run_id of the calling task is used.
 
     When pulling one single task (``task_id`` is *None* or a str) without
     specifying ``map_indexes``, the return value is inferred from whether
@@ -603,10 +606,12 @@ def _xcom_pull(
     """
     if dag_id is None:
         dag_id = ti.dag_id
+    if run_id is None:
+        run_id = ti.run_id
 
     query = XCom.get_many(
         key=key,
-        run_id=ti.run_id,
+        run_id=run_id,
         dag_ids=dag_id,
         task_ids=task_ids,
         map_indexes=map_indexes,
@@ -3472,6 +3477,7 @@ class TaskInstance(Base, LoggingMixin):
         *,
         map_indexes: int | Iterable[int] | None = None,
         default: Any = None,
+        run_id: str | None = None,
     ) -> Any:
         """
         Pull XComs that optionally meet certain criteria.
@@ -3491,6 +3497,8 @@ class TaskInstance(Base, LoggingMixin):
         :param include_prior_dates: If False, only XComs from the current
             execution_date are returned. If *True*, XComs from previous dates
             are returned as well.
+        :param run_id: If provided, only pulls XComs from a DagRun w/a 
matching run_id.
+            If *None* (default), the run_id of the calling task is used.
 
         When pulling one single task (``task_id`` is *None* or a str) without
         specifying ``map_indexes``, the return value is inferred from whether
@@ -3513,6 +3521,7 @@ class TaskInstance(Base, LoggingMixin):
             session=session,
             map_indexes=map_indexes,
             default=default,
+            run_id=run_id,
         )
 
     @provide_session
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index d749edce1f2..8f15005f43d 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1762,6 +1762,40 @@ class TestTaskInstance:
         # We *should* get a value using 'include_prior_dates'
         assert ti.xcom_pull(task_ids="test_xcom", key=key, 
include_prior_dates=True) == value
 
+    def test_xcom_pull_different_run_ids(self, create_task_instance):
+        """
+        tests xcom fetch behavior w/different run ids
+        """
+        key = "xcom_key"
+        task_id = "test_xcom"
+        diff_run_id = "diff_run_id"
+        same_run_id_value = "xcom_value_same_run_id"
+        diff_run_id_value = "xcom_value_different_run_id"
+
+        ti_same_run_id = create_task_instance(
+            dag_id="test_xcom",
+            task_id=task_id,
+        )
+        ti_same_run_id.run(mark_success=True)
+        ti_same_run_id.xcom_push(key=key, value=same_run_id_value)
+
+        ti_diff_run_id = create_task_instance(
+            dag_id="test_xcom",
+            task_id=task_id,
+            run_id=diff_run_id,
+        )
+        ti_diff_run_id.run(mark_success=True)
+        ti_diff_run_id.xcom_push(key=key, value=diff_run_id_value)
+
+        assert (
+            ti_same_run_id.xcom_pull(run_id=ti_same_run_id.dag_run.run_id, 
task_ids=task_id, key=key)
+            == same_run_id_value
+        )
+        assert (
+            ti_same_run_id.xcom_pull(run_id=ti_diff_run_id.dag_run.run_id, 
task_ids=task_id, key=key)
+            == diff_run_id_value
+        )
+
     def test_xcom_push_flag(self, dag_maker):
         """
         Tests the option for Operators to push XComs

Reply via email to