This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit d1d891a178e8d6b9993190c12d416962e616255a Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Sat Oct 19 12:46:54 2024 +0100 [skip ci] --- airflow/models/abstractoperator.py | 118 ------------- airflow/models/baseoperator.py | 41 ----- airflow/models/dag.py | 10 +- airflow/serialization/schema.json | 10 +- airflow/serialization/serialized_objects.py | 22 +-- airflow/utils/decorators.py | 9 +- .../airflow/sdk/definitions/abstractoperator.py | 106 +++++++++++ .../src/airflow/sdk/definitions/baseoperator.py | 18 ++ task_sdk/src/airflow/sdk/definitions/dag.py | 196 +++++++++------------ task_sdk/tests/defintions/test_dag.py | 72 +++++++- tests/models/test_dag.py | 67 ------- 11 files changed, 314 insertions(+), 355 deletions(-) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index a29ef09f270..a27d7e26fd1 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -106,22 +106,6 @@ class AbstractOperator(Templater, TaskSDKAbstractOperator): trigger_rule: TriggerRule weight_rule: PriorityWeightStrategy - @property - def is_setup(self) -> bool: - raise NotImplementedError() - - @is_setup.setter - def is_setup(self, value: bool) -> None: - raise NotImplementedError() - - @property - def is_teardown(self) -> bool: - raise NotImplementedError() - - @is_teardown.setter - def is_teardown(self, value: bool) -> None: - raise NotImplementedError() - @property def on_failure_fail_dagrun(self): """ @@ -211,108 +195,6 @@ class AbstractOperator(Templater, TaskSDKAbstractOperator): else: setattr(parent, attr_name, rendered_content) - def as_setup(self): - self.is_setup = True - return self - - def as_teardown( - self, - *, - setups: BaseOperator | Iterable[BaseOperator] | ArgNotSet = NOTSET, - on_failure_fail_dagrun=NOTSET, - ): - self.is_teardown = True - self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS - if on_failure_fail_dagrun is not NOTSET: - self.on_failure_fail_dagrun = on_failure_fail_dagrun - if not isinstance(setups, ArgNotSet): - setups = [setups] if isinstance(setups, DependencyMixin) else setups - for s in setups: - s.is_setup = True - s >> self - return self - - def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: - """ - Get a flat set of relative IDs, upstream or downstream. - - Will recurse each relative found in the direction specified. - - :param upstream: Whether to look for upstream or downstream relatives. - """ - dag = self.get_dag() - if not dag: - return set() - - relatives: set[str] = set() - - # This is intentionally implemented as a loop, instead of calling - # get_direct_relative_ids() recursively, since Python has significant - # limitation on stack level, and a recursive implementation can blow up - # if a DAG contains very long routes. - task_ids_to_trace = self.get_direct_relative_ids(upstream) - while task_ids_to_trace: - task_ids_to_trace_next: set[str] = set() - for task_id in task_ids_to_trace: - if task_id in relatives: - continue - task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) - relatives.add(task_id) - task_ids_to_trace = task_ids_to_trace_next - - return relatives - - def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: - """Get a flat list of relatives, either upstream or downstream.""" - dag = self.get_dag() - if not dag: - return set() - return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] - - def get_upstreams_follow_setups(self) -> Iterable[Operator]: - """All upstreams and, for each upstream setup, its respective teardowns.""" - for task in self.get_flat_relatives(upstream=True): - yield task - if task.is_setup: - for t in task.downstream_list: - if t.is_teardown and t != self: - yield t - - def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: - """ - Only *relevant* upstream setups and their teardowns. - - This method is meant to be used when we are clearing the task (non-upstream) and we need - to add in the *relevant* setups and their teardowns. - - Relevant in this case means, the setup has a teardown that is downstream of ``self``, - or the setup has no teardowns. - """ - downstream_teardown_ids = { - x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown - } - for task in self.get_flat_relatives(upstream=True): - if not task.is_setup: - continue - has_no_teardowns = not any(True for x in task.downstream_list if x.is_teardown) - # if task has no teardowns or has teardowns downstream of self - if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): - yield task - for t in task.downstream_list: - if t.is_teardown and t != self: - yield t - - def get_upstreams_only_setups(self) -> Iterable[Operator]: - """ - Return relevant upstream setups. - - This method is meant to be used when we are checking task dependencies where we need - to wait for all the upstream setups to complete before we can run the task. - """ - for task in self.get_upstreams_only_setups_and_teardowns(): - if task.is_setup: - yield task - def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """ Return mapped nodes that are direct dependencies of the current task. diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0dc533bef7c..a9628873e5e 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -28,7 +28,6 @@ import contextlib import copy import functools import logging -import sys from datetime import datetime, timedelta from functools import wraps from threading import local @@ -879,46 +878,6 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator else: return self.downstream_list - @property - def is_setup(self) -> bool: - """ - Whether the operator is a setup task. - - :meta private: - """ - return self._is_setup - - @is_setup.setter - def is_setup(self, value: bool) -> None: - """ - Setter for is_setup property. - - :meta private: - """ - if self.is_teardown and value: - raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") - self._is_setup = value - - @property - def is_teardown(self) -> bool: - """ - Whether the operator is a teardown task. - - :meta private: - """ - return self._is_teardown - - @is_teardown.setter - def is_teardown(self, value: bool) -> None: - """ - Setter for is_teardown property. - - :meta private: - """ - if self.is_setup and value: - raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") - self._is_teardown = value - @staticmethod def xcom_push( context: Any, diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 691a64ed290..642b4c2f7e1 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -100,7 +100,7 @@ from airflow.models.taskinstance import ( clear_task_instances, ) from airflow.models.tasklog import LogTemplate -from airflow.sdk import DAG as TaskSDKDag, dag as dag +from airflow.sdk import DAG as TaskSDKDag, dag as task_sdk_dag_decorator from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.security import permissions from airflow.settings import json @@ -296,6 +296,14 @@ def _create_orm_dagrun( return run +if TYPE_CHECKING: + dag = task_sdk_dag_decorator +else: + + def dag(dag_id: str = "", **kwargs): + return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, __warnings_stacklevel_delta=3) + + @functools.total_ordering @attrs.define(hash=False, repr=False, eq=False) class DAG(TaskSDKDag, LoggingMixin): diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index fe1e63c4903..e313e2c7af7 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -137,7 +137,7 @@ "type": "object", "properties": { "params": { "$ref": "#/definitions/params" }, - "_dag_id": { "type": "string" }, + "dag_id": { "type": "string" }, "tasks": { "$ref": "#/definitions/tasks" }, "timezone": { "$ref": "#/definitions/timezone" }, "owner_links": { "type": "object" }, @@ -157,10 +157,10 @@ ] }, "orientation": { "type" : "string"}, - "_dag_display_property_value": { "type" : "string"}, + "dag_display_name": { "type" : "string"}, "_description": { "type" : "string"}, "_concurrency": { "type" : "number"}, - "_max_active_tasks": { "type" : "number"}, + "max_active_tasks": { "type" : "number"}, "max_active_runs": { "type" : "number"}, "max_consecutive_failed_dag_runs": { "type" : "number"}, "default_args": { "$ref": "#/definitions/dict" }, @@ -175,7 +175,7 @@ "has_on_failure_callback": { "type": "boolean" }, "render_template_as_native_obj": { "type": "boolean" }, "tags": { "type": "array" }, - "_task_group": {"anyOf": [ + "task_group": {"anyOf": [ { "type": "null" }, { "$ref": "#/definitions/task_group" } ]}, @@ -183,7 +183,7 @@ "dag_dependencies": { "$ref": "#/definitions/dag_dependencies" } }, "required": [ - "_dag_id", + "dag_id", "fileloc", "tasks" ], diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8b674b2aa0f..9aafaf1f54a 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -920,10 +920,11 @@ class BaseSerialization: to account for the case where the default value of the field is None but has the ``field = field or {}`` set. """ - if attrname in cls._CONSTRUCTOR_PARAMS and ( - cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []]) - ): - return True + if attrname in cls._CONSTRUCTOR_PARAMS: + if cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []]): + return True + if cls._CONSTRUCTOR_PARAMS[attrname] is attrs.NOTHING and value is None: + return True return False @classmethod @@ -1613,7 +1614,7 @@ class SerializedDAG(DAG, BaseSerialization): ] dag_deps.extend(DependencyDetector.detect_dag_dependencies(dag)) serialized_dag["dag_dependencies"] = [x.__dict__ for x in sorted(dag_deps)] - serialized_dag["_task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group) + serialized_dag["task_group"] = TaskGroupSerialization.serialize_task_group(dag.task_group) # Edge info in the JSON exactly matches our internal structure serialized_dag["edge_info"] = dag.edge_info @@ -1633,7 +1634,7 @@ class SerializedDAG(DAG, BaseSerialization): @classmethod def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG: """Deserializes a DAG from a JSON object.""" - dag = SerializedDAG(dag_id=encoded_dag["_dag_id"], schedule=None) + dag = SerializedDAG(dag_id=encoded_dag["dag_id"], schedule=None) for k, v in encoded_dag.items(): if k == "_downstream_task_ids": @@ -1668,16 +1669,17 @@ class SerializedDAG(DAG, BaseSerialization): v = set(v) # else use v as it is - setattr(dag, k, v) + object.__setattr__(dag, k, v) # Set _task_group - if "_task_group" in encoded_dag: - dag.task_group = TaskGroupSerialization.deserialize_task_group( - encoded_dag["_task_group"], + if "task_group" in encoded_dag: + tg = TaskGroupSerialization.deserialize_task_group( + encoded_dag["task_group"], None, dag.task_dict, dag, ) + object.__setattr__(dag, "task_group", tg) else: # This must be old data that had no task_group. Create a root TaskGroup and add # all tasks to it. diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index e299999423e..78044e4e357 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -69,8 +69,9 @@ def _balance_parens(after_decorator): class _autostacklevel_warn: - def __init__(self): + def __init__(self, delta): self.warnings = __import__("warnings") + self.delta = delta def __getattr__(self, name): return getattr(self.warnings, name) @@ -79,11 +80,11 @@ class _autostacklevel_warn: return dir(self.warnings) def warn(self, message, category=None, stacklevel=1, source=None): - self.warnings.warn(message, category, stacklevel + 2, source) + self.warnings.warn(message, category, stacklevel + self.delta, source) -def fixup_decorator_warning_stack(func): +def fixup_decorator_warning_stack(func, delta: int = 2): if func.__globals__.get("warnings") is sys.modules["warnings"]: # Yes, this is more than slightly hacky, but it _automatically_ sets the right stacklevel parameter to # `warnings.warn` to ignore the decorator. - func.__globals__["warnings"] = _autostacklevel_warn() + func.__globals__["warnings"] = _autostacklevel_warn(delta) diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py index 54b1e30ab81..6f90ae7f118 100644 --- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py @@ -31,6 +31,7 @@ from typing import ( ) from airflow.sdk.definitions.node import DAGNode +from airflow.sdk.definitions.mixins import DependencyMixin from airflow.utils.log.secrets_masker import redact from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule @@ -41,6 +42,7 @@ if TYPE_CHECKING: import jinja2 # Slow import. from airflow.models.baseoperatorlink import BaseOperatorLink + from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG @@ -99,6 +101,8 @@ class AbstractOperator(DAGNode): trigger_rule: TriggerRule _needs_expansion: bool | None = None _on_failure_fail_dagrun = False + is_setup: bool = False + is_teardown: bool = False HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset( ( @@ -163,3 +167,105 @@ class AbstractOperator(DAGNode): # "task_group_id.task_id" -> "task_id" return self.task_id[len(tg.node_id) + 1 :] return self.task_id + + def as_setup(self): + self.is_setup = True + return self + + def as_teardown( + self, + *, + setups: BaseOperator | Iterable[BaseOperator] | None = None, + on_failure_fail_dagrun: bool | None = None, + ): + self.is_teardown = True + self.trigger_rule = TriggerRule.ALL_DONE_SETUP_SUCCESS + if on_failure_fail_dagrun is not None: + self.on_failure_fail_dagrun = on_failure_fail_dagrun + if setups is not None: + setups = [setups] if isinstance(setups, DependencyMixin) else setups + for s in setups: + s.is_setup = True + s >> self + return self + + def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: + """ + Get a flat set of relative IDs, upstream or downstream. + + Will recurse each relative found in the direction specified. + + :param upstream: Whether to look for upstream or downstream relatives. + """ + dag = self.get_dag() + if not dag: + return set() + + relatives: set[str] = set() + + # This is intentionally implemented as a loop, instead of calling + # get_direct_relative_ids() recursively, since Python has significant + # limitation on stack level, and a recursive implementation can blow up + # if a DAG contains very long routes. + task_ids_to_trace = self.get_direct_relative_ids(upstream) + while task_ids_to_trace: + task_ids_to_trace_next: set[str] = set() + for task_id in task_ids_to_trace: + if task_id in relatives: + continue + task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) + relatives.add(task_id) + task_ids_to_trace = task_ids_to_trace_next + + return relatives + + def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: + """Get a flat list of relatives, either upstream or downstream.""" + dag = self.get_dag() + if not dag: + return set() + return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] + + def get_upstreams_follow_setups(self) -> Iterable[Operator]: + """All upstreams and, for each upstream setup, its respective teardowns.""" + for task in self.get_flat_relatives(upstream=True): + yield task + if task.is_setup: + for t in task.downstream_list: + if t.is_teardown and t != self: + yield t + + def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: + """ + Only *relevant* upstream setups and their teardowns. + + This method is meant to be used when we are clearing the task (non-upstream) and we need + to add in the *relevant* setups and their teardowns. + + Relevant in this case means, the setup has a teardown that is downstream of ``self``, + or the setup has no teardowns. + """ + downstream_teardown_ids = { + x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown + } + for task in self.get_flat_relatives(upstream=True): + if not task.is_setup: + continue + has_no_teardowns = not any(True for x in task.downstream_list if x.is_teardown) + # if task has no teardowns or has teardowns downstream of self + if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): + yield task + for t in task.downstream_list: + if t.is_teardown and t != self: + yield t + + def get_upstreams_only_setups(self) -> Iterable[Operator]: + """ + Return relevant upstream setups. + + This method is meant to be used when we are checking task dependencies where we need + to wait for all the upstream setups to complete before we can run the task. + """ + for task in self.get_upstreams_only_setups_and_teardowns(): + if task.is_setup: + yield task diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index eb7e57907eb..57a85988987 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -557,6 +557,9 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): logger_name: str | None = None allow_nested_operators: bool = True + is_setup: bool = False + is_teardown: bool = False + template_fields: Collection[str] = () template_ext: Sequence[str] = () @@ -1041,6 +1044,21 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): return Resources(**resources) + def _convert_is_setup(self, value: bool) -> bool: + """ + Setter for is_setup property. + + :meta private: + """ + if self.is_teardown and value: + raise ValueError(f"Cannot mark task '{self.task_id}' as setup; task is already a teardown.") + return value + + def _convert_is_teardown(self, value: bool) -> bool: + if self.is_setup and value: + raise ValueError(f"Cannot mark task '{self.task_id}' as teardown; task is already a setup.") + return value + @property def task_display_name(self) -> str: return self._task_display_name or self.task_id diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 0ad82e52455..21579e87356 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -947,6 +947,7 @@ class DAG: "has_on_failure_callback", "auto_register", "fail_stop", + "schedule", } cls.__serialized_fields = frozenset(vars(DAG(dag_id="test", schedule=None))) - exclusion_list return cls.__serialized_fields @@ -984,114 +985,93 @@ class DAG: yield owner, link -# NOTE: Please keep the list of arguments in sync with DAG.__init__. -# Only exception: dag_id here should have a default value, but not in DAG. -def dag( - dag_id: str = "", - description: str | None = None, - schedule: ScheduleArg = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - template_searchpath: str | Iterable[str] | None = None, - template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, - user_defined_macros: dict | None = None, - user_defined_filters: dict | None = None, - default_args: dict | None = None, - max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), - max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), - max_consecutive_failed_dag_runs: int = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ), - dagrun_timeout: timedelta | None = None, - sla_miss_callback: Any = None, - catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), - on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - doc_md: str | None = None, - params: abc.MutableMapping | None = None, - access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, - is_paused_upon_creation: bool | None = None, - jinja_environment_kwargs: dict | None = None, - render_template_as_native_obj: bool = False, - tags: Collection[str] | None = None, - owner_links: dict[str, str] | None = None, - auto_register: bool = True, - fail_stop: bool = False, - dag_display_name: str | None = None, -) -> Callable[[Callable], Callable[..., DAG]]: - """ - Python dag decorator which wraps a function into an Airflow DAG. +if TYPE_CHECKING: + # NOTE: Please keep the list of arguments in sync with DAG.__init__. + # Only exception: dag_id here should have a default value, but not in DAG. + def dag( + dag_id: str = "", + *, + description: str | None = None, + schedule: ScheduleArg = None, + start_date: datetime | None = None, + end_date: datetime | None = None, + template_searchpath: str | Iterable[str] | None = None, + template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, + user_defined_macros: dict | None = None, + user_defined_filters: dict | None = None, + default_args: dict | None = None, + max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), + max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), + max_consecutive_failed_dag_runs: int = airflow_conf.getint( + "core", "max_consecutive_failed_dag_runs_per_dag" + ), + dagrun_timeout: timedelta | None = None, + sla_miss_callback: Any = None, + catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), + on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, + doc_md: str | None = None, + params: abc.MutableMapping | None = None, + access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, + is_paused_upon_creation: bool | None = None, + jinja_environment_kwargs: dict | None = None, + render_template_as_native_obj: bool = False, + tags: Collection[str] | None = None, + owner_links: dict[str, str] | None = None, + auto_register: bool = True, + fail_stop: bool = False, + dag_display_name: str | None = None, + ) -> Callable[[Callable], Callable[..., DAG]]: + """ + Python dag decorator which wraps a function into an Airflow DAG. - Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. + Accepts kwargs for operator kwarg. Can be used to parameterize DAGs. - :param dag_args: Arguments for DAG object - :param dag_kwargs: Kwargs for DAG object. - """ + :param dag_args: Arguments for DAG object + :param dag_kwargs: Kwargs for DAG object. + """ +else: - def wrapper(f: Callable) -> Callable[..., DAG]: - @functools.wraps(f) - def factory(*args, **kwargs): - # Generate signature for decorated function and bind the arguments when called - # we do this to extract parameters, so we can annotate them on the DAG object. - # In addition, this fails if we are missing any args/kwargs with TypeError as expected. - f_sig = signature(f).bind(*args, **kwargs) - # Apply defaults to capture default values if set. - f_sig.apply_defaults() - - # Initialize DAG with bound arguments - with DAG( - dag_id or f.__name__, - description=description, - start_date=start_date, - end_date=end_date, - template_searchpath=template_searchpath, - template_undefined=template_undefined, - user_defined_macros=user_defined_macros, - user_defined_filters=user_defined_filters, - default_args=default_args, - max_active_tasks=max_active_tasks, - max_active_runs=max_active_runs, - max_consecutive_failed_dag_runs=max_consecutive_failed_dag_runs, - dagrun_timeout=dagrun_timeout, - sla_miss_callback=sla_miss_callback, - catchup=catchup, - on_success_callback=on_success_callback, - on_failure_callback=on_failure_callback, - doc_md=doc_md, - params=params, - access_control=access_control, - is_paused_upon_creation=is_paused_upon_creation, - jinja_environment_kwargs=jinja_environment_kwargs, - render_template_as_native_obj=render_template_as_native_obj, - tags=tags, - schedule=schedule, - owner_links=owner_links, - auto_register=auto_register, - fail_stop=fail_stop, - dag_display_name=dag_display_name, - ) as dag_obj: - # Set DAG documentation from function documentation if it exists and doc_md is not set. - if f.__doc__ and not dag_obj.doc_md: - dag_obj.doc_md = f.__doc__ - - # Generate DAGParam for each function arg/kwarg and replace it for calling the function. - # All args/kwargs for function will be DAGParam object and replaced on execution time. - f_kwargs = {} - for name, value in f_sig.arguments.items(): - f_kwargs[name] = dag_obj.param(name, value) - - # set file location to caller source path - back = sys._getframe().f_back - dag_obj.fileloc = back.f_code.co_filename if back else "" - - # Invoke function to create operators in the DAG scope. - f(**f_kwargs) - - # Return dag object such that it's accessible in Globals. - return dag_obj - - # Ensure that warnings from inside DAG() are emitted from the caller, not here - fixup_decorator_warning_stack(factory) - return factory - - return wrapper + def dag(dag_id="", __DAG_class=DAG, __warnings_stacklevel_delta=2, **decorator_kwargs): + # TODO: Task-SDK: remove __DAG_class + # __DAG_class is a temporary hack to allow the dag decorator in airflow.models.dag to continue to + # return SchedulerDag objects + DAG = __DAG_class + + def wrapper(f: Callable) -> Callable[..., DAG]: + @functools.wraps(f) + def factory(*args, **kwargs): + # Generate signature for decorated function and bind the arguments when called + # we do this to extract parameters, so we can annotate them on the DAG object. + # In addition, this fails if we are missing any args/kwargs with TypeError as expected. + f_sig = signature(f).bind(*args, **kwargs) + # Apply defaults to capture default values if set. + f_sig.apply_defaults() + + # Initialize DAG with bound arguments + with DAG(dag_id or f.__name__, **decorator_kwargs) as dag_obj: + # Set DAG documentation from function documentation if it exists and doc_md is not set. + if f.__doc__ and not dag_obj.doc_md: + dag_obj.doc_md = f.__doc__ + + # Generate DAGParam for each function arg/kwarg and replace it for calling the function. + # All args/kwargs for function will be DAGParam object and replaced on execution time. + f_kwargs = {} + for name, value in f_sig.arguments.items(): + f_kwargs[name] = dag_obj.param(name, value) + + # set file location to caller source path + back = sys._getframe().f_back + dag_obj.fileloc = back.f_code.co_filename if back else "" + + # Invoke function to create operators in the DAG scope. + f(**f_kwargs) + + # Return dag object such that it's accessible in Globals. + return dag_obj + + # Ensure that warnings from inside DAG() are emitted from the caller, not here + fixup_decorator_warning_stack(factory) + return factory + + return wrapper diff --git a/task_sdk/tests/defintions/test_dag.py b/task_sdk/tests/defintions/test_dag.py index dd4ded2f4fe..2300e97f07e 100644 --- a/task_sdk/tests/defintions/test_dag.py +++ b/task_sdk/tests/defintions/test_dag.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import weakref from datetime import datetime, timedelta, timezone import pytest @@ -23,7 +24,7 @@ import pytest from airflow.exceptions import DuplicateTaskIdFound from airflow.models.param import Param, ParamsDict from airflow.sdk.definitions.baseoperator import BaseOperator -from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.dag import DAG, dag as dag_decorator DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) @@ -245,3 +246,72 @@ class TestDag: # Make sure we don't affect the original! assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids + + +class TestDagDecorator: + DEFAULT_ARGS = { + "owner": "test", + "depends_on_past": True, + "start_date": datetime.now(tz=timezone.utc), + "retries": 1, + "retry_delay": timedelta(minutes=1), + } + VALUE = 42 + + def test_fileloc(self): + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(): ... + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "noop_pipeline" + assert dag.fileloc == __file__ + + def test_set_dag_id(self): + """Test that checks you can set dag_id from decorator.""" + + @dag_decorator("test", schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(): ... + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "test" + + def test_default_dag_id(self): + """Test that @dag uses function name as default dag id.""" + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(): ... + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "noop_pipeline" + + @pytest.mark.parametrize( + argnames=["dag_doc_md", "expected_doc_md"], + argvalues=[ + pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"), + pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"), + ], + ) + def test_documentation_added(self, dag_doc_md, expected_doc_md): + """Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set.""" + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md) + def noop_pipeline(): + """Regular DAG documentation""" + + dag = noop_pipeline() + assert isinstance(dag, DAG) + assert dag.dag_id == "noop_pipeline" + assert dag.doc_md == expected_doc_md + + def test_fails_if_arg_not_set(self): + """Test that @dag decorated function fails if positional argument is not set""" + + @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) + def noop_pipeline(value): ... + + # Test that if arg is not passed it raises a type error as expected. + with pytest.raises(TypeError): + noop_pipeline() diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 0b0b3c340f7..be8dde52331 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -23,7 +23,6 @@ import logging import os import pickle import re -import weakref from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING @@ -92,7 +91,6 @@ from airflow.utils.weight_rule import WeightRule from tests.models import DEFAULT_DATE from tests.plugins.priority_weight_strategy import ( FactorPriorityWeightStrategy, - NotRegisteredPriorityWeightStrategy, StaticTestPriorityWeightStrategy, TestPriorityWeightStrategyPlugin, ) @@ -2517,54 +2515,6 @@ class TestDagDecorator: def teardown_method(self): clear_db_runs() - def test_fileloc(self): - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(): ... - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "noop_pipeline" - assert dag.fileloc == __file__ - - def test_set_dag_id(self): - """Test that checks you can set dag_id from decorator.""" - - @dag_decorator("test", schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(): ... - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "test" - - def test_default_dag_id(self): - """Test that @dag uses function name as default dag id.""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(): ... - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "noop_pipeline" - - @pytest.mark.parametrize( - argnames=["dag_doc_md", "expected_doc_md"], - argvalues=[ - pytest.param("dag docs.", "dag docs.", id="use_dag_doc_md"), - pytest.param(None, "Regular DAG documentation", id="use_dag_docstring"), - ], - ) - def test_documentation_added(self, dag_doc_md, expected_doc_md): - """Test that @dag uses function docs as doc_md for DAG object if doc_md is not explicitly set.""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS, doc_md=dag_doc_md) - def noop_pipeline(): - """Regular DAG documentation""" - - dag = noop_pipeline() - assert isinstance(dag, DAG) - assert dag.dag_id == "noop_pipeline" - assert dag.doc_md == expected_doc_md - def test_documentation_template_rendered(self): """Test that @dag uses function docs as doc_md for DAG object""" @@ -2577,7 +2527,6 @@ class TestDagDecorator: """ dag = noop_pipeline() - assert isinstance(dag, DAG) assert dag.dag_id == "noop_pipeline" assert "Regular DAG documentation" in dag.doc_md @@ -2597,25 +2546,9 @@ class TestDagDecorator: def markdown_docs(): ... dag = markdown_docs() - assert isinstance(dag, DAG) assert dag.dag_id == "test-dag" assert dag.doc_md == raw_content - def test_fails_if_arg_not_set(self): - """Test that @dag decorated function fails if positional argument is not set""" - - @dag_decorator(schedule=None, default_args=self.DEFAULT_ARGS) - def noop_pipeline(value): - @task_decorator - def return_num(num): - return num - - return_num(value) - - # Test that if arg is not passed it raises a type error as expected. - with pytest.raises(TypeError): - noop_pipeline() - def test_dag_param_resolves(self): """Test that dag param is correctly resolved by operator"""
