This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-6-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 4fac9455cd980d66926e004d4c7c921b83084036
Author: Hussein Awala <[email protected]>
AuthorDate: Fri Apr 14 14:16:11 2023 +0200

    Fix mapped tasks partial arguments when DAG default args are provided 
(#29913)
    
    * Add a failing test to make it pass
    
    * use partial_kwargs when they are provide and override only None values by 
dag default values
    
    * update the test and check if the values are filled in the right order
    
    * fix overriding retry_delay with default value when it is equal to 0
    
    * add missing default value for inlets and outlets
    
    * set partial_kwargs dict type to dict[str, Any] and remove type ignore 
comments
    
    * create a dict for default values and use NotSet instead of None to 
support None as accepted value
    
    * update partial typing by removing None type from some args and set NotSet 
for all args
    
    * Tweak kwarg merging slightly
    
    This should improve iteration a bit, I think.
    
    * Fix unit tests
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    (cherry picked from commit f01051a75e217d5f20394b8c890425915383101f)
---
 airflow/models/baseoperator.py      | 187 +++++++++++++++++++++---------------
 tests/models/test_mappedoperator.py |  14 +++
 2 files changed, 122 insertions(+), 79 deletions(-)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d81d45dc7e..37106c580f 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -90,6 +90,7 @@ from airflow.utils.operator_resources import Resources
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.setup_teardown import SetupTeardownContext
 from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import NOTSET, ArgNotSet
 from airflow.utils.weight_rule import WeightRule
 from airflow.utils.xcom import XCOM_RETURN_KEY
 
@@ -184,6 +185,26 @@ class _PartialDescriptor:
         return self.class_method.__get__(cls, cls)
 
 
+_PARTIAL_DEFAULTS = {
+    "owner": DEFAULT_OWNER,
+    "trigger_rule": DEFAULT_TRIGGER_RULE,
+    "depends_on_past": False,
+    "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
+    "wait_for_past_depends_before_skipping": 
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
+    "wait_for_downstream": False,
+    "retries": DEFAULT_RETRIES,
+    "queue": DEFAULT_QUEUE,
+    "pool_slots": DEFAULT_POOL_SLOTS,
+    "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
+    "retry_delay": DEFAULT_RETRY_DELAY,
+    "retry_exponential_backoff": False,
+    "priority_weight": DEFAULT_PRIORITY_WEIGHT,
+    "weight_rule": DEFAULT_WEIGHT_RULE,
+    "inlets": [],
+    "outlets": [],
+}
+
+
 # This is what handles the actual mapping.
 def partial(
     operator_class: type[BaseOperator],
@@ -191,43 +212,43 @@ def partial(
     task_id: str,
     dag: DAG | None = None,
     task_group: TaskGroup | None = None,
-    start_date: datetime | None = None,
-    end_date: datetime | None = None,
-    owner: str = DEFAULT_OWNER,
-    email: None | str | Iterable[str] = None,
+    start_date: datetime | ArgNotSet = NOTSET,
+    end_date: datetime | ArgNotSet = NOTSET,
+    owner: str | ArgNotSet = NOTSET,
+    email: None | str | Iterable[str] | ArgNotSet = NOTSET,
     params: collections.abc.MutableMapping | None = None,
-    resources: dict[str, Any] | None = None,
-    trigger_rule: str = DEFAULT_TRIGGER_RULE,
-    depends_on_past: bool = False,
-    ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
-    wait_for_past_depends_before_skipping: bool = 
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
-    wait_for_downstream: bool = False,
-    retries: int | None = DEFAULT_RETRIES,
-    queue: str = DEFAULT_QUEUE,
-    pool: str | None = None,
-    pool_slots: int = DEFAULT_POOL_SLOTS,
-    execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
-    max_retry_delay: None | timedelta | float = None,
-    retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
-    retry_exponential_backoff: bool = False,
-    priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
-    weight_rule: str = DEFAULT_WEIGHT_RULE,
-    sla: timedelta | None = None,
-    max_active_tis_per_dag: int | None = None,
-    max_active_tis_per_dagrun: int | None = None,
-    on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-    on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-    on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-    on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
-    run_as_user: str | None = None,
-    executor_config: dict | None = None,
-    inlets: Any | None = None,
-    outlets: Any | None = None,
-    doc: str | None = None,
-    doc_md: str | None = None,
-    doc_json: str | None = None,
-    doc_yaml: str | None = None,
-    doc_rst: str | None = None,
+    resources: dict[str, Any] | None | ArgNotSet = NOTSET,
+    trigger_rule: str | ArgNotSet = NOTSET,
+    depends_on_past: bool | ArgNotSet = NOTSET,
+    ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
+    wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
+    wait_for_downstream: bool | ArgNotSet = NOTSET,
+    retries: int | None | ArgNotSet = NOTSET,
+    queue: str | ArgNotSet = NOTSET,
+    pool: str | ArgNotSet = NOTSET,
+    pool_slots: int | ArgNotSet = NOTSET,
+    execution_timeout: timedelta | None | ArgNotSet = NOTSET,
+    max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
+    retry_delay: timedelta | float | ArgNotSet = NOTSET,
+    retry_exponential_backoff: bool | ArgNotSet = NOTSET,
+    priority_weight: int | ArgNotSet = NOTSET,
+    weight_rule: str | ArgNotSet = NOTSET,
+    sla: timedelta | None | ArgNotSet = NOTSET,
+    max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
+    max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
+    on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+    on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+    on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+    on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+    run_as_user: str | None | ArgNotSet = NOTSET,
+    executor_config: dict | None | ArgNotSet = NOTSET,
+    inlets: Any | None | ArgNotSet = NOTSET,
+    outlets: Any | None | ArgNotSet = NOTSET,
+    doc: str | None | ArgNotSet = NOTSET,
+    doc_md: str | None | ArgNotSet = NOTSET,
+    doc_json: str | None | ArgNotSet = NOTSET,
+    doc_yaml: str | None | ArgNotSet = NOTSET,
+    doc_rst: str | None | ArgNotSet = NOTSET,
     **kwargs,
 ) -> OperatorPartial:
     from airflow.models.dag import DagContext
@@ -242,54 +263,62 @@ def partial(
         task_id = task_group.child_id(task_id)
 
     # Merge DAG and task group level defaults into user-supplied values.
-    partial_kwargs, partial_params = get_merged_defaults(
+    dag_default_args, partial_params = get_merged_defaults(
         dag=dag,
         task_group=task_group,
         task_params=params,
         task_default_args=kwargs.pop("default_args", None),
     )
-    partial_kwargs.update(kwargs)
-
-    # Always fully populate partial kwargs to exclude them from map().
-    partial_kwargs.setdefault("dag", dag)
-    partial_kwargs.setdefault("task_group", task_group)
-    partial_kwargs.setdefault("task_id", task_id)
-    partial_kwargs.setdefault("start_date", start_date)
-    partial_kwargs.setdefault("end_date", end_date)
-    partial_kwargs.setdefault("owner", owner)
-    partial_kwargs.setdefault("email", email)
-    partial_kwargs.setdefault("trigger_rule", trigger_rule)
-    partial_kwargs.setdefault("depends_on_past", depends_on_past)
-    partial_kwargs.setdefault("ignore_first_depends_on_past", 
ignore_first_depends_on_past)
-    partial_kwargs.setdefault("wait_for_past_depends_before_skipping", 
wait_for_past_depends_before_skipping)
-    partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream)
-    partial_kwargs.setdefault("retries", retries)
-    partial_kwargs.setdefault("queue", queue)
-    partial_kwargs.setdefault("pool", pool)
-    partial_kwargs.setdefault("pool_slots", pool_slots)
-    partial_kwargs.setdefault("execution_timeout", execution_timeout)
-    partial_kwargs.setdefault("max_retry_delay", max_retry_delay)
-    partial_kwargs.setdefault("retry_delay", retry_delay)
-    partial_kwargs.setdefault("retry_exponential_backoff", 
retry_exponential_backoff)
-    partial_kwargs.setdefault("priority_weight", priority_weight)
-    partial_kwargs.setdefault("weight_rule", weight_rule)
-    partial_kwargs.setdefault("sla", sla)
-    partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
-    partial_kwargs.setdefault("max_active_tis_per_dagrun", 
max_active_tis_per_dagrun)
-    partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
-    partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
-    partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
-    partial_kwargs.setdefault("on_success_callback", on_success_callback)
-    partial_kwargs.setdefault("run_as_user", run_as_user)
-    partial_kwargs.setdefault("executor_config", executor_config)
-    partial_kwargs.setdefault("inlets", inlets or [])
-    partial_kwargs.setdefault("outlets", outlets or [])
-    partial_kwargs.setdefault("resources", resources)
-    partial_kwargs.setdefault("doc", doc)
-    partial_kwargs.setdefault("doc_json", doc_json)
-    partial_kwargs.setdefault("doc_md", doc_md)
-    partial_kwargs.setdefault("doc_rst", doc_rst)
-    partial_kwargs.setdefault("doc_yaml", doc_yaml)
+
+    # Create partial_kwargs from args and kwargs
+    partial_kwargs: dict[str, Any] = {
+        **kwargs,
+        "dag": dag,
+        "task_group": task_group,
+        "task_id": task_id,
+        "start_date": start_date,
+        "end_date": end_date,
+        "owner": owner,
+        "email": email,
+        "trigger_rule": trigger_rule,
+        "depends_on_past": depends_on_past,
+        "ignore_first_depends_on_past": ignore_first_depends_on_past,
+        "wait_for_past_depends_before_skipping": 
wait_for_past_depends_before_skipping,
+        "wait_for_downstream": wait_for_downstream,
+        "retries": retries,
+        "queue": queue,
+        "pool": pool,
+        "pool_slots": pool_slots,
+        "execution_timeout": execution_timeout,
+        "max_retry_delay": max_retry_delay,
+        "retry_delay": retry_delay,
+        "retry_exponential_backoff": retry_exponential_backoff,
+        "priority_weight": priority_weight,
+        "weight_rule": weight_rule,
+        "sla": sla,
+        "max_active_tis_per_dag": max_active_tis_per_dag,
+        "max_active_tis_per_dagrun": max_active_tis_per_dagrun,
+        "on_execute_callback": on_execute_callback,
+        "on_failure_callback": on_failure_callback,
+        "on_retry_callback": on_retry_callback,
+        "on_success_callback": on_success_callback,
+        "run_as_user": run_as_user,
+        "executor_config": executor_config,
+        "inlets": inlets,
+        "outlets": outlets,
+        "resources": resources,
+        "doc": doc,
+        "doc_json": doc_json,
+        "doc_md": doc_md,
+        "doc_rst": doc_rst,
+        "doc_yaml": doc_yaml,
+    }
+
+    # Inject DAG-level default args into args provided to this function.
+    partial_kwargs.update((k, v) for k, v in dag_default_args.items() if 
partial_kwargs.get(k) is NOTSET)
+
+    # Fill fields not provided by the user with default values.
+    partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, 
v in partial_kwargs.items()}
 
     # Post-process arguments. Should be kept in sync with 
_TaskDecorator.expand().
     if "task_concurrency" in kwargs:  # Reject deprecated option.
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 931d262d24..84ddd9fb66 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -85,6 +85,20 @@ def test_task_mapping_default_args():
     assert mapped.start_date == pendulum.instance(default_args["start_date"])
 
 
+def test_task_mapping_override_default_args():
+    default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()}
+    with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
+        literal = ["a", "b", "c"]
+        mapped = MockOperator.partial(task_id="task", 
retries=1).expand(arg2=literal)
+
+    # retries should be 1 because it is provided as a partial arg
+    assert mapped.partial_kwargs["retries"] == 1
+    # start_date should be equal to default_args["start_date"] because it is 
not provided as partial arg
+    assert mapped.start_date == pendulum.instance(default_args["start_date"])
+    # owner should be equal to Airflow default owner (airflow) because it is 
not provided at all
+    assert mapped.owner == "airflow"
+
+
 def test_map_unknown_arg_raises():
     with pytest.raises(TypeError, match=r"argument 'file'"):
         BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}])

Reply via email to