This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9214018153 Disallow calling expand with no arguments (#23463)
9214018153 is described below
commit 9214018153dd193be6b1147629f73b23d8195cce
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri May 27 00:25:13 2022 -0400
Disallow calling expand with no arguments (#23463)
---
airflow/decorators/base.py | 3 +++
airflow/models/mappedoperator.py | 5 +++++
airflow/serialization/serialized_objects.py | 3 ++-
tests/api_connexion/endpoints/test_task_endpoint.py | 6 ++++--
tests/models/test_taskinstance.py | 8 ++++++--
5 files changed, 20 insertions(+), 5 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 79277c7281..1b14cd0668 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -312,6 +312,9 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
raise TypeError(f"{func}() got unexpected keyword arguments
{names}")
def expand(self, **map_kwargs: "Mappable") -> XComArg:
+ if not map_kwargs:
+ raise TypeError("no arguments to expand against")
+
self._validate_arg_names("expand", map_kwargs)
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping
already partial")
ensure_xcomarg_return_value(map_kwargs)
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index c522cefb2c..663ceeece1 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -191,6 +191,11 @@ class OperatorPartial:
warnings.warn(f"Task {task_id} was never mapped!")
def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
+ if not mapped_kwargs:
+ raise TypeError("no arguments to expand against")
+ return self._expand(**mapped_kwargs)
+
+ def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
self._expand_called = True
from airflow.operators.empty import EmptyOperator
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 8d21bca8ee..3e674b2f8d 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -96,7 +96,8 @@ def _get_default_mapped_partial() -> Dict[str, Any]:
are defaults, they are automatically supplied on de-serialization, so we
don't need to store them.
"""
- default_partial_kwargs =
BaseOperator.partial(task_id="_").expand().partial_kwargs
+ # Use the private _expand() method to avoid the empty kwargs check.
+ default_partial_kwargs =
BaseOperator.partial(task_id="_")._expand().partial_kwargs
return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR]
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py
b/tests/api_connexion/endpoints/test_task_endpoint.py
index 9748305d8c..7509a89032 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -67,8 +67,10 @@ class TestTaskEndpoint:
task2 = EmptyOperator(task_id=self.task_id2,
start_date=self.task2_start_date)
with DAG(self.mapped_dag_id, start_date=self.task1_start_date) as
mapped_dag:
- task3 = EmptyOperator(task_id=self.task_id3) # noqa
- mapped_task =
EmptyOperator.partial(task_id=self.mapped_task_id).expand() # noqa
+ EmptyOperator(task_id=self.task_id3)
+ # Use the private _expand() method to avoid the empty kwargs check.
+ # We don't care about how the operator runs here, only its
presence.
+ EmptyOperator.partial(task_id=self.mapped_task_id)._expand()
task1 >> task2
dag_bag = DagBag(os.devnull, include_examples=False)
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 975d44cf92..a1d180fa1e 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1053,7 +1053,9 @@ class TestTaskInstance:
def test_xcom_pull_mapped(self, dag_maker, session):
with dag_maker(dag_id="test_xcom", session=session):
- task_1 = EmptyOperator.partial(task_id="task_1").expand()
+ # Use the private _expand() method to avoid the empty kwargs check.
+ # We don't care about how the operator runs here, only its
presence.
+ task_1 = EmptyOperator.partial(task_id="task_1")._expand()
EmptyOperator(task_id="task_2")
dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6,
1, 0, 0, 0))
@@ -2763,7 +2765,9 @@ class TestMappedTaskInstanceReceiveValue:
def
test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value,
dag_maker, session):
"""Ensure we access XCom lazily when pulling from a mapped operator."""
with dag_maker(dag_id="test_xcom", session=session):
- task_1 = EmptyOperator.partial(task_id="task_1").expand()
+ # Use the private _expand() method to avoid the empty kwargs check.
+ # We don't care about how the operator runs here, only its presence.
+ task_1 = EmptyOperator.partial(task_id="task_1")._expand()
EmptyOperator(task_id="task_2")
dagrun = dag_maker.create_dagrun()