This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c8ce604739 fix: task flow dynamic mapping with default_args (#41592)
c8ce604739 is described below
commit c8ce604739bcbb15865e2d5d198fc5d83f562668
Author: phi-friday <[email protected]>
AuthorDate: Wed Oct 2 10:13:38 2024 +0900
fix: task flow dynamic mapping with default_args (#41592)
---
airflow/decorators/base.py | 25 ++++++++++++++++++-------
tests/decorators/test_mapped.py | 24 ++++++++++++++++++++++++
2 files changed, 42 insertions(+), 7 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 1ef2c12c70..e650c1920a 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -431,18 +431,29 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or
TaskGroupContext.get_current_task_group(dag)
- partial_kwargs, partial_params = get_merged_defaults(
+ default_args, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=task_kwargs.pop("params", None),
task_default_args=task_kwargs.pop("default_args", None),
)
- partial_kwargs.update(
- task_kwargs,
- is_setup=self.is_setup,
- is_teardown=self.is_teardown,
- on_failure_fail_dagrun=self.on_failure_fail_dagrun,
- )
+ partial_kwargs: dict[str, Any] = {
+ "is_setup": self.is_setup,
+ "is_teardown": self.is_teardown,
+ "on_failure_fail_dagrun": self.on_failure_fail_dagrun,
+ }
+ base_signature = inspect.signature(BaseOperator)
+ ignore = {
+ "default_args", # This is target we are working on now.
+ "kwargs", # A common name for a keyword argument.
+ "do_xcom_push", # In the same boat as `multiple_outputs`
+ "multiple_outputs", # We will use `self.multiple_outputs` instead.
+ "params", # Already handled above `partial_params`.
+ "task_concurrency", # Deprecated(replaced by
`max_active_tis_per_dag`).
+ }
+ partial_keys = set(base_signature.parameters) - ignore
+ partial_kwargs.update({key: value for key, value in
default_args.items() if key in partial_keys})
+ partial_kwargs.update(task_kwargs)
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag,
task_group)
if task_group:
diff --git a/tests/decorators/test_mapped.py b/tests/decorators/test_mapped.py
index 3812367425..2d3747b5f3 100644
--- a/tests/decorators/test_mapped.py
+++ b/tests/decorators/test_mapped.py
@@ -17,6 +17,9 @@
# under the License.
from __future__ import annotations
+import pytest
+
+from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
from tests.models import DEFAULT_DATE
@@ -36,3 +39,24 @@ def test_mapped_task_group_id_prefix_task_id():
dag.get_task("t1") == x1.operator
dag.get_task("g.t2") == x2.operator
+
+
[email protected]_test
+def test_mapped_task_with_arbitrary_default_args(dag_maker, session):
+ default_args = {"some": "value", "not": "in", "the": "task", "or": "dag"}
+ with dag_maker(session=session, default_args=default_args):
+
+ @task.python(do_xcom_push=True)
+ def f(x: int, y: int) -> int:
+ return x + y
+
+ f.partial(y=10).expand(x=[1, 2, 3])
+
+ dag_run = dag_maker.create_dagrun(session=session)
+ decision = dag_run.task_instance_scheduling_decisions(session=session)
+ xcoms = set()
+ for ti in decision.schedulable_tis:
+ ti.run(session=session)
+ xcoms.add(ti.xcom_pull(session=session, task_ids=ti.task_id,
map_indexes=ti.map_index))
+
+ assert xcoms == {11, 12, 13}