This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9214018153 Disallow calling expand with no arguments (#23463)
9214018153 is described below

commit 9214018153dd193be6b1147629f73b23d8195cce
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri May 27 00:25:13 2022 -0400

    Disallow calling expand with no arguments (#23463)
---
 airflow/decorators/base.py                          | 3 +++
 airflow/models/mappedoperator.py                    | 5 +++++
 airflow/serialization/serialized_objects.py         | 3 ++-
 tests/api_connexion/endpoints/test_task_endpoint.py | 6 ++++--
 tests/models/test_taskinstance.py                   | 8 ++++++--
 5 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 79277c7281..1b14cd0668 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -312,6 +312,9 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
             raise TypeError(f"{func}() got unexpected keyword arguments 
{names}")
 
     def expand(self, **map_kwargs: "Mappable") -> XComArg:
+        if not map_kwargs:
+            raise TypeError("no arguments to expand against")
+
         self._validate_arg_names("expand", map_kwargs)
         prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping 
already partial")
         ensure_xcomarg_return_value(map_kwargs)
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index c522cefb2c..663ceeece1 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -191,6 +191,11 @@ class OperatorPartial:
             warnings.warn(f"Task {task_id} was never mapped!")
 
     def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
+        if not mapped_kwargs:
+            raise TypeError("no arguments to expand against")
+        return self._expand(**mapped_kwargs)
+
+    def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
         self._expand_called = True
 
         from airflow.operators.empty import EmptyOperator
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 8d21bca8ee..3e674b2f8d 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -96,7 +96,8 @@ def _get_default_mapped_partial() -> Dict[str, Any]:
     are defaults, they are automatically supplied on de-serialization, so we
     don't need to store them.
     """
-    default_partial_kwargs = 
BaseOperator.partial(task_id="_").expand().partial_kwargs
+    # Use the private _expand() method to avoid the empty kwargs check.
+    default_partial_kwargs = 
BaseOperator.partial(task_id="_")._expand().partial_kwargs
     return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR]
 
 
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py 
b/tests/api_connexion/endpoints/test_task_endpoint.py
index 9748305d8c..7509a89032 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -67,8 +67,10 @@ class TestTaskEndpoint:
             task2 = EmptyOperator(task_id=self.task_id2, 
start_date=self.task2_start_date)
 
         with DAG(self.mapped_dag_id, start_date=self.task1_start_date) as 
mapped_dag:
-            task3 = EmptyOperator(task_id=self.task_id3)  # noqa
-            mapped_task = 
EmptyOperator.partial(task_id=self.mapped_task_id).expand()  # noqa
+            EmptyOperator(task_id=self.task_id3)
+            # Use the private _expand() method to avoid the empty kwargs check.
+            # We don't care about how the operator runs here, only its 
presence.
+            EmptyOperator.partial(task_id=self.mapped_task_id)._expand()
 
         task1 >> task2
         dag_bag = DagBag(os.devnull, include_examples=False)
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 975d44cf92..a1d180fa1e 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1053,7 +1053,9 @@ class TestTaskInstance:
 
     def test_xcom_pull_mapped(self, dag_maker, session):
         with dag_maker(dag_id="test_xcom", session=session):
-            task_1 = EmptyOperator.partial(task_id="task_1").expand()
+            # Use the private _expand() method to avoid the empty kwargs check.
+            # We don't care about how the operator runs here, only its 
presence.
+            task_1 = EmptyOperator.partial(task_id="task_1")._expand()
             EmptyOperator(task_id="task_2")
 
         dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 
1, 0, 0, 0))
@@ -2763,7 +2765,9 @@ class TestMappedTaskInstanceReceiveValue:
 def 
test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value,
 dag_maker, session):
     """Ensure we access XCom lazily when pulling from a mapped operator."""
     with dag_maker(dag_id="test_xcom", session=session):
-        task_1 = EmptyOperator.partial(task_id="task_1").expand()
+        # Use the private _expand() method to avoid the empty kwargs check.
+        # We don't care about how the operator runs here, only its presence.
+        task_1 = EmptyOperator.partial(task_id="task_1")._expand()
         EmptyOperator(task_id="task_2")
 
     dagrun = dag_maker.create_dagrun()

Reply via email to