This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f01051a75e Fix mapped tasks partial arguments when DAG default args
are provided (#29913)
f01051a75e is described below
commit f01051a75e217d5f20394b8c890425915383101f
Author: Hussein Awala <[email protected]>
AuthorDate: Fri Apr 14 14:16:11 2023 +0200
Fix mapped tasks partial arguments when DAG default args are provided
(#29913)
* Add a failing test to make it pass
* use partial_kwargs when they are provide and override only None values by
dag default values
* update the test and check if the values are filled in the right order
* fix overriding retry_delay with default value when it is equal to 0
* add missing default value for inlets and outlets
* set partial_kwargs dict type to dict[str, Any] and remove type ignore
comments
* create a dict for default values and use NotSet instead of None to
support None as accepted value
* update partial typing by removing None type from some args and set NotSet
for all args
* Tweak kwarg merging slightly
This should improve iteration a bit, I think.
* Fix unit tests
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/models/baseoperator.py | 187 +++++++++++++++++++++---------------
tests/models/test_mappedoperator.py | 14 +++
2 files changed, 122 insertions(+), 79 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index d81d45dc7e..37106c580f 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -90,6 +90,7 @@ from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.weight_rule import WeightRule
from airflow.utils.xcom import XCOM_RETURN_KEY
@@ -184,6 +185,26 @@ class _PartialDescriptor:
return self.class_method.__get__(cls, cls)
+_PARTIAL_DEFAULTS = {
+ "owner": DEFAULT_OWNER,
+ "trigger_rule": DEFAULT_TRIGGER_RULE,
+ "depends_on_past": False,
+ "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
+ "wait_for_past_depends_before_skipping":
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
+ "wait_for_downstream": False,
+ "retries": DEFAULT_RETRIES,
+ "queue": DEFAULT_QUEUE,
+ "pool_slots": DEFAULT_POOL_SLOTS,
+ "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
+ "retry_delay": DEFAULT_RETRY_DELAY,
+ "retry_exponential_backoff": False,
+ "priority_weight": DEFAULT_PRIORITY_WEIGHT,
+ "weight_rule": DEFAULT_WEIGHT_RULE,
+ "inlets": [],
+ "outlets": [],
+}
+
+
# This is what handles the actual mapping.
def partial(
operator_class: type[BaseOperator],
@@ -191,43 +212,43 @@ def partial(
task_id: str,
dag: DAG | None = None,
task_group: TaskGroup | None = None,
- start_date: datetime | None = None,
- end_date: datetime | None = None,
- owner: str = DEFAULT_OWNER,
- email: None | str | Iterable[str] = None,
+ start_date: datetime | ArgNotSet = NOTSET,
+ end_date: datetime | ArgNotSet = NOTSET,
+ owner: str | ArgNotSet = NOTSET,
+ email: None | str | Iterable[str] | ArgNotSet = NOTSET,
params: collections.abc.MutableMapping | None = None,
- resources: dict[str, Any] | None = None,
- trigger_rule: str = DEFAULT_TRIGGER_RULE,
- depends_on_past: bool = False,
- ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
- wait_for_past_depends_before_skipping: bool =
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
- wait_for_downstream: bool = False,
- retries: int | None = DEFAULT_RETRIES,
- queue: str = DEFAULT_QUEUE,
- pool: str | None = None,
- pool_slots: int = DEFAULT_POOL_SLOTS,
- execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
- max_retry_delay: None | timedelta | float = None,
- retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
- retry_exponential_backoff: bool = False,
- priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
- weight_rule: str = DEFAULT_WEIGHT_RULE,
- sla: timedelta | None = None,
- max_active_tis_per_dag: int | None = None,
- max_active_tis_per_dagrun: int | None = None,
- on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- run_as_user: str | None = None,
- executor_config: dict | None = None,
- inlets: Any | None = None,
- outlets: Any | None = None,
- doc: str | None = None,
- doc_md: str | None = None,
- doc_json: str | None = None,
- doc_yaml: str | None = None,
- doc_rst: str | None = None,
+ resources: dict[str, Any] | None | ArgNotSet = NOTSET,
+ trigger_rule: str | ArgNotSet = NOTSET,
+ depends_on_past: bool | ArgNotSet = NOTSET,
+ ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
+ wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
+ wait_for_downstream: bool | ArgNotSet = NOTSET,
+ retries: int | None | ArgNotSet = NOTSET,
+ queue: str | ArgNotSet = NOTSET,
+ pool: str | ArgNotSet = NOTSET,
+ pool_slots: int | ArgNotSet = NOTSET,
+ execution_timeout: timedelta | None | ArgNotSet = NOTSET,
+ max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
+ retry_delay: timedelta | float | ArgNotSet = NOTSET,
+ retry_exponential_backoff: bool | ArgNotSet = NOTSET,
+ priority_weight: int | ArgNotSet = NOTSET,
+ weight_rule: str | ArgNotSet = NOTSET,
+ sla: timedelta | None | ArgNotSet = NOTSET,
+ max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
+ max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
+ on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+ on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+ on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+ on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
+ run_as_user: str | None | ArgNotSet = NOTSET,
+ executor_config: dict | None | ArgNotSet = NOTSET,
+ inlets: Any | None | ArgNotSet = NOTSET,
+ outlets: Any | None | ArgNotSet = NOTSET,
+ doc: str | None | ArgNotSet = NOTSET,
+ doc_md: str | None | ArgNotSet = NOTSET,
+ doc_json: str | None | ArgNotSet = NOTSET,
+ doc_yaml: str | None | ArgNotSet = NOTSET,
+ doc_rst: str | None | ArgNotSet = NOTSET,
**kwargs,
) -> OperatorPartial:
from airflow.models.dag import DagContext
@@ -242,54 +263,62 @@ def partial(
task_id = task_group.child_id(task_id)
# Merge DAG and task group level defaults into user-supplied values.
- partial_kwargs, partial_params = get_merged_defaults(
+ dag_default_args, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=params,
task_default_args=kwargs.pop("default_args", None),
)
- partial_kwargs.update(kwargs)
-
- # Always fully populate partial kwargs to exclude them from map().
- partial_kwargs.setdefault("dag", dag)
- partial_kwargs.setdefault("task_group", task_group)
- partial_kwargs.setdefault("task_id", task_id)
- partial_kwargs.setdefault("start_date", start_date)
- partial_kwargs.setdefault("end_date", end_date)
- partial_kwargs.setdefault("owner", owner)
- partial_kwargs.setdefault("email", email)
- partial_kwargs.setdefault("trigger_rule", trigger_rule)
- partial_kwargs.setdefault("depends_on_past", depends_on_past)
- partial_kwargs.setdefault("ignore_first_depends_on_past",
ignore_first_depends_on_past)
- partial_kwargs.setdefault("wait_for_past_depends_before_skipping",
wait_for_past_depends_before_skipping)
- partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream)
- partial_kwargs.setdefault("retries", retries)
- partial_kwargs.setdefault("queue", queue)
- partial_kwargs.setdefault("pool", pool)
- partial_kwargs.setdefault("pool_slots", pool_slots)
- partial_kwargs.setdefault("execution_timeout", execution_timeout)
- partial_kwargs.setdefault("max_retry_delay", max_retry_delay)
- partial_kwargs.setdefault("retry_delay", retry_delay)
- partial_kwargs.setdefault("retry_exponential_backoff",
retry_exponential_backoff)
- partial_kwargs.setdefault("priority_weight", priority_weight)
- partial_kwargs.setdefault("weight_rule", weight_rule)
- partial_kwargs.setdefault("sla", sla)
- partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
- partial_kwargs.setdefault("max_active_tis_per_dagrun",
max_active_tis_per_dagrun)
- partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
- partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
- partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
- partial_kwargs.setdefault("on_success_callback", on_success_callback)
- partial_kwargs.setdefault("run_as_user", run_as_user)
- partial_kwargs.setdefault("executor_config", executor_config)
- partial_kwargs.setdefault("inlets", inlets or [])
- partial_kwargs.setdefault("outlets", outlets or [])
- partial_kwargs.setdefault("resources", resources)
- partial_kwargs.setdefault("doc", doc)
- partial_kwargs.setdefault("doc_json", doc_json)
- partial_kwargs.setdefault("doc_md", doc_md)
- partial_kwargs.setdefault("doc_rst", doc_rst)
- partial_kwargs.setdefault("doc_yaml", doc_yaml)
+
+ # Create partial_kwargs from args and kwargs
+ partial_kwargs: dict[str, Any] = {
+ **kwargs,
+ "dag": dag,
+ "task_group": task_group,
+ "task_id": task_id,
+ "start_date": start_date,
+ "end_date": end_date,
+ "owner": owner,
+ "email": email,
+ "trigger_rule": trigger_rule,
+ "depends_on_past": depends_on_past,
+ "ignore_first_depends_on_past": ignore_first_depends_on_past,
+ "wait_for_past_depends_before_skipping":
wait_for_past_depends_before_skipping,
+ "wait_for_downstream": wait_for_downstream,
+ "retries": retries,
+ "queue": queue,
+ "pool": pool,
+ "pool_slots": pool_slots,
+ "execution_timeout": execution_timeout,
+ "max_retry_delay": max_retry_delay,
+ "retry_delay": retry_delay,
+ "retry_exponential_backoff": retry_exponential_backoff,
+ "priority_weight": priority_weight,
+ "weight_rule": weight_rule,
+ "sla": sla,
+ "max_active_tis_per_dag": max_active_tis_per_dag,
+ "max_active_tis_per_dagrun": max_active_tis_per_dagrun,
+ "on_execute_callback": on_execute_callback,
+ "on_failure_callback": on_failure_callback,
+ "on_retry_callback": on_retry_callback,
+ "on_success_callback": on_success_callback,
+ "run_as_user": run_as_user,
+ "executor_config": executor_config,
+ "inlets": inlets,
+ "outlets": outlets,
+ "resources": resources,
+ "doc": doc,
+ "doc_json": doc_json,
+ "doc_md": doc_md,
+ "doc_rst": doc_rst,
+ "doc_yaml": doc_yaml,
+ }
+
+ # Inject DAG-level default args into args provided to this function.
+ partial_kwargs.update((k, v) for k, v in dag_default_args.items() if
partial_kwargs.get(k) is NOTSET)
+
+ # Fill fields not provided by the user with default values.
+ partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k,
v in partial_kwargs.items()}
# Post-process arguments. Should be kept in sync with
_TaskDecorator.expand().
if "task_concurrency" in kwargs: # Reject deprecated option.
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index 931d262d24..84ddd9fb66 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -85,6 +85,20 @@ def test_task_mapping_default_args():
assert mapped.start_date == pendulum.instance(default_args["start_date"])
+def test_task_mapping_override_default_args():
+ default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()}
+ with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
+ literal = ["a", "b", "c"]
+ mapped = MockOperator.partial(task_id="task",
retries=1).expand(arg2=literal)
+
+ # retries should be 1 because it is provided as a partial arg
+ assert mapped.partial_kwargs["retries"] == 1
+ # start_date should be equal to default_args["start_date"] because it is
not provided as partial arg
+ assert mapped.start_date == pendulum.instance(default_args["start_date"])
+ # owner should be equal to Airflow default owner (airflow) because it is
not provided at all
+ assert mapped.owner == "airflow"
+
+
def test_map_unknown_arg_raises():
with pytest.raises(TypeError, match=r"argument 'file'"):
BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}])