This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c8ce604739 fix: task flow dynamic mapping with default_args (#41592)
c8ce604739 is described below

commit c8ce604739bcbb15865e2d5d198fc5d83f562668
Author: phi-friday <[email protected]>
AuthorDate: Wed Oct 2 10:13:38 2024 +0900

    fix: task flow dynamic mapping with default_args (#41592)
---
 airflow/decorators/base.py      | 25 ++++++++++++++++++-------
 tests/decorators/test_mapped.py | 24 ++++++++++++++++++++++++
 2 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 1ef2c12c70..e650c1920a 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -431,18 +431,29 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
         dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
         task_group = task_kwargs.pop("task_group", None) or 
TaskGroupContext.get_current_task_group(dag)
 
-        partial_kwargs, partial_params = get_merged_defaults(
+        default_args, partial_params = get_merged_defaults(
             dag=dag,
             task_group=task_group,
             task_params=task_kwargs.pop("params", None),
             task_default_args=task_kwargs.pop("default_args", None),
         )
-        partial_kwargs.update(
-            task_kwargs,
-            is_setup=self.is_setup,
-            is_teardown=self.is_teardown,
-            on_failure_fail_dagrun=self.on_failure_fail_dagrun,
-        )
+        partial_kwargs: dict[str, Any] = {
+            "is_setup": self.is_setup,
+            "is_teardown": self.is_teardown,
+            "on_failure_fail_dagrun": self.on_failure_fail_dagrun,
+        }
+        base_signature = inspect.signature(BaseOperator)
+        ignore = {
+            "default_args",  # This is target we are working on now.
+            "kwargs",  # A common name for a keyword argument.
+            "do_xcom_push",  # In the same boat as `multiple_outputs`
+            "multiple_outputs",  # We will use `self.multiple_outputs` instead.
+            "params",  # Already handled above `partial_params`.
+            "task_concurrency",  # Deprecated(replaced by 
`max_active_tis_per_dag`).
+        }
+        partial_keys = set(base_signature.parameters) - ignore
+        partial_kwargs.update({key: value for key, value in 
default_args.items() if key in partial_keys})
+        partial_kwargs.update(task_kwargs)
 
         task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, 
task_group)
         if task_group:
diff --git a/tests/decorators/test_mapped.py b/tests/decorators/test_mapped.py
index 3812367425..2d3747b5f3 100644
--- a/tests/decorators/test_mapped.py
+++ b/tests/decorators/test_mapped.py
@@ -17,6 +17,9 @@
 # under the License.
 from __future__ import annotations
 
+import pytest
+
+from airflow.decorators import task
 from airflow.models.dag import DAG
 from airflow.utils.task_group import TaskGroup
 from tests.models import DEFAULT_DATE
@@ -36,3 +39,24 @@ def test_mapped_task_group_id_prefix_task_id():
 
     dag.get_task("t1") == x1.operator
     dag.get_task("g.t2") == x2.operator
+
+
[email protected]_test
+def test_mapped_task_with_arbitrary_default_args(dag_maker, session):
+    default_args = {"some": "value", "not": "in", "the": "task", "or": "dag"}
+    with dag_maker(session=session, default_args=default_args):
+
+        @task.python(do_xcom_push=True)
+        def f(x: int, y: int) -> int:
+            return x + y
+
+        f.partial(y=10).expand(x=[1, 2, 3])
+
+    dag_run = dag_maker.create_dagrun(session=session)
+    decision = dag_run.task_instance_scheduling_decisions(session=session)
+    xcoms = set()
+    for ti in decision.schedulable_tis:
+        ti.run(session=session)
+        xcoms.add(ti.xcom_pull(session=session, task_ids=ti.task_id, 
map_indexes=ti.map_index))
+
+    assert xcoms == {11, 12, 13}

Reply via email to