This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 0f92b1461d9c2c8e464ccd402e5e40aee1d9be22 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Mon Oct 21 18:31:51 2024 +0100 [skip ci] --- .pre-commit-config.yaml | 1 + airflow/decorators/base.py | 2 +- airflow/decorators/sensor.py | 4 +- airflow/models/abstractoperator.py | 11 +- airflow/models/baseoperator.py | 315 +++++++++++---------- airflow/models/dag.py | 6 +- airflow/models/mappedoperator.py | 4 +- airflow/models/skipmixin.py | 2 +- airflow/models/xcom_arg.py | 1 - airflow/operators/python.py | 4 +- airflow/serialization/serialized_objects.py | 2 +- airflow/utils/task_group.py | 5 +- airflow/utils/types.py | 2 +- .../pre_commit/base_operator_partial_arguments.py | 80 ++++-- task_sdk/pyproject.toml | 2 + .../airflow/sdk/definitions/abstractoperator.py | 6 +- .../src/airflow/sdk/definitions/baseoperator.py | 7 +- task_sdk/src/airflow/sdk/definitions/dag.py | 77 +++-- task_sdk/src/airflow/sdk/definitions/edges.py | 4 +- task_sdk/src/airflow/sdk/definitions/mixins.py | 6 +- task_sdk/src/airflow/sdk/definitions/node.py | 4 +- task_sdk/src/airflow/sdk/types.py | 2 +- task_sdk/tests/defintions/test_baseoperator.py | 4 +- task_sdk/tests/defintions/test_dag.py | 68 +++++ tests/models/test_baseoperator.py | 5 +- tests/models/test_dag.py | 65 +---- tests/utils/test_task_group.py | 1 - tests_common/test_utils/mock_operators.py | 3 +- 28 files changed, 365 insertions(+), 328 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e94e4191e8..7d9cd6b2292 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1189,6 +1189,7 @@ repos: ^airflow/utils/helpers.py$ | ^providers/src/airflow/providers/ | ^(providers/)?tests/ | + task_sdk/src/airflow/sdk/definitions/dag.py$ | ^dev/.*\.py$ | ^scripts/.*\.py$ | ^docker_tests/.*$ | diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index e7a21e2919b..3290f5342ce 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -187,7 +187,7 @@ class DecoratedOperator(BaseOperator): # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: Sequence[str] = ("python_callable",) + shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",) def __init__( self, diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py index c332a78f95c..c37cd08d6b4 100644 --- a/airflow/decorators/sensor.py +++ b/airflow/decorators/sensor.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Callable, ClassVar, Sequence from airflow.decorators.base import get_unique_task_id, task_decorator_factory from airflow.sensors.python import PythonSensor @@ -48,7 +48,7 @@ class DecoratedSensorOperator(PythonSensor): # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects (e.g protobuf). - shallow_copy_attrs: Sequence[str] = ("python_callable",) + shallow_copy_attrs: ClassVar[Sequence[str]] = ("python_callable",) def __init__( self, diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index a27d7e26fd1..93c0fd2d93a 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -20,7 +20,7 @@ from __future__ import annotations import datetime import inspect from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence import methodtools from sqlalchemy import select @@ -28,7 +28,6 @@ from sqlalchemy import select from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated -from airflow.models.taskmixin import DependencyMixin from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator from airflow.template.templater import Templater from airflow.utils.context import Context @@ -39,7 +38,6 @@ from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State, TaskInstanceState from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule -from airflow.sdk.types import NOTSET, ArgNotSet from airflow.utils.weight_rule import WeightRule TaskStateChangeCallback = Callable[[Context], None] @@ -53,10 +51,9 @@ if TYPE_CHECKING: from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.dag import DAG as SchedulerDAG from airflow.models.mappedoperator import MappedOperator - from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance - from airflow.sdk import BaseOperator, DAG - from airflow.sdk.defintions.node import DAGNode + from airflow.sdk import DAG, BaseOperator + from airflow.sdk.definitions.node import DAGNode from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs from airflow.utils.task_group import TaskGroup @@ -262,7 +259,7 @@ class AbstractOperator(Templater, TaskSDKAbstractOperator): """ if (group := self.task_group) is None: return - # TODO: Task-SDK: this type ignore shouldn't be necssary, revisit once mapping support is fully in the + # TODO: Task-SDK: this type ignore shouldn't be necessary, revisit once mapping support is fully in the # SDK yield from group.iter_mapped_task_groups() # type: ignore[misc] diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index a9628873e5e..f0d1ee6f965 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -71,7 +71,6 @@ from airflow.models.abstractoperator import ( ) from airflow.models.base import _sentinel from airflow.models.mappedoperator import OperatorPartial, validate_mapping_kwargs -from airflow.models.pool import Pool from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.taskmixin import DependencyMixin @@ -83,7 +82,6 @@ from airflow.sdk.definitions.baseoperator import ( get_merged_defaults, ) from airflow.serialization.enums import DagAttributeTypes -from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep @@ -106,14 +104,15 @@ if TYPE_CHECKING: from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.baseoperatorlink import BaseOperatorLink - from airflow.models.dag import DAG, DAG as SchedulerDAG + from airflow.models.dag import DAG as SchedulerDAG from airflow.models.operator import Operator + from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.triggers.base import BaseTrigger, StartTriggerArgs from airflow.utils.types import ArgNotSet -# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic failes to resolve +# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic fails to resolve # types for serialization from airflow.utils.task_group import TaskGroup # noqa: TCH001 @@ -192,160 +191,148 @@ _PARTIAL_DEFAULTS: dict[str, Any] = { # This is what handles the actual mapping. -def partial( - operator_class: type[BaseOperator], - *, - task_id: str, - dag: DAG | None = None, - task_group: TaskGroup | None = None, - start_date: datetime | ArgNotSet = NOTSET, - end_date: datetime | ArgNotSet = NOTSET, - owner: str | ArgNotSet = NOTSET, - email: None | str | Iterable[str] | ArgNotSet = NOTSET, - params: collections.abc.MutableMapping | None = None, - resources: dict[str, Any] | None | ArgNotSet = NOTSET, - trigger_rule: str | ArgNotSet = NOTSET, - depends_on_past: bool | ArgNotSet = NOTSET, - ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, - wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, - wait_for_downstream: bool | ArgNotSet = NOTSET, - retries: int | None | ArgNotSet = NOTSET, - queue: str | ArgNotSet = NOTSET, - pool: str | ArgNotSet = NOTSET, - pool_slots: int | ArgNotSet = NOTSET, - execution_timeout: timedelta | None | ArgNotSet = NOTSET, - max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, - retry_delay: timedelta | float | ArgNotSet = NOTSET, - retry_exponential_backoff: bool | ArgNotSet = NOTSET, - priority_weight: int | ArgNotSet = NOTSET, - weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, - sla: timedelta | None | ArgNotSet = NOTSET, - map_index_template: str | None | ArgNotSet = NOTSET, - max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, - max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, - on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, - run_as_user: str | None | ArgNotSet = NOTSET, - executor: str | None | ArgNotSet = NOTSET, - executor_config: dict | None | ArgNotSet = NOTSET, - inlets: Any | None | ArgNotSet = NOTSET, - outlets: Any | None | ArgNotSet = NOTSET, - doc: str | None | ArgNotSet = NOTSET, - doc_md: str | None | ArgNotSet = NOTSET, - doc_json: str | None | ArgNotSet = NOTSET, - doc_yaml: str | None | ArgNotSet = NOTSET, - doc_rst: str | None | ArgNotSet = NOTSET, - task_display_name: str | None | ArgNotSet = NOTSET, - logger_name: str | None | ArgNotSet = NOTSET, - allow_nested_operators: bool = True, - **kwargs, -) -> OperatorPartial: - from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext - - validate_mapping_kwargs(operator_class, "partial", kwargs) - - dag = dag or DagContext.get_current() - if dag: - task_group = task_group or TaskGroupContext.get_current(dag) - if task_group: - task_id = task_group.child_id(task_id) - - # Merge DAG and task group level defaults into user-supplied values. - dag_default_args, partial_params = get_merged_defaults( - dag=dag, - task_group=task_group, - task_params=params, - task_default_args=kwargs.pop("default_args", None), - ) - # Create partial_kwargs from args and kwargs - partial_kwargs: dict[str, Any] = { +if TYPE_CHECKING: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + start_date: datetime | ArgNotSet = NOTSET, + end_date: datetime | ArgNotSet = NOTSET, + owner: str | ArgNotSet = NOTSET, + email: None | str | Iterable[str] | ArgNotSet = NOTSET, + params: collections.abc.MutableMapping | None = None, + resources: dict[str, Any] | None | ArgNotSet = NOTSET, + trigger_rule: str | ArgNotSet = NOTSET, + depends_on_past: bool | ArgNotSet = NOTSET, + ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, + wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, + wait_for_downstream: bool | ArgNotSet = NOTSET, + retries: int | None | ArgNotSet = NOTSET, + queue: str | ArgNotSet = NOTSET, + pool: str | ArgNotSet = NOTSET, + pool_slots: int | ArgNotSet = NOTSET, + execution_timeout: timedelta | None | ArgNotSet = NOTSET, + max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, + retry_delay: timedelta | float | ArgNotSet = NOTSET, + retry_exponential_backoff: bool | ArgNotSet = NOTSET, + priority_weight: int | ArgNotSet = NOTSET, + weight_rule: str | PriorityWeightStrategy | ArgNotSet = NOTSET, + sla: timedelta | None | ArgNotSet = NOTSET, + map_index_template: str | None | ArgNotSet = NOTSET, + max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, + max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, + on_execute_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_failure_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_success_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_retry_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + on_skipped_callback: None + | TaskStateChangeCallback + | list[TaskStateChangeCallback] + | ArgNotSet = NOTSET, + run_as_user: str | None | ArgNotSet = NOTSET, + executor: str | None | ArgNotSet = NOTSET, + executor_config: dict | None | ArgNotSet = NOTSET, + inlets: Any | None | ArgNotSet = NOTSET, + outlets: Any | None | ArgNotSet = NOTSET, + doc: str | None | ArgNotSet = NOTSET, + doc_md: str | None | ArgNotSet = NOTSET, + doc_json: str | None | ArgNotSet = NOTSET, + doc_yaml: str | None | ArgNotSet = NOTSET, + doc_rst: str | None | ArgNotSet = NOTSET, + task_display_name: str | None | ArgNotSet = NOTSET, + logger_name: str | None | ArgNotSet = NOTSET, + allow_nested_operators: bool = True, + **kwargs, + ) -> OperatorPartial: ... +else: + + def partial( + operator_class: type[BaseOperator], + *, + task_id: str, + dag: DAG | None = None, + task_group: TaskGroup | None = None, + params: collections.abc.MutableMapping | None = None, **kwargs, - "dag": dag, - "task_group": task_group, - "task_id": task_id, - "map_index_template": map_index_template, - "start_date": start_date, - "end_date": end_date, - "owner": owner, - "email": email, - "trigger_rule": trigger_rule, - "depends_on_past": depends_on_past, - "ignore_first_depends_on_past": ignore_first_depends_on_past, - "wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping, - "wait_for_downstream": wait_for_downstream, - "retries": retries, - "queue": queue, - "pool": pool, - "pool_slots": pool_slots, - "execution_timeout": execution_timeout, - "max_retry_delay": max_retry_delay, - "retry_delay": retry_delay, - "retry_exponential_backoff": retry_exponential_backoff, - "priority_weight": priority_weight, - "weight_rule": weight_rule, - "sla": sla, - "max_active_tis_per_dag": max_active_tis_per_dag, - "max_active_tis_per_dagrun": max_active_tis_per_dagrun, - "on_execute_callback": on_execute_callback, - "on_failure_callback": on_failure_callback, - "on_retry_callback": on_retry_callback, - "on_success_callback": on_success_callback, - "on_skipped_callback": on_skipped_callback, - "run_as_user": run_as_user, - "executor": executor, - "executor_config": executor_config, - "inlets": inlets, - "outlets": outlets, - "resources": resources, - "doc": doc, - "doc_json": doc_json, - "doc_md": doc_md, - "doc_rst": doc_rst, - "doc_yaml": doc_yaml, - "task_display_name": task_display_name, - "logger_name": logger_name, - "allow_nested_operators": allow_nested_operators, - } - - # Inject DAG-level default args into args provided to this function. - partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET) - - # Fill fields not provided by the user with default values. - partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, v in partial_kwargs.items()} - - # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). - if "task_concurrency" in kwargs: # Reject deprecated option. - raise TypeError("unexpected argument: task_concurrency") - if partial_kwargs["wait_for_downstream"]: - partial_kwargs["depends_on_past"] = True - partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"]) - partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) - if partial_kwargs["pool"] is None: - partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME - if partial_kwargs["pool_slots"] < 1: - dag_str = "" + ): + from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext + + validate_mapping_kwargs(operator_class, "partial", kwargs) + + dag = dag or DagContext.get_current() if dag: - dag_str = f" in dag {dag.dag_id}" - raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") - partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) - partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") - if partial_kwargs["max_retry_delay"] is not None: - partial_kwargs["max_retry_delay"] = coerce_timedelta( - partial_kwargs["max_retry_delay"], - key="max_retry_delay", + task_group = task_group or TaskGroupContext.get_current(dag) + if task_group: + task_id = task_group.child_id(task_id) + + # Merge DAG and task group level defaults into user-supplied values. + dag_default_args, partial_params = get_merged_defaults( + dag=dag, + task_group=task_group, + task_params=params, + task_default_args=kwargs.pop("default_args", None), ) - partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {} - partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"]) - return OperatorPartial( - operator_class=operator_class, - kwargs=partial_kwargs, - params=partial_params, - ) + # Create partial_kwargs from args and kwargs + partial_kwargs: dict[str, Any] = { + "task_id": task_id, + "dag": dag, + "task_group": task_group, + **kwargs, + } + + # Inject DAG-level default args into args provided to this function. + partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET) + + # Fill fields not provided by the user with default values. + for k, v in _PARTIAL_DEFAULTS.items(): + partial_kwargs.setdefault(k, v) + + # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). + if "task_concurrency" in kwargs: # Reject deprecated option. + raise TypeError("unexpected argument: task_concurrency") + if wait := partial_kwargs.get("wait_for_downstream", False): + partial_kwargs["depends_on_past"] = wait + if start_date := partial_kwargs.get("start_date", None): + partial_kwargs["start_date"] = timezone.convert_to_utc(start_date) + if end_date := partial_kwargs.get("end_date", None): + partial_kwargs["end_date"] = timezone.convert_to_utc(end_date) + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") + if retries := partial_kwargs.get("retries"): + partial_kwargs["retries"] = parse_retries(retries) + partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") + if partial_kwargs.get("max_retry_delay", None) is not None: + partial_kwargs["max_retry_delay"] = coerce_timedelta( + partial_kwargs["max_retry_delay"], + key="max_retry_delay", + ) + partial_kwargs.setdefault("executor_config", {}) + + return OperatorPartial( + operator_class=operator_class, + kwargs=partial_kwargs, + params=partial_params, + ) class ExecutorSafeguard: @@ -387,6 +374,8 @@ class ExecutorSafeguard: # TODO: Task-SDK - temporarily extend the metaclass to add in the ExecutorSafeguard. class BaseOperatorMeta(TaskSDKBaseOperatorMeta): + """:meta private:""" # noqa: D400 + def __new__(cls, name, bases, namespace, **kwargs): execute_method = namespace.get("execute") if callable(execute_method) and not getattr(execute_method, "__isabstractmethod__", False): @@ -617,7 +606,17 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator _is_setup: bool = False _is_teardown: bool = False - def __init__(self, pre_execute=None, post_execute=None, **kwargs): + def __init__( + self, + pre_execute=None, + post_execute=None, + on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, + **kwargs, + ): if start_date := kwargs.get("start_date", None): kwargs["start_date"] = timezone.convert_to_utc(start_date) @@ -626,6 +625,10 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator super().__init__(**kwargs) self._pre_execute_hook = pre_execute self._post_execute_hook = post_execute + self.on_execute_callback = on_execute_callback + self.on_failure_callback = on_failure_callback + self.on_success_callback = on_success_callback + self.on_skipped_callback = on_skipped_callback # Defines the operator level extra links operator_extra_links: Collection[BaseOperatorLink] = () @@ -764,7 +767,7 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator if TYPE_CHECKING: # TODO: Task-SDK: We need to set this to the scheduler DAG until we fully separate scheduling and - # defintion code + # definition code assert isinstance(self.dag, SchedulerDAG) clear_task_instances(results, session, dag=self.dag) @@ -815,7 +818,7 @@ class BaseOperator(TaskSDKBaseOperator, AbstractOperator, metaclass=BaseOperator assert self.start_date # TODO: Task-SDK: We need to set this to the scheduler DAG until we fully separate scheduling and - # defintion code + # definition code assert isinstance(self.dag, SchedulerDAG) start_date = pendulum.instance(start_date or self.start_date) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 642b4c2f7e1..e26a696ee7a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -83,7 +83,6 @@ from airflow.exceptions import ( UnknownExecutorException, ) from airflow.executors.executor_loader import ExecutorLoader -from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.asset import ( AssetDagRunQueue, AssetModel, @@ -125,6 +124,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session + from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.dagbag import DagBag from airflow.models.operator import Operator from airflow.serialization.pydantic.dag import DagModelPydantic @@ -2433,9 +2433,7 @@ if STATICA_HACK: # pragma: no cover class DagContext(airflow.sdk.definitions.contextmanager.DagContext, share_parent_context=True): - """ - :meta private: - """ + """:meta private:""" # noqa: D400 @classmethod def push_context_managed_dag(cls, dag: DAG): diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 8a9e790ea7f..52a08bce027 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -201,8 +201,8 @@ class OperatorPartial: task_id = partial_kwargs.pop("task_id") dag = partial_kwargs.pop("dag") task_group = partial_kwargs.pop("task_group") - start_date = partial_kwargs.pop("start_date") - end_date = partial_kwargs.pop("end_date") + start_date = partial_kwargs.pop("start_date", None) + end_date = partial_kwargs.pop("end_date", None) try: operator_name = self.operator_class.custom_operator_name # type: ignore diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 00e81791246..a67c7cf310b 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from airflow.models.dagrun import DagRun from airflow.models.operator import Operator - from airflow.sdk.defintions.node import DAGNode + from airflow.sdk.definitions.node import DAGNode from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 07ca0190285..940a7f1a066 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -43,7 +43,6 @@ if TYPE_CHECKING: # from airflow.models.dag import DAG from airflow.models.operator import Operator - from airflow.models.taskmixin import DAGNode from airflow.sdk import DAG, BaseOperator from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier diff --git a/airflow/operators/python.py b/airflow/operators/python.py index b032b45ed3e..dc2e772af0e 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -33,7 +33,7 @@ from collections.abc import Container from functools import cache from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable, Mapping, NamedTuple, Sequence +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Mapping, NamedTuple, Sequence import lazy_object_proxy @@ -197,7 +197,7 @@ class PythonOperator(BaseOperator): # since we won't mutate the arguments, we should just do the shallow copy # there are some cases we can't deepcopy the objects(e.g protobuf). - shallow_copy_attrs: Sequence[str] = ( + shallow_copy_attrs: ClassVar[Sequence[str]] = ( "python_callable", "op_kwargs", ) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 9aafaf1f54a..e2e21a6686d 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -101,7 +101,7 @@ if TYPE_CHECKING: from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.expandinput import ExpandInput from airflow.models.operator import Operator - from airflow.sdk.defintions.node import DAGNode + from airflow.sdk.definitions.node import DAGNode from airflow.serialization.json_schema import Validator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import Timetable diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 1cdbe4b745d..6b760a112af 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -21,9 +21,7 @@ from __future__ import annotations import functools import operator -from typing import TYPE_CHECKING, Any, Iterator - -import methodtools +from typing import TYPE_CHECKING, Iterator import airflow.sdk.definitions.contextmanager import airflow.sdk.definitions.taskgroup @@ -32,7 +30,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session from airflow.models.dag import DAG - from airflow.models.expandinput import ExpandInput from airflow.models.operator import Operator from airflow.typing_compat import TypeAlias diff --git a/airflow/utils/types.py b/airflow/utils/types.py index 6aba5c711f7..7dd1ce02b60 100644 --- a/airflow/utils/types.py +++ b/airflow/utils/types.py @@ -20,7 +20,7 @@ import enum from typing import TYPE_CHECKING import airflow.sdk.types -from airflow.typing_compat import TypedDict, TypeAlias +from airflow.typing_compat import TypeAlias, TypedDict if TYPE_CHECKING: from datetime import datetime diff --git a/scripts/ci/pre_commit/base_operator_partial_arguments.py b/scripts/ci/pre_commit/base_operator_partial_arguments.py index 14999e034ed..b5070533170 100755 --- a/scripts/ci/pre_commit/base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/base_operator_partial_arguments.py @@ -27,6 +27,7 @@ import typing ROOT_DIR = pathlib.Path(__file__).resolve().parents[3] BASEOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "baseoperator.py") +SDK_BASEOPERATOR_PY = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "definitions", "baseoperator.py") MAPPEDOPERATOR_PY = ROOT_DIR.joinpath("airflow", "models", "mappedoperator.py") IGNORED = { @@ -51,12 +52,31 @@ IGNORED = { # Only on MappedOperator. "expand_input", "partial_kwargs", + "operator_class", + # Task-SDK migration ones. + "deps", + "downstream_task_ids", + "on_execute_callback", + "on_failure_callback", + "on_retry_callback", + "on_skipped_callback", + "on_success_callback", + "operator_extra_links", + "start_from_trigger", + "start_trigger_args", + "upstream_task_ids", + "logger_name", + "sla", } BO_MOD = ast.parse(BASEOPERATOR_PY.read_text("utf-8"), str(BASEOPERATOR_PY)) +SDK_BO_MOD = ast.parse(SDK_BASEOPERATOR_PY.read_text("utf-8"), str(SDK_BASEOPERATOR_PY)) MO_MOD = ast.parse(MAPPEDOPERATOR_PY.read_text("utf-8"), str(MAPPEDOPERATOR_PY)) +# TODO: Task-SDK: Look at the BaseOperator init functions in both airflow.models.baseoperator and combine +# them, until we fully remove BaseOperator class from core. + BO_CLS = next( node for node in ast.iter_child_nodes(BO_MOD) @@ -67,9 +87,27 @@ BO_INIT = next( for node in ast.iter_child_nodes(BO_CLS) if isinstance(node, ast.FunctionDef) and node.name == "__init__" ) -BO_PARTIAL = next( + +SDK_BO_CLS = next( + node + for node in ast.iter_child_nodes(SDK_BO_MOD) + if isinstance(node, ast.ClassDef) and node.name == "BaseOperator" +) +SDK_BO_INIT = next( + node + for node in ast.iter_child_nodes(SDK_BO_CLS) + if isinstance(node, ast.FunctionDef) and node.name == "__init__" +) + +# We now define the signature in a type checking block, the runtime impl uses **kwargs +BO_TYPE_CHECKING_BLOCKS = ( node for node in ast.iter_child_nodes(BO_MOD) + if isinstance(node, ast.If) and node.test.id == "TYPE_CHECKING" # type: ignore[attr-defined] +) +BO_PARTIAL = next( + node + for node in itertools.chain.from_iterable(map(ast.iter_child_nodes, BO_TYPE_CHECKING_BLOCKS)) if isinstance(node, ast.FunctionDef) and node.name == "partial" ) MO_CLS = next( @@ -79,23 +117,27 @@ MO_CLS = next( ) -def _compare(a: set[str], b: set[str], *, excludes: set[str]) -> tuple[set[str], set[str]]: - only_in_a = {n for n in a if n not in b and n not in excludes and n[0] != "_"} - only_in_b = {n for n in b if n not in a and n not in excludes and n[0] != "_"} +def _compare(a: set[str], b: set[str]) -> tuple[set[str], set[str]]: + only_in_a = a - b - IGNORED + only_in_b = b - a - IGNORED return only_in_a, only_in_b -def _iter_arg_names(func: ast.FunctionDef) -> typing.Iterator[str]: - func_args = func.args - for arg in itertools.chain(func_args.args, getattr(func_args, "posonlyargs", ()), func_args.kwonlyargs): - yield arg.arg +def _iter_arg_names(*funcs: ast.FunctionDef) -> typing.Iterator[str]: + for func in funcs: + func_args = func.args + for arg in itertools.chain( + func_args.args, getattr(func_args, "posonlyargs", ()), func_args.kwonlyargs + ): + if arg.arg == "self" or arg.arg.startswith("_"): + continue + yield arg.arg def check_baseoperator_partial_arguments() -> bool: only_in_init, only_in_partial = _compare( - set(itertools.islice(_iter_arg_names(BO_INIT), 1, None)), - set(itertools.islice(_iter_arg_names(BO_PARTIAL), 1, None)), - excludes=IGNORED, + set(_iter_arg_names(SDK_BO_INIT, BO_INIT)), + set(_iter_arg_names(BO_PARTIAL)), ) if only_in_init: print("Arguments in BaseOperator missing from partial():", ", ".join(sorted(only_in_init))) @@ -109,6 +151,8 @@ def check_baseoperator_partial_arguments() -> bool: def _iter_assignment_to_self_attributes(targets: typing.Iterable[ast.expr]) -> typing.Iterator[str]: for t in targets: if isinstance(t, ast.Attribute) and isinstance(t.value, ast.Name) and t.value.id == "self": + if t.attr.startswith("_"): + continue yield t.attr # Something like "self.foo = ...". else: # Recursively visit nodes in unpacking assignments like "a, b = ...". @@ -132,20 +176,24 @@ def _is_property(f: ast.FunctionDef) -> bool: def _iter_member_names(klass: ast.ClassDef) -> typing.Iterator[str]: for node in ast.iter_child_nodes(klass): + name = "" if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): - yield node.target.id + name = node.target.id elif isinstance(node, ast.FunctionDef) and _is_property(node): - yield node.name + name = node.name elif isinstance(node, ast.Assign): if len(node.targets) == 1 and isinstance(target := node.targets[0], ast.Name): - yield target.id + name = target.id + else: + continue + if not name.startswith("_"): + yield name def check_operator_member_parity() -> bool: only_in_base, only_in_mapped = _compare( - set(itertools.chain(_iter_assignment_targets(BO_INIT), _iter_member_names(BO_CLS))), + set(itertools.chain(_iter_assignment_targets(SDK_BO_INIT), _iter_member_names(SDK_BO_CLS))), set(_iter_member_names(MO_CLS)), - excludes=IGNORED, ) if only_in_base: print("Members on BaseOperator missing from MappedOperator:", ", ".join(sorted(only_in_base))) diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 7c8a4e52472..37ea2d300ab 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -44,6 +44,8 @@ namespace-packages = ["src/airflow"] # Ignore Doc rules et al for anything outside of tests "!src/*" = ["D", "TID253", "S101", "TRY002"] +"src/airflow/sdk/__init__.py" = ["TCH004"] + [tool.uv] dev-dependencies = [ "kgb>=7.1.1", diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py index 6f90ae7f118..bb5ddf88e23 100644 --- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py @@ -22,7 +22,6 @@ from abc import abstractmethod from collections.abc import ( Collection, Iterable, - Mapping, ) from typing import ( TYPE_CHECKING, @@ -30,17 +29,14 @@ from typing import ( ClassVar, ) -from airflow.sdk.definitions.node import DAGNode from airflow.sdk.definitions.mixins import DependencyMixin -from airflow.utils.log.secrets_masker import redact +from airflow.sdk.definitions.node import DAGNode from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule # TaskStateChangeCallback = Callable[[Context], None] if TYPE_CHECKING: - import jinja2 # Slow import. - from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 57a85988987..8349876cb8f 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -719,7 +719,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): if kwargs: raise TypeError( f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " - + f"Invalid arguments were:\n**kwargs: {kwargs}", + f"Invalid arguments were:\n**kwargs: {kwargs}", ) validate_key(task_id) @@ -762,6 +762,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): # self.retries = parse_retries(retries) self.retries = retries self.queue = queue + # TODO: Task-SDK: pull this default name from Pool constant? self.pool = "default" if pool is None else pool self.pool_slots = pool_slots if self.pool_slots < 1: @@ -1070,8 +1071,8 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): def _set_xcomargs_dependencies(self) -> None: from airflow.models.xcom_arg import XComArg - for field in self.template_fields: - arg = getattr(self, field, NOTSET) + for f in self.template_fields: + arg = getattr(self, f, NOTSET) if arg is not NOTSET: XComArg.apply_upstream_relationship(self, arg) diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 21579e87356..ba82f19339e 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -46,7 +46,6 @@ from dateutil.relativedelta import relativedelta from airflow import settings from airflow.assets import Asset, AssetAlias, BaseAsset -from airflow.configuration import conf as airflow_conf from airflow.exceptions import ( DuplicateTaskIdFound, FailStopDagInvalidTriggerRule, @@ -119,29 +118,6 @@ _DAG_HASH_ATTRS = frozenset( } ) -# TODO: The following mapping is used to validate that the arguments passed to the DAG are of the correct -# type. This is a temporary solution until we find a more sophisticated method for argument validation. -# One potential method is to use `get_type_hints` from the typing module. However, this is not fully -# compatible with future annotations for Python versions below 3.10. Once we require a minimum Python -# version that supports `get_type_hints` effectively or find a better approach, we can replace this -# manual type-checking method. -DAG_ARGS_EXPECTED_TYPES = { - "dag_id": str, - "description": str, - "max_active_tasks": int, - "max_active_runs": int, - "max_consecutive_failed_dag_runs": int, - "dagrun_timeout": timedelta, - "catchup": bool, - "doc_md": str, - "is_paused_upon_creation": bool, - "render_template_as_native_obj": bool, - "tags": Collection, - "auto_register": bool, - "fail_stop": bool, - "dag_display_name": str, -} - def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> Timetable: """Create a Timetable instance from a plain ``schedule`` value.""" @@ -167,7 +143,7 @@ def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTime def _convert_params(val: abc.MutableMapping | None, self_: DAG) -> ParamsDict: """ - Convert the plain dict into a ParamsDict + Convert the plain dict into a ParamsDict. This will also merge in params from default_args """ @@ -336,8 +312,11 @@ class DAG: # NOTE: When updating arguments here, please also keep arguments in @dag() # below in sync. (Search for 'def dag(' in this file.) - dag_id: str = attrs.field(kw_only=False) - description: str | None = None + dag_id: str = attrs.field(kw_only=False, validator=attrs.validators.instance_of(str)) + description: str | None = attrs.field( + default=None, + validator=attrs.validators.optional(attrs.validators.instance_of(str)), + ) default_args: dict[str, Any] = attrs.field( factory=dict, validator=attrs.validators.instance_of(dict), converter=dict_copy ) @@ -355,12 +334,17 @@ class DAG: user_defined_macros: dict | None = None user_defined_filters: dict | None = None concurrency: int | None = None - max_active_tasks: int = 16 - max_active_runs: int = 16 - max_consecutive_failed_dag_runs: int = -1 - dagrun_timeout: timedelta | None = None + max_active_tasks: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) + max_active_runs: int = attrs.field(default=16, validator=attrs.validators.instance_of(int)) + max_consecutive_failed_dag_runs: int = attrs.field( + default=-1, validator=attrs.validators.instance_of(int) + ) + dagrun_timeout: timedelta | None = attrs.field( + default=None, + validator=attrs.validators.optional(attrs.validators.instance_of(timedelta)), + ) # sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None - catchup: bool = attrs.field(default=True) + catchup: bool = attrs.field(default=True, converter=bool) # on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None # on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None doc_md: str | None = None @@ -372,12 +356,12 @@ class DAG: access_control: dict | None = None is_paused_upon_creation: bool | None = None jinja_environment_kwargs: dict | None = None - render_template_as_native_obj: bool = False + render_template_as_native_obj: bool = attrs.field(default=False, converter=bool) tags: MutableSet[str] = attrs.field(factory=set, converter=_convert_tags) owner_links: dict[str, str] = attrs.field(factory=dict) - auto_register: bool = True - fail_stop: bool = False - dag_display_name: str = attrs.field() + auto_register: bool = attrs.field(default=True, converter=bool) + fail_stop: bool = attrs.field(default=True, converter=bool) + dag_display_name: str = attrs.field(validator=attrs.validators.instance_of(str)) task_dict: dict[str, Operator] = attrs.field(factory=dict, init=False) @@ -453,7 +437,7 @@ class DAG: from airflow.utils import timezone - # TODO: Task-SDK: get default dag tz from settins + # TODO: Task-SDK: get default dag tz from settings tz = timezone.utc if self.start_date and (tzinfo := self.start_date.tzinfo): tzinfo = None if tzinfo else tz @@ -484,6 +468,11 @@ class DAG: if requires_automatic_backfilling and not ("start_date" in self.default_args or self.start_date): raise ValueError("start_date is required when catchup=True") + @tags.validator + def _validate_tags(self, _, tags: Collection[str]): + if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): + raise ValueError(f"tag cannot be longer than {TAG_MAX_LEN} characters") + def __repr__(self): return f"<DAG: {self.dag_id}>" @@ -521,7 +510,7 @@ class DAG: return self def __exit__(self, _type, _value, _tb): - from .contextmanager import DagContext + from airflow.sdk.definitions.contextmanager import DagContext _ = DagContext.pop() @@ -692,7 +681,7 @@ class DAG: result.user_defined_macros = self.user_defined_macros result.user_defined_filters = self.user_defined_filters if hasattr(self, "_log"): - result._log = self._log + result._log = self._log # type: ignore[attr-defined] return result def partial_subset( @@ -1000,14 +989,12 @@ if TYPE_CHECKING: user_defined_macros: dict | None = None, user_defined_filters: dict | None = None, default_args: dict | None = None, - max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), - max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), - max_consecutive_failed_dag_runs: int = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ), + max_active_tasks: int = ..., + max_active_runs: int = ..., + max_consecutive_failed_dag_runs: int = ..., dagrun_timeout: timedelta | None = None, sla_miss_callback: Any = None, - catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), + catchup: bool = ..., on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, doc_md: str | None = None, diff --git a/task_sdk/src/airflow/sdk/definitions/edges.py b/task_sdk/src/airflow/sdk/definitions/edges.py index 7c62c1bda80..7e50431b497 100644 --- a/task_sdk/src/airflow/sdk/definitions/edges.py +++ b/task_sdk/src/airflow/sdk/definitions/edges.py @@ -19,10 +19,10 @@ from __future__ import annotations from collections.abc import Sequence from typing import TYPE_CHECKING -from .mixins import DependencyMixin +from airflow.sdk.definitions.mixins import DependencyMixin if TYPE_CHECKING: - from .dag import DAG + from airflow.sdk.definitions.dag import DAG class EdgeModifier(DependencyMixin): diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py b/task_sdk/src/airflow/sdk/definitions/mixins.py index 73d455d6c92..e9d6e162927 100644 --- a/task_sdk/src/airflow/sdk/definitions/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/mixins.py @@ -21,11 +21,11 @@ from abc import abstractmethod from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any -from ..types import NOTSET, ArgNotSet +from airflow.sdk.types import NOTSET, ArgNotSet if TYPE_CHECKING: - from .baseoperator import BaseOperator - from .edges import EdgeModifier + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.edges import EdgeModifier # TODO: Should this all just live on DAGNode? diff --git a/task_sdk/src/airflow/sdk/definitions/node.py b/task_sdk/src/airflow/sdk/definitions/node.py index ddc560b8bcd..7b877bbaf48 100644 --- a/task_sdk/src/airflow/sdk/definitions/node.py +++ b/task_sdk/src/airflow/sdk/definitions/node.py @@ -53,7 +53,7 @@ def validate_key(k: str, max_length: int = 250): if not KEY_REGEX.match(k): raise ValueError( f"The key {k!r} has to be made of alphanumeric characters, dashes, " - + "dots and underscores exclusively" + "dots and underscores exclusively" ) @@ -152,7 +152,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): else: raise ValueError( "Tried to create relationships between tasks that don't have DAGs yet. " - + f"Set the DAG for at least one task and try again: {[self, *task_list]}" + f"Set the DAG for at least one task and try again: {[self, *task_list]}" ) if not self.has_dag(): diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index ffde2170b17..232d08e27f9 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -57,7 +57,7 @@ if TYPE_CHECKING: Logger = logging.Logger else: - class Logger: ... + class Logger: ... # noqa: D101 def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, Any]) -> None: diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index 23ede320c87..427d1ee0e3e 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -222,7 +222,7 @@ class TestBaseOperator: spy_agency.spy_on(XComArg.apply_upstream_relationship, call_original=False) op_copy.arg1 = "b" - assert XComArg.apply_upstream_relationship.called == False + assert XComArg.apply_upstream_relationship.called is False def test_upstream_is_set_when_template_field_is_xcomarg(self): with DAG("xcomargs_test", schedule=None): @@ -273,7 +273,7 @@ class TestBaseOperator: def test_init_subclass_args(): class InitSubclassOp(BaseOperator): - class_arg: Any + class_arg = None def __init_subclass__(cls, class_arg=None, **kwargs) -> None: cls.class_arg = class_arg diff --git a/task_sdk/tests/defintions/test_dag.py b/task_sdk/tests/defintions/test_dag.py index 2300e97f07e..c3b3bbdce4d 100644 --- a/task_sdk/tests/defintions/test_dag.py +++ b/task_sdk/tests/defintions/test_dag.py @@ -18,6 +18,7 @@ from __future__ import annotations import weakref from datetime import datetime, timedelta, timezone +from typing import Any import pytest @@ -248,6 +249,73 @@ class TestDag: assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids +# Test some of the arg valiadtion. This is not all the validations we perform, just some of them. [email protected]( + ["attr", "value"], + [ + pytest.param("max_consecutive_failed_dag_runs", "not_an_int", id="max_consecutive_failed_dag_runs"), + pytest.param("dagrun_timeout", "not_an_int", id="dagrun_timeout"), + ], +) +def test_invalid_type_for_args(attr: str, value: Any): + with pytest.raises(TypeError): + DAG("invalid-default-args", **{attr: value}) + + [email protected]( + "tags, should_pass", + [ + pytest.param([], True, id="empty tags"), + pytest.param(["a normal tag"], True, id="one tag"), + pytest.param(["a normal tag", "another normal tag"], True, id="two tags"), + pytest.param(["a" * 100], True, id="a tag that's of just length 100"), + pytest.param(["a normal tag", "a" * 101], False, id="two tags and one of them is of length > 100"), + ], +) +def test__tags_length(tags: list[str], should_pass: bool): + if should_pass: + DAG("test-dag", schedule=None, tags=tags) + else: + with pytest.raises(ValueError): + DAG("test-dag", schedule=None, tags=tags) + + [email protected]( + "input_tags, expected_result", + [ + pytest.param([], set(), id="empty tags"), + pytest.param( + ["a normal tag"], + {"a normal tag"}, + id="one tag", + ), + pytest.param( + ["a normal tag", "another normal tag"], + {"a normal tag", "another normal tag"}, + id="two different tags", + ), + pytest.param( + ["a", "a"], + {"a"}, + id="two same tags", + ), + ], +) +def test__tags_duplicates(input_tags: list[str], expected_result: set[str]): + result = DAG("test-dag", tags=input_tags) + assert result.tags == expected_result + + +def test__tags_mutable(): + expected_tags = {"6", "7"} + test_dag = DAG("test-dag") + test_dag.tags.add("6") + test_dag.tags.add("7") + test_dag.tags.add("8") + test_dag.tags.remove("8") + assert test_dag.tags == expected_tags + + class TestDagDecorator: DEFAULT_ARGS = { "owner": "test", diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 0af6f69df01..1d8e4f4d32a 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -29,7 +29,7 @@ import jinja2 import pytest from airflow.decorators import task as task_decorator -from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule +from airflow.exceptions import AirflowException from airflow.lineage.entities import File from airflow.models.baseoperator import ( BaseOperator, @@ -47,9 +47,8 @@ from airflow.utils.template import literal from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType -from tests_common.test_utils.mock_operators import MockOperator - from tests.models import DEFAULT_DATE +from tests_common.test_utils.mock_operators import MockOperator class ClassWithCustomAttributes: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index be8dde52331..43029fbab98 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -1496,10 +1496,10 @@ class TestDag: fail_stop_dag.add_task(task_with_default_trigger_rule) # a fail stop dag should not allow a non-default trigger rule + task_with_non_default_trigger_rule = EmptyOperator( + task_id="task_with_non_default_trigger_rule", trigger_rule=TriggerRule.ALWAYS + ) with pytest.raises(FailStopDagInvalidTriggerRule): - task_with_non_default_trigger_rule = EmptyOperator( - task_id="task_with_non_default_trigger_rule", trigger_rule=TriggerRule.ALWAYS - ) fail_stop_dag.add_task(task_with_non_default_trigger_rule) def test_dag_add_task_sets_default_task_group(self): @@ -3033,60 +3033,6 @@ def test__time_restriction(dag_maker, dag_date, tasks_date, restrict): assert dag._time_restriction == restrict [email protected]( - "tags, should_pass", - [ - pytest.param([], True, id="empty tags"), - pytest.param(["a normal tag"], True, id="one tag"), - pytest.param(["a normal tag", "another normal tag"], True, id="two tags"), - pytest.param(["a" * 100], True, id="a tag that's of just length 100"), - pytest.param(["a normal tag", "a" * 101], False, id="two tags and one of them is of length > 100"), - ], -) -def test__tags_length(tags: list[str], should_pass: bool): - if should_pass: - DAG("test-dag", schedule=None, tags=tags) - else: - with pytest.raises(AirflowException): - DAG("test-dag", schedule=None, tags=tags) - - [email protected]( - "input_tags, expected_result", - [ - pytest.param([], set(), id="empty tags"), - pytest.param( - ["a normal tag"], - {"a normal tag"}, - id="one tag", - ), - pytest.param( - ["a normal tag", "another normal tag"], - {"a normal tag", "another normal tag"}, - id="two different tags", - ), - pytest.param( - ["a", "a"], - {"a"}, - id="two same tags", - ), - ], -) -def test__tags_duplicates(input_tags: list[str], expected_result: set[str]): - result = DAG("test-dag", tags=input_tags) - assert result.tags == expected_result - - -def test__tags_mutable(): - expected_tags = {"6", "7"} - test_dag = DAG("test-dag") - test_dag.tags.add("6") - test_dag.tags.add("7") - test_dag.tags.add("8") - test_dag.tags.remove("8") - assert test_dag.tags == expected_tags - - @pytest.mark.need_serialized_dag def test_get_asset_triggered_next_run_info(dag_maker, clear_assets): asset1 = Asset(uri="ds1") @@ -3195,11 +3141,6 @@ def test_create_dagrun_disallow_manual_to_use_automated_run_id(run_id_type: DagR ) -def test_invalid_type_for_args(): - with pytest.raises(TypeError): - DAG("invalid-default-args", schedule=None, max_consecutive_failed_dag_runs="not_an_int") - - class TestTaskClearingSetupTeardownBehavior: """ Task clearing behavior is mainly controlled by dag.partial_subset. diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 43abcf9012e..85c225cc9b6 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -18,7 +18,6 @@ from __future__ import annotations from datetime import timedelta -from unittest import mock import pendulum import pytest diff --git a/tests_common/test_utils/mock_operators.py b/tests_common/test_utils/mock_operators.py index ef941abb13f..6a88f41c22a 100644 --- a/tests_common/test_utils/mock_operators.py +++ b/tests_common/test_utils/mock_operators.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Sequence import attr @@ -66,7 +67,7 @@ class MockOperatorWithNestedFields(BaseOperator): def _render_nested_template_fields( self, content: Any, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment, seen_oids: set, ) -> None:
