This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-6-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 11f0a9c060f192d72b6a2512a10654f7d2f44a97 Author: Ephraim Anierobi <[email protected]> AuthorDate: Wed Jun 21 10:55:57 2023 +0100 Return None if an XComArg fails to resolve in a multiple_outputs Task (#32027) * Return None if an XComArg fails to resolve in a multiple_outputs Task Tasks with multiple_outputs set to True returns XComs with different keys which are not known to Airflow. Because they have multiple_outputs set, we should return None if we can't find the XCom, just like we return None when the key is equal to XCOM_RETURN_KEY known to Airflow. Closes: https://github.com/apache/airflow/issues/29199 * Apply suggestions from code review (cherry picked from commit 79eac7687cf7c6bcaa4df2b8735efaad79a7fee2) --- airflow/models/xcom_arg.py | 7 ++++++ tests/decorators/test_python.py | 47 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index d8b42ba819..ab9a46641b 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -370,6 +370,13 @@ class PlainXComArg(XComArg): return result if self.key == XCOM_RETURN_KEY: return None + if getattr(self.operator, "multiple_outputs", False): + # If the operator is set to have multiple outputs and it was not executed, + # we should return "None" instead of showing an error. This is because when + # multiple outputs XComs are created, the XCom keys associated with them will have + # different names than the predefined "XCOM_RETURN_KEY" and won't be found. + # Therefore, it's better to return "None" like we did above where self.key==XCOM_RETURN_KEY. + return None raise XComNotFound(ti.dag_id, task_id, self.key) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 435c56de2e..7eebe1cf5f 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -18,14 +18,14 @@ import sys from collections import namedtuple from datetime import date, timedelta -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, Tuple, Union import pytest from airflow import PY38, PY311 from airflow.decorators import task as task_decorator from airflow.decorators.base import DecoratedMappedOperator -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, XComNotFound from airflow.models import DAG from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import DictOfListsExpandInput @@ -803,6 +803,49 @@ def test_upstream_exception_produces_none_xcom(dag_maker, session): assert result == "'example' None" [email protected]("multiple_outputs", [True, False]) +def test_multiple_outputs_produces_none_xcom_when_task_is_skipped(dag_maker, session, multiple_outputs): + from airflow.exceptions import AirflowSkipException + from airflow.utils.trigger_rule import TriggerRule + + result = None + + with dag_maker(session=session) as dag: + + @dag.task() + def up1() -> str: + return "example" + + @dag.task(multiple_outputs=multiple_outputs) + def up2(x) -> Union[dict, None]: + if x == 2: + return {"x": "example"} + raise AirflowSkipException() + + @dag.task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) + def down(a, b): + nonlocal result + result = f"{a!r} {b!r}" + + down(up1(), up2(1)["x"]) + + dr = dag_maker.create_dagrun() + + decision = dr.task_instance_scheduling_decisions(session=session) + assert len(decision.schedulable_tis) == 2 # "up1" and "up2" + for ti in decision.schedulable_tis: + ti.run(session=session) + + decision = dr.task_instance_scheduling_decisions(session=session) + assert len(decision.schedulable_tis) == 1 # "down" + if multiple_outputs: + decision.schedulable_tis[0].run(session=session) + assert result == "'example' None" + else: + with pytest.raises(XComNotFound): + decision.schedulable_tis[0].run(session=session) + + @pytest.mark.filterwarnings("error") def test_no_warnings(reset_logging_config, caplog): @task_decorator
