This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new af130c0df1c AIP-72: Push XCom on Task Return (#45245)
af130c0df1c is described below
commit af130c0df1cd02bc0738504c1af63522d4f5be41
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Dec 27 23:27:52 2024 +0530
AIP-72: Push XCom on Task Return (#45245)
closes https://github.com/apache/airflow/issues/45230
---
.../src/airflow/sdk/execution_time/task_runner.py | 48 +++-
task_sdk/tests/execution_time/conftest.py | 9 +
task_sdk/tests/execution_time/test_context.py | 39 +--
task_sdk/tests/execution_time/test_task_runner.py | 290 +++++++++++++++------
4 files changed, 273 insertions(+), 113 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d8d540318f1..810b108e221 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -21,7 +21,7 @@ from __future__ import annotations
import os
import sys
-from collections.abc import Iterable
+from collections.abc import Iterable, Mapping
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
@@ -197,7 +197,11 @@ class RuntimeTaskInstance(TaskInstance):
value = msg.value
if value is not None:
- return value
+ from airflow.models.xcom import XCom
+
+ # TODO: Move XCom serialization & deserialization to Task SDK
+ # https://github.com/apache/airflow/issues/45231
+ return XCom.deserialize_value(value)
return default
def xcom_push(self, key: str, value: Any):
@@ -207,6 +211,12 @@ class RuntimeTaskInstance(TaskInstance):
:param key: Key to store the value under.
:param value: Value to store. Only be JSON-serializable may be used
otherwise.
"""
+ from airflow.models.xcom import XCom
+
+ # TODO: Move XCom serialization & deserialization to Task SDK
+ # https://github.com/apache/airflow/issues/45231
+ value = XCom.serialize_value(value)
+
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
@@ -381,7 +391,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# - Update RTIF
# - Pre Execute
# etc
- ti.task.execute(context) # type: ignore[attr-defined]
+ result = ti.task.execute(context) # type: ignore[attr-defined]
+ _push_xcom_if_needed(result, ti)
+
msg = TaskState(state=TerminalTIState.SUCCESS,
end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
@@ -436,6 +448,36 @@ def run(ti: RuntimeTaskInstance, log: Logger):
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
+def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance):
+ """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the
task returns a result."""
+ if ti.task.do_xcom_push:
+ xcom_value = result
+ else:
+ xcom_value = None
+
+ # If the task returns a result, push an XCom containing it.
+ if xcom_value is None:
+ return
+
+ # If the task has multiple outputs, push each output as a separate XCom.
+ if ti.task.multiple_outputs:
+ if not isinstance(xcom_value, Mapping):
+ raise TypeError(
+ f"Returned output was type {type(xcom_value)} expected
dictionary for multiple_outputs"
+ )
+ for key in xcom_value.keys():
+ if not isinstance(key, str):
+ raise TypeError(
+ "Returned dictionary keys must be strings when using "
+ f"multiple_outputs, found {key} ({type(key)}) instead"
+ )
+ for k, v in result.items():
+ ti.xcom_push(k, v)
+
+ # TODO: Use constant for XCom return key & use serialize_value from Task
SDK
+ ti.xcom_push("return_value", result)
+
+
def finalize(log: Logger): ...
diff --git a/task_sdk/tests/execution_time/conftest.py
b/task_sdk/tests/execution_time/conftest.py
index 4a537373363..bf482e5ec7b 100644
--- a/task_sdk/tests/execution_time/conftest.py
+++ b/task_sdk/tests/execution_time/conftest.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import sys
+from unittest import mock
import pytest
@@ -31,3 +32,11 @@ def disable_capturing():
sys.stderr = sys.__stderr__
yield
sys.stdin, sys.stdout, sys.stderr = old_in, old_out, old_err
+
+
[email protected]
+def mock_supervisor_comms():
+ with mock.patch(
+ "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+ ) as supervisor_comms:
+ yield supervisor_comms
diff --git a/task_sdk/tests/execution_time/test_context.py
b/task_sdk/tests/execution_time/test_context.py
index a3220c3bef1..34502a1a917 100644
--- a/task_sdk/tests/execution_time/test_context.py
+++ b/task_sdk/tests/execution_time/test_context.py
@@ -17,8 +17,6 @@
from __future__ import annotations
-from unittest import mock
-
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
@@ -51,7 +49,7 @@ def test_convert_connection_result_conn():
class TestConnectionAccessor:
- def test_getattr_connection(self):
+ def test_getattr_connection(self, mock_supervisor_comms):
"""
Test that the connection is fetched when accessed via __getattr__.
@@ -62,31 +60,25 @@ class TestConnectionAccessor:
# Conn from the supervisor / API Server
conn_result = ConnectionResult(conn_id="mysql_conn",
conn_type="mysql", host="mysql", port=3306)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True
- ) as mock_supervisor_comms:
- mock_supervisor_comms.get_message.return_value = conn_result
+ mock_supervisor_comms.get_message.return_value = conn_result
- # Fetch the connection; triggers __getattr__
- conn = accessor.mysql_conn
+ # Fetch the connection; triggers __getattr__
+ conn = accessor.mysql_conn
- expected_conn = Connection(conn_id="mysql_conn",
conn_type="mysql", host="mysql", port=3306)
- assert conn == expected_conn
+ expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql",
host="mysql", port=3306)
+ assert conn == expected_conn
- def test_get_method_valid_connection(self):
+ def test_get_method_valid_connection(self, mock_supervisor_comms):
"""Test that the get method returns the requested connection using
`conn.get`."""
accessor = ConnectionAccessor()
conn_result = ConnectionResult(conn_id="mysql_conn",
conn_type="mysql", host="mysql", port=3306)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True
- ) as mock_supervisor_comms:
- mock_supervisor_comms.get_message.return_value = conn_result
+ mock_supervisor_comms.get_message.return_value = conn_result
- conn = accessor.get("mysql_conn")
- assert conn == Connection(conn_id="mysql_conn", conn_type="mysql",
host="mysql", port=3306)
+ conn = accessor.get("mysql_conn")
+ assert conn == Connection(conn_id="mysql_conn", conn_type="mysql",
host="mysql", port=3306)
- def test_get_method_with_default(self):
+ def test_get_method_with_default(self, mock_supervisor_comms):
"""Test that the get method returns the default connection when the
requested connection is not found."""
accessor = ConnectionAccessor()
default_conn = {"conn_id": "default_conn", "conn_type": "sqlite"}
@@ -94,10 +86,7 @@ class TestConnectionAccessor:
error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id":
"nonexistent_conn"}
)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True
- ) as mock_supervisor_comms:
- mock_supervisor_comms.get_message.return_value = error_response
+ mock_supervisor_comms.get_message.return_value = error_response
- conn = accessor.get("nonexistent_conn", default_conn=default_conn)
- assert conn == default_conn
+ conn = accessor.get("nonexistent_conn", default_conn=default_conn)
+ assert conn == default_conn
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 46bb75ac61a..48ab35709bb 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -43,7 +43,14 @@ from airflow.sdk.execution_time.comms import (
TaskState,
)
from airflow.sdk.execution_time.context import ConnectionAccessor
-from airflow.sdk.execution_time.task_runner import CommsDecoder,
RuntimeTaskInstance, parse, run, startup
+from airflow.sdk.execution_time.task_runner import (
+ CommsDecoder,
+ RuntimeTaskInstance,
+ _push_xcom_if_needed,
+ parse,
+ run,
+ startup,
+)
from airflow.utils import timezone
@@ -147,7 +154,7 @@ def test_parse(test_dags_dir: Path, make_ti_context):
assert isinstance(ti.task.dag, DAG)
-def test_run_basic(time_machine, mocked_parse, make_ti_context, spy_agency):
+def test_run_basic(time_machine, mocked_parse, make_ti_context, spy_agency,
mock_supervisor_comms):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run",
run_id="c", try_number=1),
@@ -159,26 +166,23 @@ def test_run_basic(time_machine, mocked_parse,
make_ti_context, spy_agency):
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as mock_supervisor_comms:
- ti = mocked_parse(what, "super_basic_run",
CustomOperator(task_id="hello"))
+ ti = mocked_parse(what, "super_basic_run", CustomOperator(task_id="hello"))
- # Ensure that task is locked for execution
- spy_agency.spy_on(ti.task.prepare_for_execution)
- assert not ti.task._lock_for_execution
+ # Ensure that task is locked for execution
+ spy_agency.spy_on(ti.task.prepare_for_execution)
+ assert not ti.task._lock_for_execution
- run(ti, log=mock.MagicMock())
+ run(ti, log=mock.MagicMock())
- spy_agency.assert_spy_called(ti.task.prepare_for_execution)
- assert ti.task._lock_for_execution
+ spy_agency.assert_spy_called(ti.task.prepare_for_execution)
+ assert ti.task._lock_for_execution
- mock_supervisor_comms.send_request.assert_called_once_with(
- msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant),
log=mock.ANY
- )
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant),
log=mock.ANY
+ )
-def test_run_deferred_basic(time_machine, mocked_parse, make_ti_context):
+def test_run_deferred_basic(time_machine, mocked_parse, make_ti_context,
mock_supervisor_comms):
"""Test that a task can transition to a deferred state."""
import datetime
@@ -213,17 +217,14 @@ def test_run_deferred_basic(time_machine, mocked_parse,
make_ti_context):
)
# Run the task
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as mock_supervisor_comms:
- ti = mocked_parse(what, "basic_deferred_run", task)
- run(ti, log=mock.MagicMock())
+ ti = mocked_parse(what, "basic_deferred_run", task)
+ run(ti, log=mock.MagicMock())
- # send_request will only be called when the TaskDeferred exception is
raised
-
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task,
log=mock.ANY)
+ # send_request will only be called when the TaskDeferred exception is
raised
+
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task,
log=mock.ANY)
-def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
+def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context,
mock_supervisor_comms):
"""Test running a basic task that marks itself skipped."""
from airflow.providers.standard.operators.python import PythonOperator
@@ -246,14 +247,11 @@ def test_run_basic_skipped(time_machine, mocked_parse,
make_ti_context):
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as mock_supervisor_comms:
- run(ti, log=mock.MagicMock())
+ run(ti, log=mock.MagicMock())
- mock_supervisor_comms.send_request.assert_called_once_with(
- msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant),
log=mock.ANY
- )
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=TaskState(state=TerminalTIState.SKIPPED, end_date=instant),
log=mock.ANY
+ )
@pytest.mark.parametrize(
@@ -283,7 +281,7 @@ def test_run_basic_skipped(time_machine, mocked_parse,
make_ti_context):
],
)
def test_startup_and_run_dag_with_templated_fields(
- mocked_parse, task_params, expected_rendered_fields, make_ti_context,
time_machine
+ mocked_parse, task_params, expected_rendered_fields, make_ti_context,
time_machine, mock_supervisor_comms
):
"""Test startup of a DAG with various templated fields."""
@@ -311,24 +309,22 @@ def test_startup_and_run_dag_with_templated_fields(
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as mock_supervisor_comms:
- mock_supervisor_comms.get_message.return_value = what
- startup()
- run(ti, log=mock.MagicMock())
- expected_calls = [
- mock.call.send_request(
-
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
- log=mock.ANY,
- ),
- mock.call.send_request(
- msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
- log=mock.ANY,
- ),
- ]
- mock_supervisor_comms.assert_has_calls(expected_calls)
+ mock_supervisor_comms.get_message.return_value = what
+
+ startup()
+ run(ti, log=mock.MagicMock())
+ expected_calls = [
+ mock.call.send_request(
+ msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
+ log=mock.ANY,
+ ),
+ mock.call.send_request(
+ msg=TaskState(end_date=instant, state=TerminalTIState.SUCCESS),
+ log=mock.ANY,
+ ),
+ ]
+ mock_supervisor_comms.assert_has_calls(expected_calls)
@pytest.mark.parametrize(
@@ -349,7 +345,9 @@ def test_startup_and_run_dag_with_templated_fields(
),
],
)
-def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id,
fail_with_exception, make_ti_context):
+def test_run_basic_failed(
+ time_machine, mocked_parse, dag_id, task_id, fail_with_exception,
make_ti_context, mock_supervisor_comms
+):
"""Test running a basic task that marks itself as failed by raising
exception."""
class CustomOperator(BaseOperator):
@@ -375,14 +373,11 @@ def test_run_basic_failed(time_machine, mocked_parse,
dag_id, task_id, fail_with
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
- ) as mock_supervisor_comms:
- run(ti, log=mock.MagicMock())
+ run(ti, log=mock.MagicMock())
- mock_supervisor_comms.send_request.assert_called_once_with(
- msg=TaskState(state=TerminalTIState.FAILED, end_date=instant),
log=mock.ANY
- )
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ msg=TaskState(state=TerminalTIState.FAILED, end_date=instant),
log=mock.ANY
+ )
class TestRuntimeTaskInstance:
@@ -456,7 +451,7 @@ class TestRuntimeTaskInstance:
"ts_nodash_with_tz": "20241201T010000+0000",
}
- def test_get_connection_from_context(self, mocked_parse, make_ti_context):
+ def test_get_connection_from_context(self, mocked_parse, make_ti_context,
mock_supervisor_comms):
"""Test that the connection is fetched from the API server via the
Supervisor lazily when accessed"""
task = BaseOperator(task_id="hello")
@@ -478,34 +473,159 @@ class TestRuntimeTaskInstance:
what = StartupDetails(ti=ti, file="", requests_fd=0,
ti_context=make_ti_context())
runtime_ti = mocked_parse(what, ti.dag_id, task)
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True
- ) as mock_supervisor_comms:
- mock_supervisor_comms.get_message.return_value = conn
+ mock_supervisor_comms.get_message.return_value = conn
- context = runtime_ti.get_template_context()
+ context = runtime_ti.get_template_context()
- # Assert that the connection is not fetched from the API server
yet!
- # The connection should be only fetched connection is accessed
- mock_supervisor_comms.send_request.assert_not_called()
- mock_supervisor_comms.get_message.assert_not_called()
+ # Assert that the connection is not fetched from the API server yet!
+ # The connection should be only fetched connection is accessed
+ mock_supervisor_comms.send_request.assert_not_called()
+ mock_supervisor_comms.get_message.assert_not_called()
- # Access the connection from the context
- conn_from_context = context["conn"].test_conn
+ # Access the connection from the context
+ conn_from_context = context["conn"].test_conn
- mock_supervisor_comms.send_request.assert_called_once_with(
- log=mock.ANY, msg=GetConnection(conn_id="test_conn")
- )
- mock_supervisor_comms.get_message.assert_called_once_with()
-
- assert conn_from_context == Connection(
- conn_id="test_conn",
- conn_type="mysql",
- description=None,
- host="mysql",
- schema="airflow",
- login="root",
- password="password",
- port=1234,
- extra='{"extra_key": "extra_value"}',
- )
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ log=mock.ANY, msg=GetConnection(conn_id="test_conn")
+ )
+ mock_supervisor_comms.get_message.assert_called_once_with()
+
+ assert conn_from_context == Connection(
+ conn_id="test_conn",
+ conn_type="mysql",
+ description=None,
+ host="mysql",
+ schema="airflow",
+ login="root",
+ password="password",
+ port=1234,
+ extra='{"extra_key": "extra_value"}',
+ )
+
+
+class TestXComAfterTaskExecution:
+ @pytest.mark.parametrize(
+ ["do_xcom_push", "should_push_xcom", "expected_xcom_value"],
+ [
+ pytest.param(False, False, None, id="do_xcom_push_false"),
+ pytest.param(True, True, "Hello World!", id="do_xcom_push_true"),
+ ],
+ )
+ def test_xcom_push_flag(
+ self,
+ mocked_parse,
+ make_ti_context,
+ mock_supervisor_comms,
+ spy_agency,
+ do_xcom_push: bool,
+ should_push_xcom: bool,
+ expected_xcom_value,
+ ):
+ """Test that the do_xcom_push flag controls whether the task pushes to
XCom."""
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ return "Hello World!"
+
+ task = CustomOperator(task_id="hello", do_xcom_push=do_xcom_push)
+
+ ti = TaskInstance(
+ id=uuid7(), task_id=task.task_id, dag_id="xcom_push_flag",
run_id="test_run", try_number=1
+ )
+
+ what = StartupDetails(ti=ti, file="", requests_fd=0,
ti_context=make_ti_context())
+ runtime_ti = mocked_parse(what, ti.dag_id, task)
+
+ spy_agency.spy_on(_push_xcom_if_needed, call_original=True)
+ spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+
+ run(runtime_ti, log=mock.MagicMock())
+
+ spy_agency.assert_spy_called(_push_xcom_if_needed)
+
+ if should_push_xcom:
+ spy_agency.assert_spy_called_with(runtime_ti.xcom_push,
"return_value", expected_xcom_value)
+ else:
+ spy_agency.assert_spy_not_called(runtime_ti.xcom_push)
+
+ def test_xcom_with_multiple_outputs(self, mocked_parse, spy_agency):
+ """Test that the task pushes to XCom when multiple outputs are
returned."""
+ result = {"key1": "value1", "key2": "value2"}
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ return result
+
+ task = CustomOperator(
+ task_id="test_xcom_push_with_multiple_outputs", do_xcom_push=True,
multiple_outputs=True
+ )
+ dag = get_inline_dag(dag_id="test_dag", task=task)
+ ti = TaskInstance(
+ id=uuid7(), task_id=task.task_id, dag_id=dag.dag_id,
run_id="test_run", try_number=1
+ )
+
+ runtime_ti =
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True),
task=task)
+
+ spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+ _push_xcom_if_needed(result=result, ti=runtime_ti)
+
+ expected_calls = [
+ ("key1", "value1"),
+ ("key2", "value2"),
+ ("return_value", result),
+ ]
+ spy_agency.assert_spy_call_count(runtime_ti.xcom_push,
len(expected_calls))
+ for key, value in expected_calls:
+ spy_agency.assert_spy_called_with(runtime_ti.xcom_push, key, value)
+
+ def test_xcom_with_multiple_outputs_and_no_mapping_result(self,
mocked_parse, spy_agency):
+ """Test that error is raised when multiple outputs are returned
without mapping."""
+ result = "value1"
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ return result
+
+ task = CustomOperator(
+ task_id="test_xcom_push_with_multiple_outputs", do_xcom_push=True,
multiple_outputs=True
+ )
+ dag = get_inline_dag(dag_id="test_dag", task=task)
+ ti = TaskInstance(
+ id=uuid7(), task_id=task.task_id, dag_id=dag.dag_id,
run_id="test_run", try_number=1
+ )
+
+ runtime_ti =
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True),
task=task)
+
+ spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+ with pytest.raises(
+ TypeError,
+ match=f"Returned output was type {type(result)} expected
dictionary for multiple_outputs",
+ ):
+ _push_xcom_if_needed(result=result, ti=runtime_ti)
+
+ def test_xcom_with_multiple_outputs_and_key_is_not_string(self,
mocked_parse, spy_agency):
+ """Test that error is raised when multiple outputs are returned and
key isn't string."""
+ result = {2: "value1", "key2": "value2"}
+
+ class CustomOperator(BaseOperator):
+ def execute(self, context):
+ return result
+
+ task = CustomOperator(
+ task_id="test_xcom_push_with_multiple_outputs", do_xcom_push=True,
multiple_outputs=True
+ )
+ dag = get_inline_dag(dag_id="test_dag", task=task)
+ ti = TaskInstance(
+ id=uuid7(), task_id=task.task_id, dag_id=dag.dag_id,
run_id="test_run", try_number=1
+ )
+
+ runtime_ti =
RuntimeTaskInstance.model_construct(**ti.model_dump(exclude_unset=True),
task=task)
+
+ spy_agency.spy_on(runtime_ti.xcom_push, call_original=False)
+
+ with pytest.raises(TypeError) as exc_info:
+ _push_xcom_if_needed(result=result, ti=runtime_ti)
+
+ assert str(exc_info.value) == (
+ f"Returned dictionary keys must be strings when using
multiple_outputs, found 2 ({int}) instead"
+ )