This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new af130c0df1c AIP-72: Push XCom on Task Return (#45245)
af130c0df1c is described below

commit af130c0df1cd02bc0738504c1af63522d4f5be41
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Dec 27 23:27:52 2024 +0530

    AIP-72: Push XCom on Task Return (#45245)
    
    closes https://github.com/apache/airflow/issues/45230
---
 .../src/airflow/sdk/execution_time/task_runner.py  |  48 +++-
 task_sdk/tests/execution_time/conftest.py          |   9 +
 task_sdk/tests/execution_time/test_context.py      |  39 +--
 task_sdk/tests/execution_time/test_task_runner.py  | 290 +++++++++++++++------
 4 files changed, 273 insertions(+), 113 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d8d540318f1..810b108e221 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 
 import os
 import sys
-from collections.abc import Iterable
+from collections.abc import Iterable, Mapping
 from datetime import datetime, timezone
 from io import FileIO
 from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
@@ -197,7 +197,11 @@ class RuntimeTaskInstance(TaskInstance):
 
         value = msg.value
         if value is not None:
-            return value
+            from airflow.models.xcom import XCom
+
+            # TODO: Move XCom serialization & deserialization to Task SDK
+            #   https://github.com/apache/airflow/issues/45231
+            return XCom.deserialize_value(value)
         return default
 
     def xcom_push(self, key: str, value: Any):
@@ -207,6 +211,12 @@ class RuntimeTaskInstance(TaskInstance):
         :param key: Key to store the value under.
         :param value: Value to store. Only be JSON-serializable may be used 
otherwise.
         """
+        from airflow.models.xcom import XCom
+
+        # TODO: Move XCom serialization & deserialization to Task SDK
+        #   https://github.com/apache/airflow/issues/45231
+        value = XCom.serialize_value(value)
+
         log = structlog.get_logger(logger_name="task")
         SUPERVISOR_COMMS.send_request(
             log=log,
@@ -381,7 +391,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         #   - Update RTIF
         #   - Pre Execute
         #   etc
-        ti.task.execute(context)  # type: ignore[attr-defined]
+        result = ti.task.execute(context)  # type: ignore[attr-defined]
+        _push_xcom_if_needed(result, ti)
+
         msg = TaskState(state=TerminalTIState.SUCCESS, 
end_date=datetime.now(tz=timezone.utc))
     except TaskDeferred as defer:
         classpath, trigger_kwargs = defer.trigger.serialize()
@@ -436,6 +448,36 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         SUPERVISOR_COMMS.send_request(msg=msg, log=log)
 
 
+def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance):
+    """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the 
task returns a result."""
+    if ti.task.do_xcom_push:
+        xcom_value = result
+    else:
+        xcom_value = None
+
+    # If the task returns a result, push an XCom containing it.
+    if xcom_value is None:
+        return
+
+    # If the task has multiple outputs, push each output as a separate XCom.
+    if ti.task.multiple_outputs:
+        if not isinstance(xcom_value, Mapping):
+            raise TypeError(
+                f"Returned output was type {type(xcom_value)} expected 
dictionary for multiple_outputs"
+            )
+        for key in xcom_value.keys():
+            if not isinstance(key, str):
+                raise TypeError(
+                    "Returned dictionary keys must be strings when using "
+                    f"multiple_outputs, found {key} ({type(key)}) instead"
+                )
+        for k, v in result.items():
+            ti.xcom_push(k, v)
+
+    # TODO: Use constant for XCom return key & use serialize_value from Task 
SDK
+    ti.xcom_push("return_value", result)
+
+
 def finalize(log: Logger): ...
 
 
diff --git a/task_sdk/tests/execution_time/conftest.py 
b/task_sdk/tests/execution_time/conftest.py
index 4a537373363..bf482e5ec7b 100644
--- a/task_sdk/tests/execution_time/conftest.py
+++ b/task_sdk/tests/execution_time/conftest.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import sys
+from unittest import mock
 
 import pytest
 
@@ -31,3 +32,11 @@ def disable_capturing():
     sys.stderr = sys.__stderr__
     yield
     sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err
+
+
[email protected]
+def mock_supervisor_comms():
+    with mock.patch(
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+    ) as supervisor_comms:
+        yield supervisor_comms
diff --git a/task_sdk/tests/execution_time/test_context.py 
b/task_sdk/tests/execution_time/test_context.py
index a3220c3bef1..34502a1a917 100644
--- a/task_sdk/tests/execution_time/test_context.py
+++ b/task_sdk/tests/execution_time/test_context.py
@@ -17,8 +17,6 @@
 
 from __future__ import annotations
 
-from unittest import mock
-
 from airflow.sdk.definitions.connection import Connection
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
@@ -51,7 +49,7 @@ def test_convert_connection_result_conn():
 
 
 class TestConnectionAccessor:
-    def test_getattr_connection(self):
+    def test_getattr_connection(self, mock_supervisor_comms):
         """
         Test that the connection is fetched when accessed via __getattr__.
 
@@ -62,31 +60,25 @@ class TestConnectionAccessor:
         # Conn from the supervisor / API Server
         conn_result = ConnectionResult(conn_id="mysql_conn", 
conn_type="mysql", host="mysql", port=3306)
 
-        with mock.patch(
-            "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True
-        ) as mock_supervisor_comms:
-            mock_supervisor_comms.get_message.return_value = conn_result
+        mock_supervisor_comms.get_message.return_value = conn_result
 
-            # Fetch the connection; triggers __getattr__
-            conn = accessor.mysql_conn
+        # Fetch the connection; triggers __getattr__
+        conn = accessor.mysql_conn
 
-            expected_conn = Connection(conn_id="mysql_conn", 
conn_type="mysql", host="mysql", port=3306)
-            assert conn == expected_conn
+        expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", 
host="mysql", port=3306)
+        assert conn == expected_conn
 
-    def test_get_method_valid_connection(self):
+    def test_get_method_valid_connection(self, mock_supervisor_comms):
         """Test that the get method returns the requested connection using 
`conn.get`."""
         accessor = ConnectionAccessor()
         conn_result = ConnectionResult(conn_id="mysql_conn", 
conn_type="mysql", host="mysql", port=3306)
 
-        with mock.patch(
-            "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True
-        ) as mock_supervisor_comms:
-            mock_supervisor_comms.get_message.return_value = conn_result
+        mock_supervisor_comms.get_message.return_value = conn_result
 
-            conn = accessor.get("mysql_conn")
-            assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", 
host="mysql", port=3306)
+        conn = accessor.get("mysql_conn")
+        assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", 
host="mysql", port=3306)
 
-    def test_get_method_with_default(self):
+    def test_get_method_with_default(self, mock_supervisor_comms):
         """Test that the get method returns the default connection when the 
requested connection is not found."""
         accessor = ConnectionAccessor()
         default_conn = {"conn_id": "default_conn", "conn_type": "sqlite"}
@@ -94,10 +86,7 @@ class TestConnectionAccessor:
             error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": 
"nonexistent_conn"}
         )
 
-        with mock.patch(
-            "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True
-        ) as mock_supervisor_comms:
-            mock_supervisor_comms.get_message.return_value = error_response
+        mock_supervisor_comms.get_message.return_value = error_response
 
-            conn = accessor.get("nonexistent_conn", default_conn=default_conn)
-            assert conn == default_conn
+        conn = accessor.get("nonexistent_conn", default_conn=default_conn)
+        assert conn == default_conn
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 46bb75ac61a..48ab35709bb 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -43,7 +43,14 @@ from airflow.sdk.execution_time.comms import (
     TaskState,
 )
 from airflow.sdk.execution_time.context import ConnectionAccessor
-from airflow.sdk.execution_time.task_runner import CommsDecoder, 
RuntimeTaskInstance, parse, run, startup
+from airflow.sdk.execution_time.task_runner import (
+    CommsDecoder,
+    RuntimeTaskInstance,
+    _push_xcom_if_needed,
+    parse,
+    run,
+    startup,
+)
 from airflow.utils import timezone
 
 
@@ -147,7 +154,7 @@ def test_parse(test_dags_dir: Path, make_ti_context):
     assert isinstance(ti.task.dag, DAG)
 
 
-def test_run_basic(time_machine, mocked_parse, make_ti_context, spy_agency):
+def test_run_basic(time_machine, mocked_parse, make_ti_context, spy_agency, 
mock_supervisor_comms):
     """Test running a basic task."""
     what = StartupDetails(
         ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", 
run_id="c", try_number=1),
@@ -159,26 +166,23 @@ def test_run_basic(time_machine, mocked_parse, 
make_ti_context, spy_agency):
     instant = timezone.datetime(2024, 12, 3, 10, 0)
     time_machine.move_to(instant, tick=False)
 
-    with mock.patch(
-        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
-    ) as mock_supervisor_comms:
-        ti = mocked_parse(what, "super_basic_run", 
CustomOperator(task_id="hello"))
+    ti = mocked_parse(what, "super_basic_run", CustomOperator(task_id="hello"))
 
-        # Ensure that task is locked for execution
-        spy_agency.spy_on(ti.task.prepare_for_execution)
-        assert not ti.task._lock_for_execution
+    # Ensure that task is locked for execution
+    spy_agency.spy_on(ti.task.prepare_for_execution)
+    assert not ti.task._lock_for_execution
 
-        run(ti, log=mock.MagicMock())
+    run(ti, log=mock.MagicMock())
 
-        spy_agency.assert_spy_called(ti.task.prepare_for_execution)
-        assert ti.task._lock_for_execution
+    spy_agency.assert_spy_called(ti.task.prepare_for_execution)
+    assert ti.task._lock_for_execution
 
-        mock_supervisor_comms.send_request.assert_called_once_with(
-            msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), 
log=mock.ANY
-        )
+    mock_supervisor_comms.send_request.assert_called_once_with(
+        msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), 
log=mock.ANY
+    )
 
 
-def test_run_deferred_basic(time_machine, mocked_parse, make_ti_context):
+def test_run_deferred_basic(time_machine, mocked_parse, make_ti_context, 
mock_supervisor_comms):
     """Test that a task can transition to a deferred state."""
     import datetime
 
@@ -213,17 +217,14 @@ def test_run_deferred_basic(time_machine, mocked_parse, 
make_ti_context):
     )
 
     # Run the task
-    with mock.patch(
-        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
-    ) as mock_supervisor_comms:
-        ti = mocked_parse(what, "basic_deferred_run", task)
-        run(ti, log=mock.MagicMock())
+    ti = mocked_parse(what, "basic_deferred_run", task)
+    run(ti, log=mock.MagicMock())
 
-        # send_request will only be called when the TaskDeferred exception is 
raised
-        
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task,
 log=mock.ANY)
+    # send_request will only be called when the TaskDeferred exception is 
raised
+    
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task,
 log=mock.ANY)
 
 
-def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
+def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context, 
mock_supervisor_comms):
     """Test running a basic task that marks itself skipped."""
     from airflow.providers.standard.operators.python import PythonOperator
 
@@ -246,14 +247,11 @@ def test_run_basic_skipped(time_machine, mocked_parse, 
make_ti_context):
     instant = timezone.datetime(2024, 12, 3, 10, 0)
     time_machine.move_to(instant, tick=False)
 
-    with mock.patch(
-        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
-    ) as mock_supervisor_comms:
-        run(ti, log=mock.MagicMock())
+    run(ti, log=mock.MagicMock())
 
-        mock_supervisor_comms.send_request.assert_called_once_with(
-            msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), 
log=mock.ANY
-        )
+    mock_supervisor_comms.send_request.assert_called_once_with(
+        msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant), 
log=mock.ANY
+    )
 
 
 @pytest.mark.parametrize(
@@ -283,7 +281,7 @@ def test_run_basic_skipped(time_machine, mocked_parse, 
make_ti_context):
     ],
 )
 def test_startup_and_run_dag_with_templated_fields(
-    mocked_parse, task_params, expected_rendered_fields, make_ti_context, 
time_machine
+    mocked_parse, task_params, expected_rendered_fields, make_ti_context, 
time_machine, mock_supervisor_comms
 ):
     """Test startup of a DAG with various templated fields."""
 
@@ -311,24 +309,22 @@ def test_startup_and_run_dag_with_templated_fields(
 
     instant = timezone.datetime(2024, 12, 3, 10, 0)
     time_machine.move_to(instant, tick=False)
-    with mock.patch(
-        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
-    ) as mock_supervisor_comms:
-        mock_supervisor_comms.get_message.return_value = what
 
-        startup()
-        run(ti, log=mock.MagicMock())
-        expected_calls = [
-            mock.call.send_request(
-                
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
-                log=mock.ANY,
-            ),
-            mock.call.send_request(
-                msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
-                log=mock.ANY,
-            ),
-        ]
-        mock_supervisor_comms.assert_has_calls(expected_calls)
+    mock_supervisor_comms.get_message.return_value = what
+
+    startup()
+    run(ti, log=mock.MagicMock())
+    expected_calls = [
+        mock.call.send_request(
+            msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
+            log=mock.ANY,
+        ),
+        mock.call.send_request(
+            msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
+            log=mock.ANY,
+        ),
+    ]
+    mock_supervisor_comms.assert_has_calls(expected_calls)
 
 
 @pytest.mark.parametrize(
@@ -349,7 +345,9 @@ def test_startup_and_run_dag_with_templated_fields(
         ),
     ],
 )
-def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, 
fail_with_exception, make_ti_context):
+def test_run_basic_failed(
+    time_machine, mocked_parse, dag_id, task_id, fail_with_exception, 
make_ti_context, mock_supervisor_comms
+):
     """Test running a basic task that marks itself as failed by raising 
exception."""
 
     class CustomOperator(BaseOperator):
@@ -375,14 +373,11 @@ def test_run_basic_failed(time_machine, mocked_parse, 
dag_id, task_id, fail_with
     instant = timezone.datetime(2024, 12, 3, 10, 0)
     time_machine.move_to(instant, tick=False)
 
-    with mock.patch(
-        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
-    ) as mock_supervisor_comms:
-        run(ti, log=mock.MagicMock())
+    run(ti, log=mock.MagicMock())
 
-        mock_supervisor_comms.send_request.assert_called_once_with(
-            msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), 
log=mock.ANY
-        )
+    mock_supervisor_comms.send_request.assert_called_once_with(
+        msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), 
log=mock.ANY
+    )
 
 
 class TestRuntimeTaskInstance:
@@ -456,7 +451,7 @@ class TestRuntimeTaskInstance:
             "ts_nodash_with_tz": "20241201T010000+0000",
         }
 
-    def test_get_connection_from_context(self, mocked_parse, make_ti_context):
+    def test_get_connection_from_context(self, mocked_parse, make_ti_context, 
mock_supervisor_comms):
         """Test that the connection is fetched from the API server via the 
Supervisor lazily when accessed"""
 
         task = BaseOperator(task_id="hello")
@@ -478,34 +473,159 @@ class TestRuntimeTaskInstance:
 
         what = StartupDetails(ti=ti, file="", requests_fd=0, 
ti_context=make_ti_context())
         runtime_ti = mocked_parse(what, ti.dag_id, task)
-        with mock.patch(
-            "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True
-        ) as mock_supervisor_comms:
-            mock_supervisor_comms.get_message.return_value = conn
+        mock_supervisor_comms.get_message.return_value = conn
 
-            context = runtime_ti.get_template_context()
+        context = runtime_ti.get_template_context()
 
-            # Assert that the connection is not fetched from the API server 
yet!
-            # The connection should be only fetched connection is accessed
-            mock_supervisor_comms.send_request.assert_not_called()
-            mock_supervisor_comms.get_message.assert_not_called()
+        # Assert that the connection is not fetched from the API server yet!
+        # The connection should be only fetched connection is accessed
+        mock_supervisor_comms.send_request.assert_not_called()
+        mock_supervisor_comms.get_message.assert_not_called()
 
-            # Access the connection from the context
-            conn_from_context = context["conn"].test_conn
+        # Access the connection from the context
+        conn_from_context = context["conn"].test_conn
 
-            mock_supervisor_comms.send_request.assert_called_once_with(
-                log=mock.ANY, msg=GetConnection(conn_id="test_conn")
-            )
-            mock_supervisor_comms.get_message.assert_called_once_with()
-
-            assert conn_from_context == Connection(
-                conn_id="test_conn",
-                conn_type="mysql",
-                description=None,
-                host="mysql",
-                schema="airflow",
-                login="root",
-                password="password",
-                port=1234,
-                extra='{"extra_key": "extra_value"}',
-            )
+        mock_supervisor_comms.send_request.assert_called_once_with(
+            log=mock.ANY, msg=GetConnection(conn_id="test_conn")
+        )
+        mock_supervisor_comms.get_message.assert_called_once_with()
+
+        assert conn_from_context == Connection(
+            conn_id="test_conn",
+            conn_type="mysql",
+            description=None,
+            host="mysql",
+            schema="airflow",
+            login="root",
+            password="password",
+            port=1234,
+            extra='{"extra_key": "extra_value"}',
+        )
+
+
+class TestXComAfterTaskExecution:
+    @pytest.mark.parametrize(
+        ["do_xcom_push", "should_push_xcom", "expected_xcom_value"],
+        [
+            pytest.param(False, False, None, id="do_xcom_push_false"),
+            pytest.param(True, True, "Hello World!", id="do_xcom_push_true"),
+        ],
+    )
+    def test_xcom_push_flag(
+        self,
+        mocked_parse,
+        make_ti_context,
+        mock_supervisor_comms,
+        spy_agency,
+        do_xcom_push: bool,
+        should_push_xcom: bool,
+        expected_xcom_value,
+    ):
+        """Test that the do_xcom_push flag controls whether the task pushes to 
XCom."""
+
+        class CustomOperator(BaseOperator):
+            def execute(self, context):
+                return "Hello World!"
+
+        task = CustomOperator(task_id="hello", do_xcom_push=do_xcom_push)
+
+        ti = TaskInstance(
+            id=uuid7(), task_id=task.task_id, dag_id="xcom_push_flag", 
run_id="test_run", try_number=1
+        )
+
+        what = StartupDetails(ti=ti, file="", requests_fd=0, 
ti_context=make_ti_context())
+        runtime_ti = mocked_parse(what, ti.dag_id, task)
+
+        spy_agency.spy_on(_push_xcom_if_needed, call_original=True)
+        spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+
+        run(runtime_ti, log=mock.MagicMock())
+
+        spy_agency.assert_spy_called(_push_xcom_if_needed)
+
+        if should_push_xcom:
+            spy_agency.assert_spy_called_with(runtime_ti.xcom_push, 
"return_value", expected_xcom_value)
+        else:
+            spy_agency.assert_spy_not_called(runtime_ti.xcom_push)
+
+    def test_xcom_with_multiple_outputs(self, mocked_parse, spy_agency):
+        """Test that the task pushes to XCom when multiple outputs are 
returned."""
+        result = {"key1": "value1", "key2": "value2"}
+
+        class CustomOperator(BaseOperator):
+            def execute(self, context):
+                return result
+
+        task = CustomOperator(
+            task_id="test_xcom_push_with_multiple_outputs", do_xcom_push=True, 
multiple_outputs=True
+        )
+        dag = get_inline_dag(dag_id="test_dag", task=task)
+        ti = TaskInstance(
+            id=uuid7(), task_id=task.task_id, dag_id=dag.dag_id, 
run_id="test_run", try_number=1
+        )
+
+        runtime_ti = 
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), 
task=task)
+
+        spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+        _push_xcom_if_needed(result=result, ti=runtime_ti)
+
+        expected_calls = [
+            ("key1", "value1"),
+            ("key2", "value2"),
+            ("return_value", result),
+        ]
+        spy_agency.assert_spy_call_count(runtime_ti.xcom_push, 
len(expected_calls))
+        for key, value in expected_calls:
+            spy_agency.assert_spy_called_with(runtime_ti.xcom_push, key, value)
+
+    def test_xcom_with_multiple_outputs_and_no_mapping_result(self, 
mocked_parse, spy_agency):
+        """Test that error is raised when multiple outputs are returned 
without mapping."""
+        result = "value1"
+
+        class CustomOperator(BaseOperator):
+            def execute(self, context):
+                return result
+
+        task = CustomOperator(
+            task_id="test_xcom_push_with_multiple_outputs", do_xcom_push=True, 
multiple_outputs=True
+        )
+        dag = get_inline_dag(dag_id="test_dag", task=task)
+        ti = TaskInstance(
+            id=uuid7(), task_id=task.task_id, dag_id=dag.dag_id, 
run_id="test_run", try_number=1
+        )
+
+        runtime_ti = 
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), 
task=task)
+
+        spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+        with pytest.raises(
+            TypeError,
+            match=f"Returned output was type {type(result)} expected 
dictionary for multiple_outputs",
+        ):
+            _push_xcom_if_needed(result=result, ti=runtime_ti)
+
+    def test_xcom_with_multiple_outputs_and_key_is_not_string(self, 
mocked_parse, spy_agency):
+        """Test that error is raised when multiple outputs are returned and 
key isn't string."""
+        result = {2: "value1", "key2": "value2"}
+
+        class CustomOperator(BaseOperator):
+            def execute(self, context):
+                return result
+
+        task = CustomOperator(
+            task_id="test_xcom_push_with_multiple_outputs", do_xcom_push=True, 
multiple_outputs=True
+        )
+        dag = get_inline_dag(dag_id="test_dag", task=task)
+        ti = TaskInstance(
+            id=uuid7(), task_id=task.task_id, dag_id=dag.dag_id, 
run_id="test_run", try_number=1
+        )
+
+        runtime_ti = 
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True), 
task=task)
+
+        spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+
+        with pytest.raises(TypeError) as exc_info:
+            _push_xcom_if_needed(result=result, ti=runtime_ti)
+
+        assert str(exc_info.value) == (
+            f"Returned dictionary keys must be strings when using 
multiple_outputs, found 2 ({int}) instead"
+        )

Reply via email to