This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/task-sdk-first-code by this
push:
new 44efb0b7c1 Move over more of BaseOperator and DAG, along with their
tests
44efb0b7c1 is described below
commit 44efb0b7c1d383ed8b71d2e6ec25573f542abe56
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Mon Oct 14 22:16:00 2024 +0100
Move over more of BaseOperator and DAG, along with their tests
---
airflow/task/priority_strategy.py | 4 +-
dev/tests_common/test_utils/mock_operators.py | 10 -
task_sdk/pyproject.toml | 19 ++
.../airflow/sdk/definitions/abstractoperator.py | 5 +-
.../src/airflow/sdk/definitions/baseoperator.py | 352 ++++++++++++++-------
task_sdk/src/airflow/sdk/definitions/dag.py | 114 +++++--
task_sdk/src/airflow/sdk/definitions/node.py | 2 +-
task_sdk/src/airflow/sdk/types.py | 19 +-
task_sdk/tests/defintions/test_baseoperator.py | 305 ++++++++++++++++++
task_sdk/tests/defintions/test_dag.py | 96 +++++-
tests/models/test_baseoperator.py | 253 +--------------
tests/models/test_dag.py | 143 +--------
uv.lock | 11 +
13 files changed, 768 insertions(+), 565 deletions(-)
diff --git a/airflow/task/priority_strategy.py
b/airflow/task/priority_strategy.py
index c22bdfa994..dcef1c865b 100644
--- a/airflow/task/priority_strategy.py
+++ b/airflow/task/priority_strategy.py
@@ -22,8 +22,6 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
-from airflow.exceptions import AirflowException
-
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
@@ -150,5 +148,5 @@ def validate_and_load_priority_weight_strategy(
priority_weight_strategy_class = qualname(priority_weight_strategy)
loaded_priority_weight_strategy =
_get_registered_priority_weight_strategy(priority_weight_strategy_class)
if loaded_priority_weight_strategy is None:
- raise AirflowException(f"Unknown priority strategy
{priority_weight_strategy_class}")
+ raise ValueError(f"Unknown priority strategy
{priority_weight_strategy_class}")
return loaded_priority_weight_strategy()
diff --git a/dev/tests_common/test_utils/mock_operators.py
b/dev/tests_common/test_utils/mock_operators.py
index 0df0afec82..ecf8989f4b 100644
--- a/dev/tests_common/test_utils/mock_operators.py
+++ b/dev/tests_common/test_utils/mock_operators.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import warnings
from typing import TYPE_CHECKING, Any, Sequence
import attr
@@ -200,12 +199,3 @@ class GithubLink(BaseOperatorLink):
def get_link(self, operator, *, ti_key):
return "https://github.com/apache/airflow"
-
-
-class DeprecatedOperator(BaseOperator):
- def __init__(self, **kwargs):
- warnings.warn("This operator is deprecated.", DeprecationWarning,
stacklevel=2)
- super().__init__(**kwargs)
-
- def execute(self, context: Context):
- pass
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index f83a4b7ec2..f371c3744a 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -46,7 +46,26 @@ namespace-packages = ["src/airflow"]
[tool.uv]
dev-dependencies = [
+ "kgb>=7.1.1",
"pytest-asyncio>=0.24.0",
"pytest-mock>=3.14.0",
"pytest>=8.3.3",
]
+
+[tool.coverage.run]
+branch = true
+relative_files = true
+source = ["src/airflow"]
+include_namespace_packages = true
+
+[tool.coverage.report]
+skip_empty = true
+exclude_also = [
+ "def __repr__",
+ "raise AssertionError",
+ "raise NotImplementedError",
+ "if __name__ == .__main__.:",
+ "@(abc\\.)?abstractmethod",
+ "@(typing(_extensions)?\\.)?overload",
+ "if (typing(_extensions)?\\.)?TYPE_CHECKING:",
+]
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
index d37a478df3..2dd0d48804 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
@@ -50,12 +50,9 @@ DEFAULT_RETRIES: int = 0
DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(300)
MAX_RETRY_DELAY: int = 24 * 60 * 60
-# DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
-# conf.get("core", "default_task_weight_rule",
fallback=WeightRule.DOWNSTREAM)
-# )
# DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
-DEFAULT_WEIGHT_RULE = 0
DEFAULT_TRIGGER_RULE = "ALL_SUCCESS" # TriggerRule.ALL_SUCCESS
+DEFAULT_WEIGHT_RULE = "downstream"
DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index d52802db21..6480a89780 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -22,8 +22,9 @@ import collections.abc
import contextlib
import copy
import inspect
-from collections.abc import Sequence
+from collections.abc import Iterable, Sequence
from dataclasses import dataclass
+from datetime import datetime, timedelta
from functools import total_ordering, wraps
from types import FunctionType
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
@@ -44,6 +45,9 @@ from airflow.sdk.definitions.abstractoperator import (
)
from airflow.sdk.definitions.decorators import fixup_decorator_warning_stack
from airflow.sdk.definitions.node import validate_key
+from airflow.sdk.types import NOTSET, validate_instance_args
+from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
+from airflow.utils.types import AttributeRemoved
T = TypeVar("T", bound=FunctionType)
@@ -59,7 +63,7 @@ if TYPE_CHECKING:
# TODO: Task-SDK
AirflowException = RuntimeError
-ParamsDict = object
+ParamsDict = dict
def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) ->
tuple[dict, ParamsDict]:
@@ -83,11 +87,11 @@ def get_merged_defaults(
args, params = _get_parent_defaults(dag, task_group)
if task_params:
if not isinstance(task_params, collections.abc.Mapping):
- raise TypeError("params must be a mapping")
+ raise TypeError(f"params must be a mapping, got
{type(task_params)}")
params.update(task_params)
if task_default_args:
if not isinstance(task_default_args, collections.abc.Mapping):
- raise TypeError("default_args must be a mapping")
+ raise TypeError(f"default_args must be a mapping, got
{type(task_params)}")
args.update(task_default_args)
with contextlib.suppress(KeyError):
params.update(task_default_args["params"] or {})
@@ -130,7 +134,7 @@ class BaseOperatorMeta(abc.ABCMeta):
from airflow.sdk.definitions.contextmanager import DagContext,
TaskGroupContext
if args:
- raise AirflowException("Use keyword arguments when
initializing operators")
+ raise TypeError("Use keyword arguments when initializing
operators")
instantiated_from_mapped = kwargs.pop(
"_airflow_from_mapped",
@@ -155,10 +159,10 @@ class BaseOperatorMeta(abc.ABCMeta):
missing_args = non_optional_args.difference(kwargs)
if len(missing_args) == 1:
- raise AirflowException(f"missing keyword argument
{missing_args.pop()!r}")
+ raise TypeError(f"missing keyword argument
{missing_args.pop()!r}")
elif missing_args:
display = ", ".join(repr(a) for a in sorted(missing_args))
- raise AirflowException(f"missing keyword arguments {display}")
+ raise TypeError(f"missing keyword arguments {display}")
if merged_params:
kwargs["params"] = merged_params
@@ -169,8 +173,8 @@ class BaseOperatorMeta(abc.ABCMeta):
default_args = kwargs.pop("default_args", {})
if not hasattr(self, "_BaseOperator__init_kwargs"):
- self._BaseOperator__init_kwargs = {}
- self._BaseOperator__from_mapped = instantiated_from_mapped
+ object.__setattr__(self, "_BaseOperator__init_kwargs", {})
+ object.__setattr__(self, "_BaseOperator__from_mapped",
instantiated_from_mapped)
result = func(self, **kwargs, default_args=default_args)
@@ -180,9 +184,9 @@ class BaseOperatorMeta(abc.ABCMeta):
# Set upstream task defined by XComArgs passed to template fields
of the operator.
# BUT: only do this _ONCE_, not once for each class in the
hierarchy
if not instantiated_from_mapped and func ==
self.__init__.__wrapped__: # type: ignore[misc]
- self.set_xcomargs_dependencies()
- # Mark instance as instantiated.
- self._BaseOperator__instantiated = True
+ self._set_xcomargs_dependencies()
+ # Mark instance as instantiated so that futre attr setting
updates xcomarg-based deps.
+ object.__setattr__(self, "_BaseOperator__instantiated", True)
return result
@@ -213,11 +217,60 @@ class BaseOperatorMeta(abc.ABCMeta):
return new_cls
+# TODO: The following mapping is used to validate that the arguments passed to
the BaseOperator are of the
+# correct type. This is a temporary solution until we find a more
sophisticated method for argument
+# validation. One potential method is to use `get_type_hints` from the typing
module. However, this is not
+# fully compatible with future annotations for Python versions below 3.10.
Once we require a minimum Python
+# version that supports `get_type_hints` effectively or find a better
approach, we can replace this
+# manual type-checking method.
+BASEOPERATOR_ARGS_EXPECTED_TYPES = {
+ "task_id": str,
+ "email": (str, Sequence),
+ "email_on_retry": bool,
+ "email_on_failure": bool,
+ "retries": int,
+ "retry_exponential_backoff": bool,
+ "depends_on_past": bool,
+ "ignore_first_depends_on_past": bool,
+ "wait_for_past_depends_before_skipping": bool,
+ "wait_for_downstream": bool,
+ "priority_weight": int,
+ "queue": str,
+ "pool": str,
+ "pool_slots": int,
+ "trigger_rule": str,
+ "run_as_user": str,
+ "task_concurrency": int,
+ "map_index_template": str,
+ "max_active_tis_per_dag": int,
+ "max_active_tis_per_dagrun": int,
+ "executor": str,
+ "do_xcom_push": bool,
+ "multiple_outputs": bool,
+ "doc": str,
+ "doc_md": str,
+ "doc_json": str,
+ "doc_yaml": str,
+ "doc_rst": str,
+ "task_display_name": str,
+ "logger_name": str,
+ "allow_nested_operators": bool,
+ "start_date": datetime,
+ "end_date": datetime,
+}
+
+
+# Note: BaseOperator is defined as a dataclass, and not an attrs class as we
do too much metaprogramming in
+# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs
too much to make it worth it.
+#
+# To future reader: if you want to try and make this a "normal" attrs class,
go ahead and attempt it. If you
+# get no where leave your record here for the next poor soul and what problems
you ran in to.
+#
+# @ashb, 2024/10/14
+# - "Can't combine custom __setattr__ with on_setattr hooks"
+# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs
subclasses
@total_ordering
-@dataclass(
- init=False,
- repr=False,
-)
+@dataclass(repr=False, kw_only=True)
class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
r"""
Abstract base class for all operators.
@@ -433,7 +486,59 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
hello_world_task.execute(context)
"""
- # Implementing Operator.
+ task_id: str
+ owner: str = DEFAULT_OWNER
+ email: str | Sequence[str] | None = None
+ retries: int | None = DEFAULT_RETRIES
+ retry_delay: timedelta | float = DEFAULT_RETRY_DELAY
+ retry_exponential_backoff: bool = False
+ max_retry_delay: timedelta | float | None = None
+ start_date: datetime | None = None
+ end_date: datetime | None = None
+ depends_on_past: bool = False
+ ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
+ wait_for_past_depends_before_skipping: bool =
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
+ wait_for_downstream: bool = False
+ dag: DAG | None = None
+ params: MutableMapping | None = None
+ default_args: dict | None = None
+ priority_weight: int = DEFAULT_PRIORITY_WEIGHT
+ # TODO:
+ weight_rule: PriorityWeightStrategy | str = DEFAULT_WEIGHT_RULE
+ queue: str = DEFAULT_QUEUE
+ pool: str = "default"
+ pool_slots: int = DEFAULT_POOL_SLOTS
+ execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
+ # on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ # on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ # on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ # on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ # on_skipped_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ # pre_execute: TaskPreExecuteHook | None = None
+ # post_execute: TaskPostExecuteHook | None = None
+ trigger_rule: str = DEFAULT_TRIGGER_RULE
+ resources: dict[str, Any] | None = None
+ run_as_user: str | None = None
+ task_concurrency: int | None = None
+ map_index_template: str | None = None
+ max_active_tis_per_dag: int | None = None
+ max_active_tis_per_dagrun: int | None = None
+ executor: str | None = None
+ executor_config: dict | None = None
+ do_xcom_push: bool = True
+ multiple_outputs: bool = False
+ inlets: Any | None = None
+ outlets: Any | None = None
+ task_group: TaskGroup | None = None
+ doc: str | None = None
+ doc_md: str | None = None
+ doc_json: str | None = None
+ doc_yaml: str | None = None
+ doc_rst: str | None = None
+ _task_display_name: str
+ logger_name: str | None = None
+ allow_nested_operators: bool = True
+
template_fields: ClassVar[Sequence[str]] = ()
template_ext: ClassVar[Sequence[str]] = ()
@@ -443,12 +548,10 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
ui_color: str = "#fff"
ui_fgcolor: str = "#000"
- pool: str = ""
-
- # TODO: Mapping
+ # TODO: Task-SDK Mapping
# partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type:
ignore
- _comps = {
+ _comps: ClassVar[set[str]] = {
"task_id",
"dag_id",
"owner",
@@ -476,29 +579,45 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
}
# Defines if the operator supports lineage without manual definitions
- supports_lineage = False
+ supports_lineage: bool = False
# If True then the class constructor was called
- __instantiated = False
+ __instantiated: bool = False
# List of args as passed to `init()`, after apply_defaults() has been
updated. Used to "recreate" the task
# when mapping
__init_kwargs: dict[str, Any]
# Set to True before calling execute method
- _lock_for_execution = False
+ _lock_for_execution: bool = False
# Set to True for an operator instantiated by a mapped operator.
- __from_mapped = False
+ __from_mapped: bool = False
# TODO:
# start_trigger_args: StartTriggerArgs | None = None
# start_from_trigger: bool = False
+ def __setattr__(self: BaseOperator, key: str, value: Any):
+ if converter := getattr(self, f"_convert_{key}", None):
+ value = converter(value)
+ super().__setattr__(key, value)
+ if self.__from_mapped or self._lock_for_execution:
+ return # Skip any custom behavior for validation and during
execute.
+ if key in self.__init_kwargs:
+ self.__init_kwargs[key] = value
+ if self.__instantiated and key in self.template_fields:
+ # Resolve upstreams set by assigning an XComArg after initializing
+ # an operator, example:
+ # op = BashOperator()
+ # op.bash_command = "sleep 1"
+ self._set_xcomargs_dependency(key, value)
+
def __init__(
self,
+ *,
task_id: str,
owner: str = DEFAULT_OWNER,
- email: str | Iterable[str] | None = None,
+ email: str | Sequence[str] | None = None,
retries: int | None = DEFAULT_RETRIES,
retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
retry_exponential_backoff: bool = False,
@@ -513,9 +632,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
params: MutableMapping | None = None,
default_args: dict | None = None,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
- # TODO:
- # weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
- weight_rule: str = DEFAULT_WEIGHT_RULE,
+ weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
queue: str = DEFAULT_QUEUE,
pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
@@ -531,7 +648,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
trigger_rule: str = DEFAULT_TRIGGER_RULE,
resources: dict[str, Any] | None = None,
run_as_user: str | None = None,
- task_concurrency: int | None = None,
map_index_template: str | None = None,
max_active_tis_per_dag: int | None = None,
max_active_tis_per_dagrun: int | None = None,
@@ -554,25 +670,23 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
):
from airflow.sdk.definitions.contextmanager import DagContext,
TaskGroupContext
- self.__init_kwargs = {}
+ self.task_id = task_group.child_id(task_id) if task_group else task_id
+ if not self.__from_mapped and task_group:
+ task_group.add(self)
+
+ dag = dag or DagContext.get_current()
+ task_group = task_group or TaskGroupContext.get_current(dag)
- super().__init__()
+ super().__init__(dag=dag, task_group=task_group)
kwargs.pop("_airflow_mapped_validation_only", None)
if kwargs:
- raise RuntimeError(
+ raise TypeError(
f"Invalid arguments were passed to {self.__class__.__name__}
(task_id: {task_id}). "
+ f"Invalid arguments were:\n**kwargs: {kwargs}",
)
validate_key(task_id)
- dag = dag or DagContext.get_current()
- task_group = task_group or TaskGroupContext.get_current(dag)
-
- self.task_id = task_group.child_id(task_id) if task_group else task_id
- if not self.__from_mapped and task_group:
- task_group.add(self)
-
self.owner = owner
self.email = email
@@ -591,9 +705,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
# self._pre_execute_hook = pre_execute
# self._post_execute_hook = post_execute
- if start_date and not isinstance(start_date, datetime):
- self.log.warning("start_date for %s isn't datetime.datetime", self)
- elif start_date:
+ if start_date:
self.start_date = timezone.convert_to_utc(start_date)
if end_date:
@@ -610,7 +722,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.run_as_user = run_as_user
# TODO:
# self.retries = parse_retries(retries)
- self.retries = int(retries)
+ self.retries = retries
self.queue = queue
self.pool = "default" if pool is None else pool
self.pool_slots = pool_slots
@@ -620,15 +732,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.sla = sla
"""
- if trigger_rule == "none_failed_or_skipped":
- warnings.warn(
- "none_failed_or_skipped Trigger Rule is deprecated. "
- "Please use `none_failed_min_one_success`.",
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- trigger_rule = TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
-
# if not TriggerRule.is_valid(trigger_rule):
# raise AirflowException(
# f"The trigger_rule must be one of
{TriggerRule.all_triggers()},"
@@ -637,6 +740,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
#
# self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
# FailStopDagInvalidTriggerRule.check(dag=dag,
trigger_rule=self.trigger_rule)
+ """
self.depends_on_past: bool = depends_on_past
self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
@@ -645,34 +749,21 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
if wait_for_downstream:
self.depends_on_past = True
- self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay")
+ self.retry_delay = retry_delay
self.retry_exponential_backoff = retry_exponential_backoff
- self.max_retry_delay = (
- max_retry_delay
- if max_retry_delay is None
- else coerce_timedelta(max_retry_delay, key="max_retry_delay")
- )
+ if max_retry_delay is not None:
+ self.max_retry_delay = max_retry_delay
+
+ self.resources = resources
+ """
# At execution_time this becomes a normal dict
self.params: ParamsDict | dict = ParamsDict(params)
- if priority_weight is not None and not isinstance(priority_weight,
int):
- raise AirflowException(
- f"`priority_weight` for task '{self.task_id}' only accepts
integers, "
- f"received '{type(priority_weight)}'."
- )
+ """
+
self.priority_weight = priority_weight
self.weight_rule =
validate_and_load_priority_weight_strategy(weight_rule)
- self.resources = coerce_resources(resources)
- if task_concurrency and not max_active_tis_per_dag:
- # TODO: Remove in Airflow 3.0
- warnings.warn(
- "The 'task_concurrency' parameter is deprecated. Please use
'max_active_tis_per_dag'.",
- RemovedInAirflow3Warning,
- stacklevel=2,
- )
- max_active_tis_per_dag = task_concurrency
- """
self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
self.do_xcom_push: bool = do_xcom_push
@@ -684,22 +775,18 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.doc_yaml = doc_yaml
self.doc_rst = doc_rst
self.doc = doc
- # Populate the display field only if provided and different from task
id
- self._task_display_property_value = (
- task_display_name if task_display_name and task_display_name !=
task_id else None
- )
- if dag:
- self.dag = dag
+ self._task_display_name = task_display_name
+
+ self.allow_nested_operators = allow_nested_operators
+ self.inlets: list = []
+ self.outlets: list = []
"""
self._log_config_logger_name = "airflow.task.operators"
self._logger_name = logger_name
- self.allow_nested_operators: bool = allow_nested_operators
# Lineage
- self.inlets: list = []
- self.outlets: list = []
if inlets:
self.inlets = (
@@ -718,6 +805,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
outlets,
]
)
+ """
if isinstance(self.template_fields, str):
warnings.warn(
@@ -731,11 +819,11 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self._is_setup = False
self._is_teardown = False
- if SetupTeardownContext.active:
- SetupTeardownContext.update_context_map(self)
+ # TODO: Task-SDK
+ # if SetupTeardownContext.active:
+ # SetupTeardownContext.update_context_map(self)
validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
- """
def __eq__(self, other):
if type(self) is type(other):
@@ -810,19 +898,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return self
- def __setattr__(self, key, value):
- super().__setattr__(key, value)
- if self.__from_mapped or self._lock_for_execution:
- return # Skip any custom behavior for validation and during
execute.
- if key in self.__init_kwargs:
- self.__init_kwargs[key] = value
- if self.__instantiated and key in self.template_fields:
- # Resolve upstreams set by assigning an XComArg after initializing
- # an operator, example:
- # op = BashOperator()
- # op.bash_command = "sleep 1"
- self.set_xcomargs_dependencies()
-
def add_inlets(self, inlets: Iterable[Any]):
"""Set inlets to this operator."""
self.inlets.extend(inlets)
@@ -832,45 +907,77 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.outlets.extend(outlets)
def get_dag(self) -> DAG | None:
- return self._dag
+ return self.dag
- @property # type: ignore[override]
- def dag(self) -> DAG: # type: ignore[override]
- """Returns the Operator's DAG if set, otherwise raises an error."""
- if self._dag:
- return self._dag
- else:
- raise RuntimeError(f"Operator {self} has not been assigned to a
DAG yet")
-
- @dag.setter
- def dag(self, dag: DAG | None):
+ def _convert_dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None |
AttributeRemoved:
"""Operators can be assigned to one DAG, one time. Repeat assignments
to that same DAG are ok."""
- from .dag import DAG
+ from airflow.sdk.definitions.dag import DAG
if dag is None:
- self._dag = None
- return
+ return dag
+
+ # if set to removed, then just set and exit
+ if self.dag.__class__ is AttributeRemoved:
+ return dag
+ # if setting to removed, then just set and exit
+ if dag.__class__ is AttributeRemoved:
+ return AttributeRemoved("_dag") # type: ignore[assignment]
+
if not isinstance(dag, DAG):
raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
- elif self.has_dag() and self.dag is not dag:
+ elif self.dag is not None and self.dag is not dag:
raise ValueError(f"The DAG assigned to {self} can not be changed.")
if self.__from_mapped:
pass # Don't add to DAG -- the mapped task takes the place.
elif dag.task_dict.get(self.task_id) is not self:
dag.add_task(self)
-
- self._dag = dag
+ return dag
+
+ @staticmethod
+ def _convert_retries(retries: Any) -> int | None:
+ if retries is None:
+ return 0
+ elif type(retries) == int: # noqa: E721
+ return retries
+ try:
+ parsed_retries = int(retries)
+ except (TypeError, ValueError):
+ raise TypeError(f"'retries' type must be int, not
{type(retries).__name__}")
+ return parsed_retries
+
+ @staticmethod
+ def _convert_timedelta(value: float | timedelta) -> timedelta:
+ if isinstance(value, timedelta):
+ return value
+ return timedelta(seconds=value)
+
+ _convert_retry_delay = _convert_timedelta
+ _convert_max_retry_delay = _convert_timedelta
+
+ @staticmethod
+ def _convert_resources(resources: dict[str, Any] | None) -> Resources |
None:
+ if resources is None:
+ return None
+ return Resources(**resources)
@property
def task_display_name(self) -> str:
- return self._task_display_property_value or self.task_id
+ return self._task_display_name or self.task_id
def has_dag(self):
"""Return True if the Operator has been assigned to a DAG."""
- return self._dag is not None
+ return self.dag is not None
+
+ def _set_xcomargs_dependencies(self) -> None:
+ from airflow.models.xcom_arg import XComArg
+
+ for field in self.template_fields:
+ arg = getattr(self, field, NOTSET)
+ if arg is not NOTSET:
+ XComArg.apply_upstream_relationship(self, arg)
- def set_xcomargs_dependencies(self) -> None:
+ def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None:
"""
Resolve upstream dependencies of a task.
@@ -892,10 +999,9 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"""
from airflow.models.xcom_arg import XComArg
- for field in self.template_fields:
- if hasattr(self, field):
- arg = getattr(self, field)
- XComArg.apply_upstream_relationship(self, arg)
+ if field not in self.template_fields:
+ return
+ XComArg.apply_upstream_relationship(self, newvalue)
def on_kill(self) -> None:
"""
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py
b/task_sdk/src/airflow/sdk/definitions/dag.py
index f80d8f7f71..5cf76da488 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -27,7 +27,7 @@ import sys
import weakref
from collections import abc
from collections.abc import Collection, Iterable, Iterator
-from datetime import datetime, timedelta
+from datetime import datetime, timedelta, timezone
from inspect import signature
from re import Pattern
from typing import (
@@ -51,7 +51,6 @@ from airflow import settings
from airflow.assets import Asset, AssetAlias, BaseAsset
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import (
- AirflowException,
DuplicateTaskIdFound,
FailStopDagInvalidTriggerRule,
ParamValidationError,
@@ -62,6 +61,13 @@ from airflow.sdk.definitions.abstractoperator import
AbstractOperator
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.stats import Stats
from airflow.timetables.base import Timetable
+from airflow.timetables.interval import CronDataIntervalTimetable,
DeltaDataIntervalTimetable
+from airflow.timetables.simple import (
+ AssetTriggeredTimetable,
+ ContinuousTimetable,
+ NullTimetable,
+ OnceTimetable,
+)
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.trigger_rule import TriggerRule
@@ -70,15 +76,17 @@ from airflow.utils.types import NOTSET, EdgeInfoType
if TYPE_CHECKING:
from airflow.decorators import TaskDecoratorCollection
from airflow.models.operator import Operator
- from airflow.utils.taskgroup import TaskGroup
+ from airflow.sdk.definitions.taskgroup import TaskGroup
log = logging.getLogger(__name__)
-DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"]
-ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"]
-
TAG_MAX_LEN = 100
+__all__ = [
+ "DAG",
+ "dag",
+]
+
# TODO: Task-SDK
class Context: ...
@@ -135,7 +143,25 @@ DAG_ARGS_EXPECTED_TYPES = {
}
[email protected]
+def _create_timetable(interval: ScheduleInterval, timezone: Timezone |
FixedTimezone) -> Timetable:
+ """Create a Timetable instance from a plain ``schedule`` value."""
+ if interval is None:
+ return NullTimetable()
+ if interval == "@once":
+ return OnceTimetable()
+ if interval == "@continuous":
+ return ContinuousTimetable()
+ if isinstance(interval, (timedelta, relativedelta)):
+ return DeltaDataIntervalTimetable(interval)
+ if isinstance(interval, str):
+ if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"):
+ return CronDataIntervalTimetable(interval, timezone)
+ else:
+ return CronTriggerTimetable(interval, timezone=timezone)
+ raise ValueError(f"{interval!r} is not a valid schedule.")
+
+
[email protected](kw_only=True)
class DAG:
"""
A dag (directed acyclic graph) is a collection of tasks with directional
dependencies.
@@ -266,11 +292,11 @@ class DAG:
# below in sync. (Search for 'def dag(' in this file.)
dag_id: str
description: str | None = None
- schedule: ScheduleArg = NOTSET
- schedule_interval: ScheduleIntervalArg = NOTSET
- timetable: Timetable | None = None
start_date: datetime | None = None
end_date: datetime | None = None
+ timezone: timezone = timezone.utc
+ schedule: ScheduleArg = attrs.field(default=None,
on_setattr=attrs.setters.NO_OP)
+ timetable: Timetable = attrs.field(init=False)
full_filepath: str | None = None
template_searchpath: str | Iterable[str] | None = None
# template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
@@ -285,7 +311,7 @@ class DAG:
# on_success_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None
# on_failure_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None
doc_md: str | None = None
- params: abc.MutableMapping | None = None
+ params: abc.MutableMapping | None = attrs.field(default=None)
access_control: dict | None = None
is_paused_upon_creation: bool | None = None
jinja_environment_kwargs: dict | None = None
@@ -310,6 +336,55 @@ class DAG:
return TaskGroup.create_root(dag=self)
+ @timetable.default
+ def _set_schedule(self):
+ schedule = self.schedule
+ delattr(self, "schedule")
+ if isinstance(schedule, Timetable):
+ return schedule
+ elif isinstance(schedule, BaseAsset):
+ return AssetTriggeredTimetable(schedule)
+ elif isinstance(schedule, Collection) and not isinstance(schedule,
str):
+ if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule):
+ raise ValueError("All elements in 'schedule' should be assets
or asset aliases")
+ return AssetTriggeredTimetable(AssetAll(*schedule))
+ else:
+ return _create_timetable(schedule, self.timezone)
+
+ @params.validator
+ def _validate_params(self, attr, val: abc.MutableMapping | None):
+ """
+ Validate Param values when the DAG has schedule defined.
+
+ Raise exception if there are any Params which can not be resolved by
their schema definition.
+
+ This will also merge in params from default_args
+ """
+ # TODO: Task-SDK
+ from airflow.models.param import ParamsDict
+
+ val = val or {}
+
+ # merging potentially conflicting default_args['params'] into params
+ if "params" in self.default_args:
+ val.update(self.default_args["params"])
+ del self.default_args["params"]
+
+ params = ParamsDict(val)
+ object.__setattr__(self, "params", params)
+ if not self.timetable or not self.timetable.can_be_scheduled:
+ return
+
+ try:
+ params.validate()
+ except ParamValidationError as pverr:
+ raise ValueError(
+ f"DAG {self.dag_id!r} is not allowed to define a Schedule, "
+ "as there are required params without default values, or the
default values are not valid."
+ ) from pverr
+
+ # check self.params and convert them into ParamsDict
+
def __repr__(self):
return f"<DAG: {self.dag_id}>"
@@ -832,23 +907,6 @@ class DAG:
"""
self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] =
info
- def validate_schedule_and_params(self):
- """
- Validate Param values when the DAG has schedule defined.
-
- Raise exception if there are any Params which can not be resolved by
their schema definition.
- """
- if not self.timetable.can_be_scheduled:
- return
-
- try:
- self.params.validate()
- except ParamValidationError as pverr:
- raise AirflowException(
- "DAG is not allowed to define a Schedule, "
- "if there are any required params without default values or
default values are not valid."
- ) from pverr
-
def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]:
"""
Parse a given link, and verifies if it's a valid URL, or a 'mailto'
link.
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py
b/task_sdk/src/airflow/sdk/definitions/node.py
index 1d98028467..40236cac8b 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/node.py
@@ -108,8 +108,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
edge_modifier: EdgeModifier | None = None,
) -> None:
"""Set relatives for the task or task list."""
- from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
+ from airflow.sdk.definitions.baseoperator import BaseOperator
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]
diff --git a/task_sdk/src/airflow/sdk/types.py
b/task_sdk/src/airflow/sdk/types.py
index 505ee4cb19..a412509ba5 100644
--- a/task_sdk/src/airflow/sdk/types.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -17,7 +17,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
class ArgNotSet:
@@ -43,7 +43,24 @@ NOTSET = ArgNotSet()
if TYPE_CHECKING:
import logging
+ from airflow.sdk.definitions.node import DAGNode
+
Logger = logging.Logger
else:
class Logger: ...
+
+
+def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str,
Any]) -> None:
+ """Validate that the instance has the expected types for the arguments."""
+ from airflow.sdk.definitions.taskgroup import TaskGroup
+
+ typ = "task group" if isinstance(instance, TaskGroup) else "task"
+
+ for arg_name, expected_arg_type in expected_arg_types.items():
+ instance_arg_value = getattr(instance, arg_name, None)
+ if instance_arg_value is not None and not
isinstance(instance_arg_value, expected_arg_type):
+ raise TypeError(
+ f"{arg_name!r} for {typ} {instance.node_id!r} expects
{expected_arg_type}, got {type(instance_arg_value)} with value "
+ f"{instance_arg_value!r}"
+ )
diff --git a/task_sdk/tests/defintions/test_baseoperator.py
b/task_sdk/tests/defintions/test_baseoperator.py
new file mode 100644
index 0000000000..0222de3ee3
--- /dev/null
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import warnings
+from datetime import UTC, datetime, timedelta
+
+import pytest
+
+from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta
+from airflow.sdk.definitions.dag import DAG
+from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy
+
+DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=UTC)
+
+
+# Essentially similar to airflow.models.baseoperator.BaseOperator
+class FakeOperator(metaclass=BaseOperatorMeta):
+ def __init__(self, test_param, params=None, default_args=None):
+ self.test_param = test_param
+
+ def _set_xcomargs_dependencies(self): ...
+
+
+class FakeSubClass(FakeOperator):
+ def __init__(self, test_sub_param, test_param, **kwargs):
+ super().__init__(test_param=test_param, **kwargs)
+ self.test_sub_param = test_sub_param
+
+
+class DeprecatedOperator(BaseOperator):
+ def __init__(self, **kwargs):
+ warnings.warn("This operator is deprecated.", DeprecationWarning,
stacklevel=2)
+ super().__init__(**kwargs)
+
+ def execute(self, context: Context):
+ pass
+
+
+class MockOperator(BaseOperator):
+ """Operator for testing purposes."""
+
+ template_fields: Sequence[str] = ("arg1", "arg2")
+
+ def __init__(self, arg1: str = "", arg2: str = "", **kwargs):
+ super().__init__(**kwargs)
+ self.arg1 = arg1
+ self.arg2 = arg2
+
+ def execute(self, context: Context):
+ pass
+
+
+class TestBaseOperator:
+ # Since we have a custom metaclass, lets double check the behaviour of
passing args in the wrong way (args
+ # etc)
+ def test_kwargs_only(self):
+ with pytest.raises(TypeError, match="keyword arguments"):
+ BaseOperator("task_id")
+
+ def test_missing_kwarg(self):
+ with pytest.raises(TypeError, match="missing keyword argument"):
+ FakeOperator(task_id="task_id")
+
+ def test_missing_kwargs(self):
+ with pytest.raises(TypeError, match="missing keyword arguments"):
+ FakeSubClass(task_id="task_id")
+
+ def test_hash(self):
+ """Two operators created equally should hash equaylly"""
+ # Include a "non-hashable" type too
+ assert hash(MockOperator(task_id="one", retries=1024 * 1024,
arg1="abcef", params={"a": 1})) == hash(
+ MockOperator(task_id="one", retries=1024 * 1024, arg1="abcef",
params={"a": 2})
+ )
+
+ def test_expand(self):
+ op = FakeOperator(test_param=True)
+ assert op.test_param
+
+ with pytest.raises(TypeError, match="missing keyword argument
'test_param'"):
+ FakeSubClass(test_sub_param=True)
+
+ def test_default_args(self):
+ default_args = {"test_param": True}
+ op = FakeOperator(default_args=default_args)
+ assert op.test_param
+
+ default_args = {"test_param": True, "test_sub_param": True}
+ op = FakeSubClass(default_args=default_args)
+ assert op.test_param
+ assert op.test_sub_param
+
+ default_args = {"test_param": True}
+ op = FakeSubClass(default_args=default_args, test_sub_param=True)
+ assert op.test_param
+ assert op.test_sub_param
+
+ with pytest.raises(TypeError, match="missing keyword argument
'test_sub_param'"):
+ FakeSubClass(default_args=default_args)
+
+ def test_execution_timeout_type(self):
+ with pytest.raises(
+ ValueError, match="execution_timeout must be timedelta object but
passed as type: <class 'str'>"
+ ):
+ BaseOperator(task_id="test", execution_timeout="1")
+
+ with pytest.raises(
+ ValueError, match="execution_timeout must be timedelta object but
passed as type: <class 'int'>"
+ ):
+ BaseOperator(task_id="test", execution_timeout=1)
+
+ def test_incorrect_default_args(self):
+ default_args = {"test_param": True, "extra_param": True}
+ op = FakeOperator(default_args=default_args)
+ assert op.test_param
+
+ default_args = {"random_params": True}
+ with pytest.raises(TypeError, match="missing keyword argument
'test_param'"):
+ FakeOperator(default_args=default_args)
+
+ def test_incorrect_priority_weight(self):
+ error_msg = "'priority_weight' for task 'test_op' expects <class
'int'>, got <class 'str'>"
+ with pytest.raises(TypeError, match=error_msg):
+ BaseOperator(task_id="test_op", priority_weight="2")
+
+ def test_illegal_args_forbidden(self):
+ """
+ Tests that operators raise exceptions on illegal arguments when
+ illegal arguments are not allowed.
+ """
+ msg = r"Invalid arguments were passed to BaseOperator \(task_id:
test_illegal_args\)"
+ with pytest.raises(TypeError, match=msg):
+ BaseOperator(
+ task_id="test_illegal_args",
+ illegal_argument_1234="hello?",
+ )
+
+ def test_invalid_type_for_default_arg(self):
+ error_msg = "'max_active_tis_per_dag' for task 'test' expects <class
'int'>, got <class 'str'> with value 'not_an_int'"
+ with pytest.raises(TypeError, match=error_msg):
+ BaseOperator(task_id="test",
default_args={"max_active_tis_per_dag": "not_an_int"})
+
+ def test_invalid_type_for_operator_arg(self):
+ error_msg = "'max_active_tis_per_dag' for task 'test' expects <class
'int'>, got <class 'str'> with value 'not_an_int'"
+ with pytest.raises(TypeError, match=error_msg):
+ BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int")
+
+ def test_weight_rule_default(self):
+ op = BaseOperator(task_id="test_task")
+ assert _DownstreamPriorityWeightStrategy() == op.weight_rule
+
+ def test_weight_rule_override(self):
+ op = BaseOperator(task_id="test_task", weight_rule="upstream")
+ assert _UpstreamPriorityWeightStrategy() == op.weight_rule
+
+ def test_warnings_are_properly_propagated(self):
+ with pytest.warns(DeprecationWarning) as warnings:
+ DeprecatedOperator(task_id="test")
+ assert len(warnings) == 1
+ warning = warnings[0]
+ # Here we check that the trace points to the place
+ # where the deprecated class was used
+ assert warning.filename == __file__
+
+ def test_setattr_performs_no_custom_action_at_execute_time(self,
spy_agency):
+ from airflow.models.xcom_arg import XComArg
+
+ op = MockOperator(task_id="test_task")
+ # TODO: Task-SDK
+ # op_copy = op.prepare_for_execution()
+ op_copy = op
+
+ spy_agency.spy_on(XComArg.apply_upstream_relationship,
call_original=False)
+ op_copy.execute({})
+ assert XComArg.apply_upstream_relationship.called == False
+
+ def test_upstream_is_set_when_template_field_is_xcomarg(self):
+ with DAG("xcomargs_test", schedule=None):
+ op1 = BaseOperator(task_id="op1")
+ op2 = MockOperator(task_id="op2", arg1=op1.output)
+
+ assert op1.task_id in op2.upstream_task_ids
+ assert op2.task_id in op1.downstream_task_ids
+
+ def test_set_xcomargs_dependencies_works_recursively(self):
+ with DAG("xcomargs_test", schedule=None):
+ op1 = BaseOperator(task_id="op1")
+ op2 = BaseOperator(task_id="op2")
+ op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output])
+ op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2":
op2.output})
+
+ assert op1.task_id in op3.upstream_task_ids
+ assert op2.task_id in op3.upstream_task_ids
+ assert op1.task_id in op4.upstream_task_ids
+ assert op2.task_id in op4.upstream_task_ids
+
+ def test_set_xcomargs_dependencies_works_when_set_after_init(self):
+ with DAG(dag_id="xcomargs_test", schedule=None):
+ op1 = BaseOperator(task_id="op1")
+ op2 = MockOperator(task_id="op2")
+ op2.arg1 = op1.output # value is set after init
+
+ assert op1.task_id in op2.upstream_task_ids
+
+ def test_set_xcomargs_dependencies_error_when_outside_dag(self):
+ op1 = BaseOperator(task_id="op1")
+ with pytest.raises(ValueError):
+ MockOperator(task_id="op2", arg1=op1.output)
+
+ def test_cannot_change_dag(self):
+ with DAG(dag_id="dag1", schedule=None):
+ op1 = BaseOperator(task_id="op1")
+ with pytest.raises(ValueError, match="can not be changed"):
+ op1.dag = DAG(dag_id="dag2")
+
+
+def test_init_subclass_args():
+ class InitSubclassOp(BaseOperator):
+ _class_arg: Any
+
+ def __init_subclass__(cls, class_arg=None, **kwargs) -> None:
+ cls._class_arg = class_arg
+ super().__init_subclass__()
+
+ def execute(self, context: Context):
+ self.context_arg = context
+
+ class_arg = "foo"
+ context = {"key": "value"}
+
+ class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg):
+ pass
+
+ task = ConcreteSubclassOp(task_id="op1")
+ # TODO: Task-SDK
+ # task_copy = task.prepare_for_execution()
+ task_copy = task
+
+ task_copy.execute(context)
+
+ assert task_copy._class_arg == class_arg
+ assert task_copy.context_arg == context
+
+
+class CustomInt(int):
+ def __int__(self):
+ raise ValueError("Cannot cast to int")
+
+
[email protected](
+ ("retries", "expected"),
+ [
+ pytest.param("foo", "'retries' type must be int, not str",
id="string"),
+ pytest.param(CustomInt(10), "'retries' type must be int, not
CustomInt", id="custom int"),
+ ],
+)
+def test_operator_retries_invalid(dag_maker, retries, expected):
+ with pytest.raises(TypeError) as ctx:
+ BaseOperator(task_id="test_illegal_args", retries=retries)
+ assert str(ctx.value) == expected
+
+
[email protected](
+ ("retries", "expected"),
+ [
+ pytest.param(None, 0, id="None"),
+ pytest.param("5", 5, id="str"),
+ pytest.param(1, 1, id="int"),
+ ],
+)
+def test_operator_retries_conversion(retries, expected):
+ op = BaseOperator(
+ task_id="test_illegal_args",
+ retries=retries,
+ )
+ assert op.retries == expected
+
+
+def test_dag_level_retry_delay(dag_maker):
+ with DAG(dag_id="test_dag_level_retry_delay", default_args={"retry_delay":
timedelta(seconds=100)}):
+ task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
+
+ assert task1.retry_delay == timedelta(seconds=100)
+
+
+def test_task_level_retry_delay():
+ with DAG(dag_id="test_task_level_retry_delay",
default_args={"retry_delay": timedelta(seconds=100)}):
+ task1 = BaseOperator(task_id="test_no_explicit_retry_delay",
retry_delay=200)
+
+ assert task1.retry_delay == timedelta(seconds=200)
diff --git a/task_sdk/tests/defintions/test_dag.py
b/task_sdk/tests/defintions/test_dag.py
index e07e3a8bfa..29249745ff 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
from __future__ import annotations
-from datetime import UTC, datetime
+from datetime import UTC, datetime, timedelta
+
+import pytest
+from airflow.exceptions import DuplicateTaskIdFound
+from airflow.models.param import Param, ParamsDict
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
@@ -76,3 +79,92 @@ class TestDag:
assert op7.dag == dag
assert op8.dag == dag
assert op9.dag == dag2
+
+ def test_params_not_passed_is_empty_dict(self):
+ """
+ Test that when 'params' is _not_ passed to a new Dag, that the params
+ attribute is set to an empty dictionary.
+ """
+ dag = DAG("test-dag", schedule=None)
+
+ assert isinstance(dag.params, ParamsDict)
+ assert 0 == len(dag.params)
+
+ def test_params_passed_and_params_in_default_args_no_override(self):
+ """
+ Test that when 'params' exists as a key passed to the default_args dict
+ in addition to params being passed explicitly as an argument to the
+ dag, that the 'params' key of the default_args dict is merged with the
+ dict of the params argument.
+ """
+ params1 = {"parameter1": 1}
+ params2 = {"parameter2": 2}
+
+ dag = DAG("test-dag", schedule=None, default_args={"params": params1},
params=params2)
+
+ assert params1["parameter1"] == dag.params["parameter1"]
+ assert params2["parameter2"] == dag.params["parameter2"]
+
+ def test_not_none_schedule_with_non_default_params(self):
+ """
+ Test if there is a DAG with a schedule and have some params that don't
have a default value raise a
+ error while DAG parsing. (Because we can't schedule them if there we
don't know what value to use)
+ """
+ params = {"param1": Param(type="string")}
+
+ with pytest.raises(ValueError):
+ DAG("my-dag", schedule=timedelta(days=1), start_date=DEFAULT_DATE,
params=params)
+
+ def test_roots(self):
+ """Verify if dag.roots returns the root tasks of a DAG."""
+ with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+ op1 = BaseOperator(task_id="t1")
+ op2 = BaseOperator(task_id="t2")
+ op3 = BaseOperator(task_id="t3")
+ op4 = BaseOperator(task_id="t4")
+ op5 = BaseOperator(task_id="t5")
+ [op1, op2] >> op3 >> [op4, op5]
+
+ assert set(dag.roots) == {op1, op2}
+
+ def test_leaves(self):
+ """Verify if dag.leaves returns the leaf tasks of a DAG."""
+ with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+ op1 = BaseOperator(task_id="t1")
+ op2 = BaseOperator(task_id="t2")
+ op3 = BaseOperator(task_id="t3")
+ op4 = BaseOperator(task_id="t4")
+ op5 = BaseOperator(task_id="t5")
+ [op1, op2] >> op3 >> [op4, op5]
+
+ assert set(dag.leaves) == {op4, op5}
+
+ def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
+ """Verify tasks with Duplicate task_id raises error"""
+ with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+ op1 = BaseOperator(task_id="t1")
+ with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has
already been added to the DAG"):
+ BaseOperator(task_id="t1")
+
+ assert dag.task_dict == {op1.task_id: op1}
+
+ def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
+ """Verify tasks with Duplicate task_id raises error"""
+ dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE)
+ op1 = BaseOperator(task_id="t1", dag=dag)
+ with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has
already been added to the DAG"):
+ BaseOperator(task_id="t1", dag=dag)
+
+ assert dag.task_dict == {op1.task_id: op1}
+
+ def test_duplicate_task_ids_for_same_task_is_allowed(self):
+ """Verify that same tasks with Duplicate task_id do not raise error"""
+ with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+ op1 = op2 = BaseOperator(task_id="t1")
+ op3 = BaseOperator(task_id="t3")
+ op1 >> op3
+ op2 >> op3
+
+ assert op1 == op2
+ assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
+ assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
diff --git a/tests/models/test_baseoperator.py
b/tests/models/test_baseoperator.py
index 999529e14a..eea5dae269 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -22,7 +22,7 @@ import logging
import uuid
from collections import defaultdict
from datetime import date, datetime, timedelta
-from typing import TYPE_CHECKING, Any, NamedTuple
+from typing import NamedTuple
from unittest import mock
import jinja2
@@ -32,9 +32,7 @@ from airflow.decorators import task as task_decorator
from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule
from airflow.lineage.entities import File
from airflow.models.baseoperator import (
- BASEOPERATOR_ARGS_EXPECTED_TYPES,
BaseOperator,
- BaseOperatorMeta,
chain,
chain_linear,
cross_downstream,
@@ -43,7 +41,6 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.sql.operators import sql
-from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy
from airflow.utils.edgemodifier import Label
from airflow.utils.task_group import TaskGroup
from airflow.utils.template import literal
@@ -51,10 +48,7 @@ from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
-from dev.tests_common.test_utils.mock_operators import DeprecatedOperator,
MockOperator
-
-if TYPE_CHECKING:
- from airflow.utils.context import Context
+from dev.tests_common.test_utils.mock_operators import MockOperator
class ClassWithCustomAttributes:
@@ -83,93 +77,12 @@ object2 = ClassWithCustomAttributes(attr="{{ foo }}_2",
ref=object1, template_fi
setattr(object1, "ref", object2)
-# Essentially similar to airflow.models.baseoperator.BaseOperator
-class DummyClass(metaclass=BaseOperatorMeta):
- def __init__(self, test_param, params=None, default_args=None):
- self.test_param = test_param
-
- def set_xcomargs_dependencies(self): ...
-
-
-class DummySubClass(DummyClass):
- def __init__(self, test_sub_param, **kwargs):
- super().__init__(**kwargs)
- self.test_sub_param = test_sub_param
-
-
class MockNamedTuple(NamedTuple):
var1: str
var2: str
-class CustomInt(int):
- def __int__(self):
- raise ValueError("Cannot cast to int")
-
-
class TestBaseOperator:
- def test_expand(self):
- dummy = DummyClass(test_param=True)
- assert dummy.test_param
-
- with pytest.raises(AirflowException, match="missing keyword argument
'test_param'"):
- DummySubClass(test_sub_param=True)
-
- def test_default_args(self):
- default_args = {"test_param": True}
- dummy_class = DummyClass(default_args=default_args)
- assert dummy_class.test_param
-
- default_args = {"test_param": True, "test_sub_param": True}
- dummy_subclass = DummySubClass(default_args=default_args)
- assert dummy_class.test_param
- assert dummy_subclass.test_sub_param
-
- default_args = {"test_param": True}
- dummy_subclass = DummySubClass(default_args=default_args,
test_sub_param=True)
- assert dummy_class.test_param
- assert dummy_subclass.test_sub_param
-
- with pytest.raises(AirflowException, match="missing keyword argument
'test_sub_param'"):
- DummySubClass(default_args=default_args)
-
- def test_execution_timeout_type(self):
- with pytest.raises(
- ValueError, match="execution_timeout must be timedelta object but
passed as type: <class 'str'>"
- ):
- BaseOperator(task_id="test", execution_timeout="1")
-
- with pytest.raises(
- ValueError, match="execution_timeout must be timedelta object but
passed as type: <class 'int'>"
- ):
- BaseOperator(task_id="test", execution_timeout=1)
-
- def test_incorrect_default_args(self):
- default_args = {"test_param": True, "extra_param": True}
- dummy_class = DummyClass(default_args=default_args)
- assert dummy_class.test_param
-
- default_args = {"random_params": True}
- with pytest.raises(AirflowException, match="missing keyword argument
'test_param'"):
- DummyClass(default_args=default_args)
-
- def test_incorrect_priority_weight(self):
- error_msg = "`priority_weight` for task 'test_op' only accepts
integers, received '<class 'str'>'."
- with pytest.raises(AirflowException, match=error_msg):
- BaseOperator(task_id="test_op", priority_weight="2")
-
- def test_illegal_args_forbidden(self):
- """
- Tests that operators raise exceptions on illegal arguments when
- illegal arguments are not allowed.
- """
- msg = r"Invalid arguments were passed to BaseOperator \(task_id:
test_illegal_args\)"
- with pytest.raises(AirflowException, match=msg):
- BaseOperator(
- task_id="test_illegal_args",
- illegal_argument_1234="hello?",
- )
-
def test_trigger_rule_validation(self):
from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
@@ -659,15 +572,6 @@ class TestBaseOperator:
task4 > [inlet, outlet, extra]
assert task4.get_outlet_defs() == [inlet, outlet, extra]
- def test_warnings_are_properly_propagated(self):
- with pytest.warns(DeprecationWarning) as warnings:
- DeprecatedOperator(task_id="test")
- assert len(warnings) == 1
- warning = warnings[0]
- # Here we check that the trace points to the place
- # where the deprecated class was used
- assert warning.filename == __file__
-
def test_pre_execute_hook(self):
hook = mock.MagicMock()
@@ -694,47 +598,6 @@ class TestBaseOperator:
assert op_no_dag.start_date.tzinfo
assert op_no_dag.end_date.tzinfo
- def test_setattr_performs_no_custom_action_at_execute_time(self):
- op = MockOperator(task_id="test_task")
- op_copy = op.prepare_for_execution()
-
- with
mock.patch("airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies")
as method_mock:
- op_copy.execute({})
- assert method_mock.call_count == 0
-
- def test_upstream_is_set_when_template_field_is_xcomarg(self):
- with DAG("xcomargs_test", schedule=None, default_args={"start_date":
datetime.today()}):
- op1 = BaseOperator(task_id="op1")
- op2 = MockOperator(task_id="op2", arg1=op1.output)
-
- assert op1 in op2.upstream_list
- assert op2 in op1.downstream_list
-
- def test_set_xcomargs_dependencies_works_recursively(self):
- with DAG("xcomargs_test", schedule=None, default_args={"start_date":
datetime.today()}):
- op1 = BaseOperator(task_id="op1")
- op2 = BaseOperator(task_id="op2")
- op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output])
- op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2":
op2.output})
-
- assert op1 in op3.upstream_list
- assert op2 in op3.upstream_list
- assert op1 in op4.upstream_list
- assert op2 in op4.upstream_list
-
- def test_set_xcomargs_dependencies_works_when_set_after_init(self):
- with DAG(dag_id="xcomargs_test", schedule=None,
default_args={"start_date": datetime.today()}):
- op1 = BaseOperator(task_id="op1")
- op2 = MockOperator(task_id="op2")
- op2.arg1 = op1.output # value is set after init
-
- assert op1 in op2.upstream_list
-
- def test_set_xcomargs_dependencies_error_when_outside_dag(self):
- op1 = BaseOperator(task_id="op1")
- with pytest.raises(AirflowException):
- MockOperator(task_id="op2", arg1=op1.output)
-
def test_invalid_trigger_rule(self):
with pytest.raises(
AirflowException,
@@ -745,14 +608,6 @@ class TestBaseOperator:
):
BaseOperator(task_id="op1", trigger_rule="some_rule")
- def test_weight_rule_default(self):
- op = BaseOperator(task_id="test_task")
- assert _DownstreamPriorityWeightStrategy() == op.weight_rule
-
- def test_weight_rule_override(self):
- op = BaseOperator(task_id="test_task", weight_rule="upstream")
- assert _UpstreamPriorityWeightStrategy() == op.weight_rule
-
# ensure the default logging config is used for this test, no matter what
ran before
@pytest.mark.usefixtures("reset_logging_config")
def test_logging_propogated_by_default(self, caplog):
@@ -763,92 +618,6 @@ class TestBaseOperator:
# leaking a lot of state)
assert caplog.messages == ["test"]
- def test_invalid_type_for_default_arg(self):
- error_msg = "'max_active_tis_per_dag' has an invalid type <class
'str'> with value not_an_int, expected type is <class 'int'>"
- with pytest.raises(TypeError, match=error_msg):
- BaseOperator(task_id="test",
default_args={"max_active_tis_per_dag": "not_an_int"})
-
- def test_invalid_type_for_operator_arg(self):
- error_msg = "'max_active_tis_per_dag' has an invalid type <class
'str'> with value not_an_int, expected type is <class 'int'>"
- with pytest.raises(TypeError, match=error_msg):
- BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int")
-
- @mock.patch("airflow.models.baseoperator.validate_instance_args")
- def test_baseoperator_init_validates_arg_types(self,
mock_validate_instance_args):
- operator = BaseOperator(task_id="test")
-
- mock_validate_instance_args.assert_called_once_with(operator,
BASEOPERATOR_ARGS_EXPECTED_TYPES)
-
-
-def test_init_subclass_args():
- class InitSubclassOp(BaseOperator):
- _class_arg: Any
-
- def __init_subclass__(cls, class_arg=None, **kwargs) -> None:
- cls._class_arg = class_arg
- super().__init_subclass__()
-
- def execute(self, context: Context):
- self.context_arg = context
-
- class_arg = "foo"
- context = {"key": "value"}
-
- class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg):
- pass
-
- task = ConcreteSubclassOp(task_id="op1")
- task_copy = task.prepare_for_execution()
-
- task_copy.execute(context)
-
- assert task_copy._class_arg == class_arg
- assert task_copy.context_arg == context
-
-
[email protected]_test
[email protected](
- ("retries", "expected"),
- [
- pytest.param("foo", "'retries' type must be int, not str",
id="string"),
- pytest.param(CustomInt(10), "'retries' type must be int, not
CustomInt", id="custom int"),
- ],
-)
-def test_operator_retries_invalid(dag_maker, retries, expected):
- with pytest.raises(AirflowException) as ctx:
- with dag_maker():
- BaseOperator(task_id="test_illegal_args", retries=retries)
- assert str(ctx.value) == expected
-
-
[email protected]_test
[email protected](
- ("retries", "expected"),
- [
- pytest.param(None, [], id="None"),
- pytest.param(5, [], id="5"),
- pytest.param(
- "1",
- [
- (
- "airflow.models.baseoperator.BaseOperator",
- logging.WARNING,
- "Implicitly converting 'retries' from '1' to int",
- ),
- ],
- id="str",
- ),
- ],
-)
-def test_operator_retries(caplog, dag_maker, retries, expected):
- with caplog.at_level(logging.WARNING):
- with dag_maker():
- BaseOperator(
- task_id="test_illegal_args",
- retries=retries,
- )
- assert caplog.record_tuples == expected
-
@pytest.mark.db_test
def test_default_retry_delay(dag_maker):
@@ -858,24 +627,6 @@ def test_default_retry_delay(dag_maker):
assert task1.retry_delay == timedelta(seconds=300)
[email protected]_test
-def test_dag_level_retry_delay(dag_maker):
- with dag_maker(dag_id="test_dag_level_retry_delay",
default_args={"retry_delay": timedelta(seconds=100)}):
- task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
-
- assert task1.retry_delay == timedelta(seconds=100)
-
-
[email protected]_test
-def test_task_level_retry_delay(dag_maker):
- with dag_maker(
- dag_id="test_task_level_retry_delay", default_args={"retry_delay":
timedelta(seconds=100)}
- ):
- task1 = BaseOperator(task_id="test_no_explicit_retry_delay",
retry_delay=timedelta(seconds=200))
-
- assert task1.retry_delay == timedelta(seconds=200)
-
-
def test_deepcopy():
# Test bug when copying an operator attached to a DAG
with DAG("dag0", schedule=None, start_date=DEFAULT_DATE) as dag:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 67dc699fc3..c0b705d641 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -42,7 +42,6 @@ from airflow.configuration import conf
from airflow.decorators import setup, task as task_decorator, teardown
from airflow.exceptions import (
AirflowException,
- DuplicateTaskIdFound,
ParamValidationError,
UnknownExecutorException,
)
@@ -65,7 +64,7 @@ from airflow.models.dag import (
get_asset_triggered_next_run_info,
)
from airflow.models.dagrun import DagRun
-from airflow.models.param import DagParam, Param, ParamsDict
+from airflow.models.param import DagParam, Param
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance as TI
@@ -175,41 +174,6 @@ class TestDag:
b_index = i
return 0 <= a_index < b_index
- def test_params_not_passed_is_empty_dict(self):
- """
- Test that when 'params' is _not_ passed to a new Dag, that the params
- attribute is set to an empty dictionary.
- """
- dag = DAG("test-dag", schedule=None)
-
- assert isinstance(dag.params, ParamsDict)
- assert 0 == len(dag.params)
-
- def test_params_passed_and_params_in_default_args_no_override(self):
- """
- Test that when 'params' exists as a key passed to the default_args dict
- in addition to params being passed explicitly as an argument to the
- dag, that the 'params' key of the default_args dict is merged with the
- dict of the params argument.
- """
- params1 = {"parameter1": 1}
- params2 = {"parameter2": 2}
-
- dag = DAG("test-dag", schedule=None, default_args={"params": params1},
params=params2)
-
- assert params1["parameter1"] == dag.params["parameter1"]
- assert params2["parameter2"] == dag.params["parameter2"]
-
- def test_not_none_schedule_with_non_default_params(self):
- """
- Test if there is a DAG with not None schedule and have some params that
- don't have a default value raise a error while DAG parsing
- """
- params = {"param1": Param(type="string")}
-
- with pytest.raises(AirflowException):
- DAG("dummy-dag", schedule=timedelta(days=1),
start_date=DEFAULT_DATE, params=params)
-
def test_dag_invalid_default_view(self):
"""
Test invalid `default_view` of DAG initialization
@@ -238,57 +202,6 @@ class TestDag:
dag = DAG(dag_id="test-default_orientation", schedule=None)
assert conf.get("webserver", "dag_orientation") == dag.orientation
- def test_dag_as_context_manager(self):
- """
- Test DAG as a context manager.
- When used as a context manager, Operators are automatically added to
- the DAG (unless they specify a different DAG)
- """
- dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE,
default_args={"owner": "owner1"})
- dag2 = DAG("dag2", schedule=None, start_date=DEFAULT_DATE,
default_args={"owner": "owner2"})
-
- with dag:
- op1 = EmptyOperator(task_id="op1")
- op2 = EmptyOperator(task_id="op2", dag=dag2)
-
- assert op1.dag is dag
- assert op1.owner == "owner1"
- assert op2.dag is dag2
- assert op2.owner == "owner2"
-
- with dag2:
- op3 = EmptyOperator(task_id="op3")
-
- assert op3.dag is dag2
- assert op3.owner == "owner2"
-
- with dag:
- with dag2:
- op4 = EmptyOperator(task_id="op4")
- op5 = EmptyOperator(task_id="op5")
-
- assert op4.dag is dag2
- assert op5.dag is dag
- assert op4.owner == "owner2"
- assert op5.owner == "owner1"
-
- with DAG("creating_dag_in_cm", schedule=None, start_date=DEFAULT_DATE)
as dag:
- EmptyOperator(task_id="op6")
-
- assert dag.dag_id == "creating_dag_in_cm"
- assert dag.tasks[0].task_id == "op6"
-
- with dag:
- with dag:
- op7 = EmptyOperator(task_id="op7")
- op8 = EmptyOperator(task_id="op8")
- op9 = EmptyOperator(task_id="op8")
- op9.dag = dag2
-
- assert op7.dag == dag
- assert op8.dag == dag
- assert op9.dag == dag2
-
def test_dag_topological_sort_dag_without_tasks(self):
dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE,
default_args={"owner": "owner1"})
@@ -1287,60 +1200,6 @@ class TestDag:
dag = DAG("DAG", schedule=None, default_args=default_args)
assert dag.timezone.name == local_tz.name
- def test_roots(self):
- """Verify if dag.roots returns the root tasks of a DAG."""
- with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
- op1 = EmptyOperator(task_id="t1")
- op2 = EmptyOperator(task_id="t2")
- op3 = EmptyOperator(task_id="t3")
- op4 = EmptyOperator(task_id="t4")
- op5 = EmptyOperator(task_id="t5")
- [op1, op2] >> op3 >> [op4, op5]
-
- assert set(dag.roots) == {op1, op2}
-
- def test_leaves(self):
- """Verify if dag.leaves returns the leaf tasks of a DAG."""
- with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
- op1 = EmptyOperator(task_id="t1")
- op2 = EmptyOperator(task_id="t2")
- op3 = EmptyOperator(task_id="t3")
- op4 = EmptyOperator(task_id="t4")
- op5 = EmptyOperator(task_id="t5")
- [op1, op2] >> op3 >> [op4, op5]
-
- assert set(dag.leaves) == {op4, op5}
-
- def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
- """Verify tasks with Duplicate task_id raises error"""
- with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
- op1 = EmptyOperator(task_id="t1")
- with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has
already been added to the DAG"):
- BashOperator(task_id="t1", bash_command="sleep 1")
-
- assert dag.task_dict == {op1.task_id: op1}
-
- def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
- """Verify tasks with Duplicate task_id raises error"""
- dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE)
- op1 = EmptyOperator(task_id="t1", dag=dag)
- with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has
already been added to the DAG"):
- EmptyOperator(task_id="t1", dag=dag)
-
- assert dag.task_dict == {op1.task_id: op1}
-
- def test_duplicate_task_ids_for_same_task_is_allowed(self):
- """Verify that same tasks with Duplicate task_id do not raise error"""
- with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
- op1 = op2 = EmptyOperator(task_id="t1")
- op3 = EmptyOperator(task_id="t3")
- op1 >> op3
- op2 >> op3
-
- assert op1 == op2
- assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
- assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
-
def test_partial_subset_updates_all_references_while_deepcopy(self):
with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
op1 = EmptyOperator(task_id="t1")
diff --git a/uv.lock b/uv.lock
index 3a3c2db04b..1fecc75854 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1373,6 +1373,7 @@ dependencies = [
[package.dev-dependencies]
dev = [
+ { name = "kgb" },
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-mock" },
@@ -1387,6 +1388,7 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
+ { name = "kgb", specifier = ">=7.1.1" },
{ name = "pytest", specifier = ">=8.3.3" },
{ name = "pytest-asyncio", specifier = ">=0.24.0" },
{ name = "pytest-mock", specifier = ">=3.14.0" },
@@ -2452,6 +2454,15 @@ wheels = [
{ url =
"https://files.pythonhosted.org/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl",
hash =
"sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf", size
= 18459 },
]
+[[package]]
+name = "kgb"
+version = "7.1.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url =
"https://files.pythonhosted.org/packages/0c/2e/2b608fa158cd87d7372b1d1c94d70b9b90e4ab5316c77f26feb1e4b6549f/kgb-7.1.1.tar.gz",
hash =
"sha256:74912c8761651f2063151c6c2a36ebe023393de491ec86744771a2888ab9845b", size
= 61504 }
+wheels = [
+ { url =
"https://files.pythonhosted.org/packages/80/45/ae8db25f019419b17359ca98f129c0a0d9fa40cadeaac3525b02b690e705/kgb-7.1.1-py2.py3-none-any.whl",
hash =
"sha256:ed535b25caa5d8151bb8700c653a73475a6d3937c75cd2b8ce93c84c97a86a6f", size
= 58003 },
+]
+
[[package]]
name = "lazy-object-proxy"
version = "1.10.0"