This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/task-sdk-first-code by this
push:
new 50081fada27 fixup! Move over more of BaseOperator and DAG, along with
their tests
50081fada27 is described below
commit 50081fada2795dce76aa0a6fb89ad62b97c9dd07
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Oct 16 12:56:20 2024 +0100
fixup! Move over more of BaseOperator and DAG, along with their tests
---
airflow/dag_processing/collection.py | 5 +-
airflow/decorators/base.py | 3 +-
airflow/models/abstractoperator.py | 67 --
airflow/models/baseoperator.py | 760 ++-------------------
airflow/models/dag.py | 24 +-
airflow/utils/edgemodifier.py | 10 +-
airflow/utils/task_group.py | 11 +-
pyproject.toml | 1 +
task_sdk/src/airflow/sdk/__init__.py | 2 +
.../airflow/sdk/definitions/abstractoperator.py | 84 ++-
.../src/airflow/sdk/definitions/baseoperator.py | 61 +-
.../src/airflow/sdk/definitions/contextmanager.py | 10 +-
task_sdk/src/airflow/sdk/definitions/dag.py | 50 +-
task_sdk/src/airflow/sdk/definitions/node.py | 25 +-
task_sdk/src/airflow/sdk/definitions/taskgroup.py | 22 +-
task_sdk/tests/defintions/test_baseoperator.py | 36 +-
tests/models/test_baseoperator.py | 41 +-
17 files changed, 293 insertions(+), 919 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index c8ce5dc873a..163ddf01c02 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -198,7 +198,10 @@ class DagModelOperation(NamedTuple):
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
dm.default_view = dag.default_view
- dm._dag_display_property_value = dag._dag_display_property_value
+ if hasattr(dag, "_dag_display_property_value"):
+ dm._dag_display_property_value =
dag._dag_display_property_value
+ else:
+ dm._dag_display_property_value = dag.dag_display_name
dm.description = dag.description
dm.max_active_tasks = dag.max_active_tasks
dm.max_active_runs = dag.max_active_runs
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index bb9602d50c1..e7a21e2919b 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -59,6 +59,7 @@ from airflow.models.expandinput import (
from airflow.models.mappedoperator import MappedOperator,
ensure_xcomarg_return_value
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
+from airflow.sdk.definitions.baseoperator import BaseOperator as
TaskSDKBaseOperator
from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS
@@ -442,7 +443,7 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams,
FReturn, OperatorSubcla
"is_teardown": self.is_teardown,
"on_failure_fail_dagrun": self.on_failure_fail_dagrun,
}
- base_signature = inspect.signature(BaseOperator)
+ base_signature = inspect.signature(TaskSDKBaseOperator)
ignore = {
"default_args", # This is target we are working on now.
"kwargs", # A common name for a keyword argument.
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index 5e5d13d5dc2..0c962307519 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -33,7 +33,6 @@ from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.template.templater import Templater
from airflow.utils.context import Context
from airflow.utils.db import exists_query
-from airflow.utils.log.secrets_masker import redact
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
@@ -719,72 +718,6 @@ class AbstractOperator(Templater, DAGNode):
"""
raise NotImplementedError()
- def _render(self, template, context, dag: DAG | None = None):
- if dag is None:
- dag = self.get_dag()
- return super()._render(template, context, dag=dag)
-
- def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
- """Get the template environment for rendering templates."""
- if dag is None:
- dag = self.get_dag()
- return super().get_template_env(dag=dag)
-
- def _do_render_template_fields(
- self,
- parent: Any,
- template_fields: Iterable[str],
- context: Context,
- jinja_env: jinja2.Environment,
- seen_oids: set[int],
- ) -> None:
- """Override the base to use custom error logging."""
- for attr_name in template_fields:
- try:
- value = getattr(parent, attr_name)
- except AttributeError:
- raise AttributeError(
- f"{attr_name!r} is configured as a template field "
- f"but {parent.task_type} does not have this attribute."
- )
- try:
- if not value:
- continue
- except Exception:
- # This may happen if the templated field points to a class
which does not support `__bool__`,
- # such as Pandas DataFrames:
- #
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
- self.log.info(
- "Unable to check if the value of type '%s' is False for
task '%s', field '%s'.",
- type(value).__name__,
- self.task_id,
- attr_name,
- )
- # We may still want to render custom classes which do not
support __bool__
- pass
-
- try:
- if callable(value):
- rendered_content = value(context=context,
jinja_env=jinja_env)
- else:
- rendered_content = self.render_template(
- value,
- context,
- jinja_env,
- seen_oids,
- )
- except Exception:
- value_masked = redact(name=attr_name, value=value)
- self.log.exception(
- "Exception rendering Jinja template for task '%s', field
'%s'. Template: %r",
- self.task_id,
- attr_name,
- value_masked,
- )
- raise
- else:
- setattr(parent, attr_name, rendered_content)
-
def __enter__(self):
if not self.is_setup and not self.is_teardown:
raise AirflowException("Only setup/teardown tasks can be used as
context managers.")
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 39d56187371..2a9fed1503a 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -23,17 +23,14 @@ Base operator for all operators.
from __future__ import annotations
-import abc
import collections.abc
import contextlib
import copy
import functools
-import inspect
import logging
import sys
-import warnings
from datetime import datetime, timedelta
-from functools import total_ordering, wraps
+from functools import wraps
from threading import local
from types import FunctionType
from typing import (
@@ -46,10 +43,8 @@ from typing import (
Sequence,
TypeVar,
Union,
- cast,
)
-import attr
import pendulum
from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound
@@ -57,7 +52,6 @@ from sqlalchemy.orm.exc import NoResultFound
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
- FailStopDagInvalidTriggerRule,
TaskDeferralError,
TaskDeferred,
)
@@ -83,8 +77,15 @@ from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.models.taskinstance import TaskInstance, clear_task_instances
from airflow.models.taskmixin import DependencyMixin
+from airflow.sdk.definitions.baseoperator import BaseOperator as
TaskSDKBaseOperator
+
+# Keeping this file at all is a temp thing as we migrate the repo to the task
sdk as the base, but to keep
+# main working and useful for others to develop against we use the TaskSDK
here but keep this file around
+from airflow.sdk.definitions.dag import DAG
+from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier
+from airflow.sdk.definitions.mixins import DependencyMixin as
TaskSDKDependencyMixin
from airflow.serialization.enums import DagAttributeTypes
-from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
+from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.not_previously_skipped_dep import
NotPreviouslySkippedDep
@@ -92,15 +93,11 @@ from airflow.ti_deps.deps.prev_dagrun_dep import
PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.context import Context, context_get_outlet_events
-from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.edgemodifier import EdgeModifier
-from airflow.utils.helpers import validate_instance_args, validate_key
from airflow.utils.operator_helpers import ExecutionCallableRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.setup_teardown import SetupTeardownContext
-from airflow.utils.trigger_rule import TriggerRule
-from airflow.utils.types import NOTSET, AttributeRemoved, DagRunTriggeredByType
+from airflow.utils.types import NOTSET, DagRunTriggeredByType
from airflow.utils.xcom import XCOM_RETURN_KEY
if TYPE_CHECKING:
@@ -113,7 +110,6 @@ if TYPE_CHECKING:
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DAG
from airflow.models.operator import Operator
- from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.task_group import TaskGroup
@@ -419,165 +415,16 @@ class ExecutorSafeguard:
return wrapper
-class BaseOperatorMeta(abc.ABCMeta):
- """Metaclass of BaseOperator."""
-
- @classmethod
- def _apply_defaults(cls, func: T) -> T:
- """
- Look for an argument named "default_args", and fill the unspecified
arguments from it.
-
- Since python2.* isn't clear about which arguments are missing when
- calling a function, and that this can be quite confusing with
multi-level
- inheritance and argument defaults, this decorator also alerts with
- specific information about the missing arguments.
- """
- # Cache inspect.signature for the wrapper closure to avoid calling it
- # at every decorated invocation. This is separate sig_cache created
- # per decoration, i.e. each function decorated using apply_defaults
will
- # have a different sig_cache.
- sig_cache = inspect.signature(func)
- non_variadic_params = {
- name: param
- for (name, param) in sig_cache.parameters.items()
- if param.name != "self" and param.kind not in
(param.VAR_POSITIONAL, param.VAR_KEYWORD)
- }
- non_optional_args = {
- name
- for name, param in non_variadic_params.items()
- if param.default == param.empty and name != "task_id"
- }
-
- fixup_decorator_warning_stack(func)
-
- @wraps(func)
- def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) ->
Any:
- from airflow.sdk.definitions.contextmanager import DagContext,
TaskGroupContext
-
- if args:
- raise AirflowException("Use keyword arguments when
initializing operators")
-
- instantiated_from_mapped = kwargs.pop(
- "_airflow_from_mapped",
- getattr(self, "_BaseOperator__from_mapped", False),
- )
-
- dag: DAG | None = kwargs.get("dag") or DagContext.get_current()
- task_group: TaskGroup | None = kwargs.get("task_group")
- if dag and not task_group:
- task_group = TaskGroupContext.get_current(dag)
-
- default_args, merged_params = get_merged_defaults(
- dag=dag,
- task_group=task_group,
- task_params=kwargs.pop("params", None),
- task_default_args=kwargs.pop("default_args", None),
- )
-
- for arg in sig_cache.parameters:
- if arg not in kwargs and arg in default_args:
- kwargs[arg] = default_args[arg]
-
- missing_args = non_optional_args.difference(kwargs)
- if len(missing_args) == 1:
- raise AirflowException(f"missing keyword argument
{missing_args.pop()!r}")
- elif missing_args:
- display = ", ".join(repr(a) for a in sorted(missing_args))
- raise AirflowException(f"missing keyword arguments {display}")
-
- if merged_params:
- kwargs["params"] = merged_params
-
- hook = getattr(self, "_hook_apply_defaults", None)
- if hook:
- args, kwargs = hook(**kwargs, default_args=default_args)
- default_args = kwargs.pop("default_args", {})
-
- if not hasattr(self, "_BaseOperator__init_kwargs"):
- self._BaseOperator__init_kwargs = {}
- self._BaseOperator__from_mapped = instantiated_from_mapped
-
- result = func(self, **kwargs, default_args=default_args)
-
- # Store the args passed to init -- we need them to support
task.map serialization!
- self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
-
- # Set upstream task defined by XComArgs passed to template fields
of the operator.
- # BUT: only do this _ONCE_, not once for each class in the
hierarchy
- if not instantiated_from_mapped and func ==
self.__init__.__wrapped__: # type: ignore[misc]
- self.set_xcomargs_dependencies()
- # Mark instance as instantiated.
- self._BaseOperator__instantiated = True
-
- return result
-
- apply_defaults.__non_optional_args = non_optional_args # type: ignore
- apply_defaults.__param_names = set(non_variadic_params) # type: ignore
-
- return cast(T, apply_defaults)
-
+# TODO: Task-SDK - temporarily extend the metaclass to add in the
ExecutorSafeguard.
+class BaseOperatorMeta(type(TaskSDKBaseOperator)):
def __new__(cls, name, bases, namespace, **kwargs):
execute_method = namespace.get("execute")
if callable(execute_method) and not getattr(execute_method,
"__isabstractmethod__", False):
namespace["execute"] =
ExecutorSafeguard().decorator(execute_method)
- new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
- with contextlib.suppress(KeyError):
- # Update the partial descriptor with the class method, so it calls
the actual function
- # (but let subclasses override it if they need to)
- partial_desc = vars(new_cls)["partial"]
- if isinstance(partial_desc, _PartialDescriptor):
- partial_desc.class_method = classmethod(partial)
-
- # We patch `__init__` only if the class defines it.
- if inspect.getmro(new_cls)[1].__init__ is not new_cls.__init__:
- new_cls.__init__ = cls._apply_defaults(new_cls.__init__)
-
- return new_cls
-
-
-# TODO: The following mapping is used to validate that the arguments passed to
the BaseOperator are of the
-# correct type. This is a temporary solution until we find a more
sophisticated method for argument
-# validation. One potential method is to use `get_type_hints` from the typing
module. However, this is not
-# fully compatible with future annotations for Python versions below 3.10.
Once we require a minimum Python
-# version that supports `get_type_hints` effectively or find a better
approach, we can replace this
-# manual type-checking method.
-BASEOPERATOR_ARGS_EXPECTED_TYPES = {
- "task_id": str,
- "email": (str, Iterable),
- "email_on_retry": bool,
- "email_on_failure": bool,
- "retries": int,
- "retry_exponential_backoff": bool,
- "depends_on_past": bool,
- "ignore_first_depends_on_past": bool,
- "wait_for_past_depends_before_skipping": bool,
- "wait_for_downstream": bool,
- "priority_weight": int,
- "queue": str,
- "pool": str,
- "pool_slots": int,
- "trigger_rule": str,
- "run_as_user": str,
- "task_concurrency": int,
- "map_index_template": str,
- "max_active_tis_per_dag": int,
- "max_active_tis_per_dagrun": int,
- "executor": str,
- "do_xcom_push": bool,
- "multiple_outputs": bool,
- "doc": str,
- "doc_md": str,
- "doc_json": str,
- "doc_yaml": str,
- "doc_rst": str,
- "task_display_name": str,
- "logger_name": str,
- "allow_nested_operators": bool,
-}
+ return super().__new__(cls, name, bases, namespace, **kwargs)
-@total_ordering
-class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
+class BaseOperator(TaskSDKBaseOperator, AbstractOperator,
metaclass=BaseOperatorMeta):
r"""
Abstract base class for all operators.
@@ -782,17 +629,24 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
hello_world_task.execute(context)
"""
- # Implementing Operator.
- template_fields: Sequence[str] = ()
- template_ext: Sequence[str] = ()
+ start_trigger_args: StartTriggerArgs | None = None
+ start_from_trigger: bool = False
- template_fields_renderers: dict[str, str] = {}
+ on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
+ on_skipped_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None
- # Defines the color in the UI
- ui_color: str = "#fff"
- ui_fgcolor: str = "#000"
+ def __init__(self, pre_execute=None, post_execute=None, **kwargs):
+ if start_date := kwargs.get("start_date", None):
+ kwargs["start_date"] = timezone.convert_to_utc(start_date)
- pool: str = ""
+ if end_date := kwargs.get("end_date", None):
+ kwargs["end_date"] = timezone.convert_to_utc(end_date)
+ super().__init__(**kwargs)
+ self._pre_execute_hook = pre_execute
+ self._post_execute_hook = post_execute
# base list which includes all the attrs that don't need deep copy.
_base_operator_shallow_copy_attrs: tuple[str, ...] = (
@@ -807,368 +661,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
# Defines the operator level extra links
operator_extra_links: Collection[BaseOperatorLink] = ()
- # The _serialized_fields are lazily loaded when get_serialized_fields()
method is called
- __serialized_fields: frozenset[str] | None = None
-
partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type:
ignore
- _comps = {
- "task_id",
- "dag_id",
- "owner",
- "email",
- "email_on_retry",
- "retry_delay",
- "retry_exponential_backoff",
- "max_retry_delay",
- "start_date",
- "end_date",
- "depends_on_past",
- "wait_for_downstream",
- "priority_weight",
- "sla",
- "execution_timeout",
- "on_execute_callback",
- "on_failure_callback",
- "on_success_callback",
- "on_retry_callback",
- "on_skipped_callback",
- "do_xcom_push",
- "multiple_outputs",
- "allow_nested_operators",
- "executor",
- }
-
- # Defines if the operator supports lineage without manual definitions
- supports_lineage = False
-
- # If True then the class constructor was called
- __instantiated = False
- # List of args as passed to `init()`, after apply_defaults() has been
updated. Used to "recreate" the task
- # when mapping
- __init_kwargs: dict[str, Any]
-
- # Set to True before calling execute method
- _lock_for_execution = False
-
- _dag: DAG | None = None
- task_group: TaskGroup | None = None
-
- start_date: pendulum.DateTime | None = None
- end_date: pendulum.DateTime | None = None
-
- # Set to True for an operator instantiated by a mapped operator.
- __from_mapped = False
-
- start_trigger_args: StartTriggerArgs | None = None
- start_from_trigger: bool = False
-
- def __init__(
- self,
- task_id: str,
- owner: str = DEFAULT_OWNER,
- email: str | Iterable[str] | None = None,
- email_on_retry: bool = conf.getboolean("email",
"default_email_on_retry", fallback=True),
- email_on_failure: bool = conf.getboolean("email",
"default_email_on_failure", fallback=True),
- retries: int | None = DEFAULT_RETRIES,
- retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
- retry_exponential_backoff: bool = False,
- max_retry_delay: timedelta | float | None = None,
- start_date: datetime | None = None,
- end_date: datetime | None = None,
- depends_on_past: bool = False,
- ignore_first_depends_on_past: bool =
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
- wait_for_past_depends_before_skipping: bool =
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
- wait_for_downstream: bool = False,
- dag: DAG | None = None,
- params: collections.abc.MutableMapping | None = None,
- default_args: dict | None = None,
- priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
- weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
- queue: str = DEFAULT_QUEUE,
- pool: str | None = None,
- pool_slots: int = DEFAULT_POOL_SLOTS,
- sla: timedelta | None = None,
- execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
- on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- on_retry_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- on_skipped_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
- pre_execute: TaskPreExecuteHook | None = None,
- post_execute: TaskPostExecuteHook | None = None,
- trigger_rule: str = DEFAULT_TRIGGER_RULE,
- resources: dict[str, Any] | None = None,
- run_as_user: str | None = None,
- map_index_template: str | None = None,
- max_active_tis_per_dag: int | None = None,
- max_active_tis_per_dagrun: int | None = None,
- executor: str | None = None,
- executor_config: dict | None = None,
- do_xcom_push: bool = True,
- multiple_outputs: bool = False,
- inlets: Any | None = None,
- outlets: Any | None = None,
- task_group: TaskGroup | None = None,
- doc: str | None = None,
- doc_md: str | None = None,
- doc_json: str | None = None,
- doc_yaml: str | None = None,
- doc_rst: str | None = None,
- task_display_name: str | None = None,
- logger_name: str | None = None,
- allow_nested_operators: bool = True,
- **kwargs,
- ):
- from airflow.sdk.definitions.contextmanager import DagContext,
TaskGroupContext
-
- self.__init_kwargs = {}
-
- super().__init__()
-
- kwargs.pop("_airflow_mapped_validation_only", None)
- if kwargs:
- raise AirflowException(
- f"Invalid arguments were passed to {self.__class__.__name__}
(task_id: {task_id}). "
- f"Invalid arguments were:\n**kwargs: {kwargs}",
- )
- validate_key(task_id)
-
- dag = dag or DagContext.get_current()
- task_group = task_group or TaskGroupContext.get_current(dag)
-
- self.task_id = task_group.child_id(task_id) if task_group else task_id
- if not self.__from_mapped and task_group:
- task_group.add(self)
-
- self.owner = owner
- self.email = email
- self.email_on_retry = email_on_retry
- self.email_on_failure = email_on_failure
-
- if execution_timeout is not None and not isinstance(execution_timeout,
timedelta):
- raise ValueError(
- f"execution_timeout must be timedelta object but passed as
type: {type(execution_timeout)}"
- )
- self.execution_timeout = execution_timeout
-
- self.on_execute_callback = on_execute_callback
- self.on_failure_callback = on_failure_callback
- self.on_success_callback = on_success_callback
- self.on_retry_callback = on_retry_callback
- self.on_skipped_callback = on_skipped_callback
- self._pre_execute_hook = pre_execute
- self._post_execute_hook = post_execute
-
- if start_date and not isinstance(start_date, datetime):
- self.log.warning("start_date for %s isn't datetime.datetime", self)
- elif start_date:
- self.start_date = timezone.convert_to_utc(start_date)
-
- if end_date:
- self.end_date = timezone.convert_to_utc(end_date)
-
- self.executor = executor
- self.executor_config = executor_config or {}
- self.run_as_user = run_as_user
- self.retries = parse_retries(retries)
- self.queue = queue
- self.pool = Pool.DEFAULT_POOL_NAME if pool is None else pool
- self.pool_slots = pool_slots
- if self.pool_slots < 1:
- dag_str = f" in dag {dag.dag_id}" if dag else ""
- raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot
be less than 1")
-
- if sla:
- self.log.warning(
- "The SLA feature is removed in Airflow 3.0, to be replaced
with a new implementation in 3.1"
- )
-
- if not TriggerRule.is_valid(trigger_rule):
- raise AirflowException(
- f"The trigger_rule must be one of
{TriggerRule.all_triggers()},"
- f"'{dag.dag_id if dag else ''}.{task_id}'; received
'{trigger_rule}'."
- )
-
- self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
- FailStopDagInvalidTriggerRule.check(dag=dag,
trigger_rule=self.trigger_rule)
-
- self.depends_on_past: bool = depends_on_past
- self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
- self.wait_for_past_depends_before_skipping: bool =
wait_for_past_depends_before_skipping
- self.wait_for_downstream: bool = wait_for_downstream
- if wait_for_downstream:
- self.depends_on_past = True
-
- self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay")
- self.retry_exponential_backoff = retry_exponential_backoff
- self.max_retry_delay = (
- max_retry_delay
- if max_retry_delay is None
- else coerce_timedelta(max_retry_delay, key="max_retry_delay")
- )
-
- # At execution_time this becomes a normal dict
- self.params: ParamsDict | dict = ParamsDict(params)
- if priority_weight is not None and not isinstance(priority_weight,
int):
- raise AirflowException(
- f"`priority_weight` for task '{self.task_id}' only accepts
integers, "
- f"received '{type(priority_weight)}'."
- )
- self.priority_weight = priority_weight
- self.weight_rule =
validate_and_load_priority_weight_strategy(weight_rule)
- self.resources = coerce_resources(resources)
- self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
- self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
- self.do_xcom_push: bool = do_xcom_push
- self.map_index_template: str | None = map_index_template
- self.multiple_outputs: bool = multiple_outputs
-
- self.doc_md = doc_md
- self.doc_json = doc_json
- self.doc_yaml = doc_yaml
- self.doc_rst = doc_rst
- self.doc = doc
- # Populate the display field only if provided and different from task
id
- self._task_display_property_value = (
- task_display_name if task_display_name and task_display_name !=
task_id else None
- )
-
- self.upstream_task_ids: set[str] = set()
- self.downstream_task_ids: set[str] = set()
-
- if dag:
- self.dag = dag
-
- self._log_config_logger_name = "airflow.task.operators"
- self._logger_name = logger_name
- self.allow_nested_operators: bool = allow_nested_operators
-
- # Lineage
- self.inlets: list = []
- self.outlets: list = []
-
- if inlets:
- self.inlets = (
- inlets
- if isinstance(inlets, list)
- else [
- inlets,
- ]
- )
-
- if outlets:
- self.outlets = (
- outlets
- if isinstance(outlets, list)
- else [
- outlets,
- ]
- )
-
- if isinstance(self.template_fields, str):
- warnings.warn(
- f"The `template_fields` value for {self.task_type} is a string
"
- "but should be a list or tuple of string. Wrapping it in a
list for execution. "
- f"Please update {self.task_type} accordingly.",
- UserWarning,
- stacklevel=2,
- )
- self.template_fields = [self.template_fields]
-
- self._is_setup = False
- self._is_teardown = False
- if SetupTeardownContext.active:
- SetupTeardownContext.update_context_map(self)
-
- validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
-
- def __eq__(self, other):
- if type(self) is type(other):
- # Use getattr() instead of __dict__ as __dict__ doesn't return
- # correct values for properties.
- return all(getattr(self, c, None) == getattr(other, c, None) for c
in self._comps)
- return False
-
- def __ne__(self, other):
- return not self == other
-
- def __hash__(self):
- hash_components = [type(self)]
- for component in self._comps:
- val = getattr(self, component, None)
- try:
- hash(val)
- hash_components.append(val)
- except TypeError:
- hash_components.append(repr(val))
- return hash(tuple(hash_components))
-
- # including lineage information
- def __or__(self, other):
- """
- Return [This Operator] | [Operator].
-
- The inlets of other will be set to pick up the outlets from this
operator.
- Other will be set as a downstream task of this operator.
- """
- if isinstance(other, BaseOperator):
- if not self.outlets and not self.supports_lineage:
- raise ValueError("No outlets defined for this operator")
- other.add_inlets([self.task_id])
- self.set_downstream(other)
- else:
- raise TypeError(f"Right hand side ({other}) is not an Operator")
-
- return self
-
- # /Composing Operators ---------------------------------------------
-
- def __gt__(self, other):
- """
- Return [Operator] > [Outlet].
-
- If other is an attr annotated object it is set as an outlet of this
Operator.
- """
- if not isinstance(other, Iterable):
- other = [other]
-
- for obj in other:
- if not attr.has(obj):
- raise TypeError(f"Left hand side ({obj}) is not an outlet")
- self.add_outlets(other)
-
- return self
-
- def __lt__(self, other):
- """
- Return [Inlet] > [Operator] or [Operator] < [Inlet].
-
- If other is an attr annotated object it is set as an inlet to this
operator.
- """
- if not isinstance(other, Iterable):
- other = [other]
-
- for obj in other:
- if not attr.has(obj):
- raise TypeError(f"{obj} cannot be an inlet")
- self.add_inlets(other)
-
- return self
-
- def __setattr__(self, key, value):
- super().__setattr__(key, value)
- if self.__from_mapped or self._lock_for_execution:
- return # Skip any custom behavior for validation and during
execute.
- if key in self.__init_kwargs:
- self.__init_kwargs[key] = value
- if self.__instantiated and key in self.template_fields:
- # Resolve upstreams set by assigning an XComArg after initializing
- # an operator, example:
- # op = BashOperator()
- # op.bash_command = "sleep 1"
- self.set_xcomargs_dependencies()
-
def add_inlets(self, inlets: Iterable[Any]):
"""Set inlets to this operator."""
self.inlets.extend(inlets)
@@ -1193,55 +687,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
"""
return self.outlets
- def get_dag(self) -> DAG | None:
- return self._dag
-
- @property # type: ignore[override]
- def dag(self) -> DAG: # type: ignore[override]
- """Returns the Operator's DAG if set, otherwise raises an error."""
- if self._dag:
- return self._dag
- else:
- raise AirflowException(f"Operator {self} has not been assigned to
a DAG yet")
-
- @dag.setter
- def dag(self, dag: DAG | None):
- """Operators can be assigned to one DAG, one time. Repeat assignments
to that same DAG are ok."""
- if dag is None:
- self._dag = None
- return
-
- # if set to removed, then just set and exit
- if self._dag.__class__ is AttributeRemoved:
- self._dag = dag
- return
- # if setting to removed, then just set and exit
- if dag.__class__ is AttributeRemoved:
- self._dag = AttributeRemoved("_dag") # type: ignore[assignment]
- return
-
- from airflow.models.dag import DAG
-
- if not isinstance(dag, DAG):
- raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
- elif self.has_dag() and self.dag is not dag:
- raise AirflowException(f"The DAG assigned to {self} can not be
changed.")
-
- if self.__from_mapped:
- pass # Don't add to DAG -- the mapped task takes the place.
- elif dag.task_dict.get(self.task_id) is not self:
- dag.add_task(self)
-
- self._dag = dag
-
- @property
- def task_display_name(self) -> str:
- return self._task_display_property_value or self.task_id
-
- def has_dag(self):
- """Return True if the Operator has been assigned to a DAG."""
- return self._dag is not None
-
deps: frozenset[BaseTIDep] = frozenset(
{
NotInRetryPeriodDep(),
@@ -1263,33 +708,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
other._lock_for_execution = True
return other
- def set_xcomargs_dependencies(self) -> None:
- """
- Resolve upstream dependencies of a task.
-
- In this way passing an ``XComArg`` as value for a template field
- will result in creating upstream relation between two tasks.
-
- **Example**: ::
-
- with DAG(...):
- generate_content =
GenerateContentOperator(task_id="generate_content")
- send_email = EmailOperator(...,
html_content=generate_content.output)
-
- # This is equivalent to
- with DAG(...):
- generate_content =
GenerateContentOperator(task_id="generate_content")
- send_email = EmailOperator(..., html_content="{{
task_instance.xcom_pull('generate_content') }}")
- generate_content >> send_email
-
- """
- from airflow.models.xcom_arg import XComArg
-
- for field in self.template_fields:
- if hasattr(self, field):
- arg = getattr(self, field)
- XComArg.apply_upstream_relationship(self, arg)
-
@prepare_lineage
def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
@@ -1326,14 +744,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
logger=self.log,
).run(context, result)
- def on_kill(self) -> None:
- """
- Override this method to clean up subprocesses when a task instance
gets killed.
-
- Any use of the threading, subprocess or multiprocessing module within
an
- operator needs to be cleaned up, or it will leave ghost processes
behind.
- """
-
def __deepcopy__(self, memo):
# Hack sorting double chained task lists by task_id to avoid hitting
# max_depth on deepcopy operations.
@@ -1518,43 +928,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
else:
return self.downstream_list
- def __repr__(self):
- return f"<Task({self.task_type}): {self.task_id}>"
-
- @property
- def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
- return self.__class__
-
- @property
- def task_type(self) -> str:
- """@property: type of the task."""
- return self.__class__.__name__
-
- @property
- def operator_name(self) -> str:
- """@property: use a more friendly display name for the operator, if
set."""
- try:
- return self.custom_operator_name # type: ignore
- except AttributeError:
- return self.task_type
-
- @property
- def roots(self) -> list[BaseOperator]:
- """Required by DAGNode."""
- return [self]
-
- @property
- def leaves(self) -> list[BaseOperator]:
- """Required by DAGNode."""
- return [self]
-
- @property
- def output(self) -> XComArg:
- """Returns reference to XCom pushed by current operator."""
- from airflow.models.xcom_arg import XComArg
-
- return XComArg(operator=self)
-
@property
def is_setup(self) -> bool:
"""
@@ -1655,68 +1028,10 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
session=session,
)
- @classmethod
- def get_serialized_fields(cls):
- """Stringified DAGs and operators contain exactly these fields."""
- if not cls.__serialized_fields:
- from airflow.sdk.definitions.contextmanager import DagContext
-
- # make sure the following dummy task is not added to current active
- # dag in context, otherwise, it will result in
- # `RuntimeError: dictionary changed size during iteration`
- # Exception in SerializedDAG.serialize_dag() call.
- DagContext.push(None)
- cls.__serialized_fields = frozenset(
- vars(BaseOperator(task_id="test")).keys()
- - {
- "upstream_task_ids",
- "default_args",
- "dag",
- "_dag",
- "label",
- "_BaseOperator__instantiated",
- "_BaseOperator__init_kwargs",
- "_BaseOperator__from_mapped",
- "_is_setup",
- "_is_teardown",
- "_on_failure_fail_dagrun",
- }
- | { # Class level defaults need to be added to this list
- "start_date",
- "end_date",
- "_task_type",
- "_operator_name",
- "ui_color",
- "ui_fgcolor",
- "template_ext",
- "template_fields",
- "template_fields_renderers",
- "params",
- "is_setup",
- "is_teardown",
- "on_failure_fail_dagrun",
- "map_index_template",
- "start_trigger_args",
- "_needs_expansion",
- "start_from_trigger",
- }
- )
- DagContext.pop()
-
- return cls.__serialized_fields
-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Serialize; required by DAGNode."""
return DagAttributeTypes.OP, self.task_id
- @property
- def inherits_from_empty_operator(self):
- """Used to determine if an Operator is inherited from EmptyOperator."""
- # This looks like `isinstance(self, EmptyOperator) would work, but
this also
- # needs to cope when `self` is a Serialized instance of a
EmptyOperator or one
- # of its subclasses (which don't inherit from anything but
BaseOperator).
- return getattr(self, "_is_empty", False)
-
def defer(
self,
*,
@@ -1786,11 +1101,12 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return self.start_trigger_args
-# TODO: Deprecate for Airflow 3.0
-Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
+# TODO: Task-SDK: remove before Airflow 3.0
+# Temp for migration to Task-SDK only
+_HasDependency = Union[DependencyMixin | TaskSDKDependencyMixin]
-def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
+def chain(*tasks: _HasDependency | Sequence[DependencyMixin]) -> None:
r"""
Given a number of tasks, builds a dependency chain.
@@ -1899,10 +1215,10 @@ def chain(*tasks: DependencyMixin |
Sequence[DependencyMixin]) -> None:
:param tasks: Individual and/or list of tasks, EdgeModifiers, XComArgs, or
TaskGroups to set dependencies
"""
for up_task, down_task in zip(tasks, tasks[1:]):
- if isinstance(up_task, DependencyMixin):
+ if isinstance(up_task, (DependencyMixin, TaskSDKDependencyMixin)):
up_task.set_downstream(down_task)
continue
- if isinstance(down_task, DependencyMixin):
+ if isinstance(down_task, (DependencyMixin, TaskSDKDependencyMixin)):
down_task.set_upstream(up_task)
continue
if not isinstance(up_task, Sequence) or not isinstance(down_task,
Sequence):
@@ -2040,13 +1356,15 @@ def chain_linear(*elements: DependencyMixin |
Sequence[DependencyMixin]):
prev_elem = None
deps_set = False
for curr_elem in elements:
- if isinstance(curr_elem, EdgeModifier):
+ if isinstance(curr_elem, (EdgeModifier, TaskSDKEdgeModifier)):
raise ValueError("Labels are not supported by chain_linear")
if prev_elem is not None:
for task in prev_elem:
task >> curr_elem
if not deps_set:
deps_set = True
- prev_elem = [curr_elem] if isinstance(curr_elem, DependencyMixin) else
curr_elem
+ prev_elem = (
+ [curr_elem] if isinstance(curr_elem, (DependencyMixin,
TaskSDKDependencyMixin)) else curr_elem
+ )
if not deps_set:
raise ValueError("No dependencies were set. Did you forget to expand
with `*`?")
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 290f42ca387..8cb0e07d137 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -26,7 +26,7 @@ import pickle
import sys
import time
import traceback
-from collections import abc, defaultdict, deque
+from collections import abc, defaultdict
from contextlib import ExitStack
from datetime import datetime, timedelta
from inspect import signature
@@ -44,6 +44,7 @@ from typing import (
overload,
)
+import attrs
import jinja2
import pendulum
import re2
@@ -117,14 +118,13 @@ from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.helpers import exactly_one
+from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, lock_rows,
tuple_in_condition, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
if TYPE_CHECKING:
- from types import ModuleType
-
from pendulum.tz.timezone import FixedTimezone, Timezone
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
@@ -345,7 +345,8 @@ DAG_ARGS_EXPECTED_TYPES = {
@functools.total_ordering
-class DAG(TaskSDKDag):
[email protected](hash=False, repr=False)
+class DAG(TaskSDKDag, LoggingMixin):
"""
A dag (directed acyclic graph) is a collection of tasks with directional
dependencies.
@@ -459,6 +460,13 @@ class DAG(TaskSDKDag):
:param dag_display_name: The display name of the DAG which appears on the
UI.
"""
+ partial: bool = False
+ last_loaded: datetime | None = None
+ on_success_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None
+ on_failure_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None
+
+ __hash__ = TaskSDKDag.__hash__
+
def validate_executor_field(self):
for task in self.tasks:
if task.executor:
@@ -762,7 +770,7 @@ class DAG(TaskSDKDag):
@property
def default_view(self) -> str:
- return self._default_view
+ return "grid"
@property
def pickle_id(self) -> int | None:
@@ -2566,15 +2574,11 @@ if STATICA_HACK: # pragma: no cover
""":sphinx-autoapi-skip:"""
-class DagContext(airflow.sdk.definitions.contextmanager.DagContext):
+class DagContext(airflow.sdk.definitions.contextmanager.DagContext,
share_parent_context=True):
"""
:meta private:
"""
- _context_managed_dags: deque[DAG] = deque()
- autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
- current_autoregister_module_name: str | None = None
-
@classmethod
def push_context_managed_dag(cls, dag: DAG):
cls.push(dag)
diff --git a/airflow/utils/edgemodifier.py b/airflow/utils/edgemodifier.py
index a78e6c64999..e4345cb471b 100644
--- a/airflow/utils/edgemodifier.py
+++ b/airflow/utils/edgemodifier.py
@@ -19,6 +19,8 @@ from __future__ import annotations
from typing import Sequence
from airflow.models.taskmixin import DAGNode, DependencyMixin
+from airflow.sdk.definitions.node import DAGNode as TaskSDKDagNode
+from airflow.sdk.definitions.taskgroup import TaskGroup as TaskSDKTaskGroup
from airflow.utils.task_group import TaskGroup
@@ -68,7 +70,7 @@ class EdgeModifier(DependencyMixin):
from airflow.models.xcom_arg import XComArg
for node in self._make_list(nodes):
- if isinstance(node, (TaskGroup, XComArg, DAGNode)):
+ if isinstance(node, (XComArg, DAGNode, TaskSDKDagNode)):
stream.append(node)
else:
raise TypeError(
@@ -95,10 +97,10 @@ class EdgeModifier(DependencyMixin):
group_ids.add("root")
else:
group_ids.add(node.task_group.group_id)
- elif isinstance(node, TaskGroup):
+ elif isinstance(node, (TaskGroup, TaskSDKTaskGroup)):
group_ids.add(node.group_id)
elif isinstance(node, XComArg):
- if isinstance(node.operator, DAGNode) and
node.operator.task_group:
+ if isinstance(node.operator, (DAGNode, TaskSDKDagNode)) and
node.operator.task_group:
if node.operator.task_group.is_root:
group_ids.add("root")
else:
@@ -112,7 +114,7 @@ class EdgeModifier(DependencyMixin):
def _convert_stream_to_task_groups(self, stream:
Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
return [
node.task_group
- if isinstance(node, DAGNode) and node.task_group and not
node.task_group.is_root
+ if isinstance(node, (DAGNode, TaskSDKDagNode)) and node.task_group
and not node.task_group.is_root
else node
for node in stream
]
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 79b6329e44b..acbd8183052 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -36,6 +36,7 @@ from airflow.exceptions import (
TaskAlreadyInTaskGroup,
)
from airflow.models.taskmixin import DAGNode
+from airflow.sdk.definitions.node import DAGNode as TaskSDKDagNode
from airflow.serialization.enums import DagAttributeTypes
from airflow.utils.helpers import validate_group_key, validate_instance_args
@@ -275,7 +276,7 @@ class TaskGroup(DAGNode):
@property
def group_id(self) -> str | None:
"""group_id of this TaskGroup."""
- if self.task_group and self.task_group.prefix_group_id and
self.task_group._group_id:
+ if self.task_group and self.task_group.prefix_group_id and
self.task_group.node_id:
# defer to parent whether it adds a prefix
return self.task_group.child_id(self._group_id)
@@ -311,7 +312,7 @@ class TaskGroup(DAGNode):
else:
# Handles setting relationship between a TaskGroup and a task
for task in other.roots:
- if not isinstance(task, DAGNode):
+ if not isinstance(task, (DAGNode, TaskSDKDagNode)):
raise AirflowException(
"Relationships can only be set between TaskGroup "
f"or operators; received {task.__class__.__name__}"
@@ -651,13 +652,13 @@ class MappedTaskGroup(TaskGroup):
super().__exit__(exc_type, exc_val, exc_tb)
-class
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext):
+class
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext,
share_parent_context=True):
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup
is used as ContextManager."""
@classmethod
def push_context_managed_task_group(cls, task_group: TaskGroup):
"""Push a TaskGroup into the list of managed TaskGroups."""
- return cls.pusg(task_group)
+ return cls.push(task_group)
@classmethod
def pop_context_managed_task_group(cls) -> TaskGroup | None:
@@ -667,7 +668,7 @@ class
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext):
@classmethod
def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
"""Get the current TaskGroup."""
- return cls.get_current()
+ return cls.get_current(dag)
def task_group_to_dict(task_item_or_group):
diff --git a/pyproject.toml b/pyproject.toml
index 3a3bde8e23a..63b2958dad2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -468,6 +468,7 @@ fixture-parentheses = false
## pytest settings ##
[tool.pytest.ini_options]
addopts = [
+ "--tb=short",
"-rasl",
"--verbosity=2",
# Disable `flaky` plugin for pytest. This plugin conflicts with
`rerunfailures` because provide the same marker.
diff --git a/task_sdk/src/airflow/sdk/__init__.py
b/task_sdk/src/airflow/sdk/__init__.py
index baf7c85baa9..5fea295d981 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -23,12 +23,14 @@ __all__ = ["DAG", "BaseOperator", "TaskGroup"]
if TYPE_CHECKING:
from airflow.sdk.definitions.baseoperator import BaseOperator as
BaseOperator
from airflow.sdk.definitions.dag import DAG as DAG
+ from airflow.sdk.definitions.edges import EdgeModifier as EdgeModifier
from airflow.sdk.definitions.taskgroup import TaskGroup as TaskGroup
__lazy_imports: dict[str, str] = {
"DAG": ".definitions.dag",
"BaseOperator": ".definitions.baseoperator",
"TaskGroup": ".definitions.taskgroup",
+ "EdgeModifier": ".definitions.edges",
}
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
index 2dd0d488040..56baa258618 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
@@ -21,6 +21,7 @@ import datetime
from abc import abstractmethod
from collections.abc import (
Collection,
+ Iterable,
)
from typing import (
TYPE_CHECKING,
@@ -28,16 +29,22 @@ from typing import (
ClassVar,
)
-from .node import DAGNode
+from airflow.sdk.definitions.node import DAGNode
+from airflow.utils.log.secrets_masker import redact
# TaskStateChangeCallback = Callable[[Context], None]
if TYPE_CHECKING:
- from airflow.models.baseoperator import BaseOperator
+ import jinja2 # Slow import.
+
from airflow.models.baseoperatorlink import BaseOperatorLink
+ from airflow.sdk.definitions.baseoperator import BaseOperator
+ from airflow.sdk.definitions.dag import DAG
from airflow.task.priority_strategy import PriorityWeightStrategy
- from .dag import DAG
+ # TODO: Task-SDK
+ Context = dict[str, Any]
+
DEFAULT_OWNER: str = "airflow"
DEFAULT_POOL_SLOTS: int = 1
@@ -47,11 +54,12 @@ DEFAULT_QUEUE: str = "default"
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = False
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING: bool = False
DEFAULT_RETRIES: int = 0
-DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(300)
+DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300)
MAX_RETRY_DELAY: int = 24 * 60 * 60
+# TODO: Task-SDK
# DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
-DEFAULT_TRIGGER_RULE = "ALL_SUCCESS" # TriggerRule.ALL_SUCCESS
+DEFAULT_TRIGGER_RULE = "all_success"
DEFAULT_WEIGHT_RULE = "downstream"
DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None
@@ -155,3 +163,69 @@ class AbstractOperator(DAGNode):
# "task_group_id.task_id" -> "task_id"
return self.task_id[len(tg.node_id) + 1 :]
return self.task_id
+
+ def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
+ """Get the template environment for rendering templates."""
+ if dag is None:
+ dag = self.get_dag()
+ return super().get_template_env(dag=dag)
+
+ def _render(self, template, context, dag: DAG | None = None):
+ if dag is None:
+ dag = self.get_dag()
+ return super()._render(template, context, dag=dag)
+
+ def _do_render_template_fields(
+ self,
+ parent: Any,
+ template_fields: Iterable[str],
+ context: Context,
+ jinja_env: jinja2.Environment,
+ seen_oids: set[int],
+ ) -> None:
+ """Override the base to use custom error logging."""
+ for attr_name in template_fields:
+ try:
+ value = getattr(parent, attr_name)
+ except AttributeError:
+ raise AttributeError(
+ f"{attr_name!r} is configured as a template field "
+ f"but {parent.task_type} does not have this attribute."
+ )
+ try:
+ if not value:
+ continue
+ except Exception:
+ # This may happen if the templated field points to a class
which does not support `__bool__`,
+ # such as Pandas DataFrames:
+ #
https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
+ self.log.info(
+ "Unable to check if the value of type '%s' is False for
task '%s', field '%s'.",
+ type(value).__name__,
+ self.task_id,
+ attr_name,
+ )
+ # We may still want to render custom classes which do not
support __bool__
+ pass
+
+ try:
+ if callable(value):
+ rendered_content = value(context=context,
jinja_env=jinja_env)
+ else:
+ rendered_content = self.render_template(
+ value,
+ context,
+ jinja_env,
+ seen_oids,
+ )
+ except Exception:
+ value_masked = redact(name=attr_name, value=value)
+ self.log.exception(
+ "Exception rendering Jinja template for task '%s', field
'%s'. Template: %r",
+ self.task_id,
+ attr_name,
+ value_masked,
+ )
+ raise
+ else:
+ setattr(parent, attr_name, rendered_content)
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index 6480a897807..3a70f910305 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -22,6 +22,7 @@ import collections.abc
import contextlib
import copy
import inspect
+import warnings
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
@@ -29,6 +30,9 @@ from functools import total_ordering, wraps
from types import FunctionType
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
+import attrs
+
+from airflow.exceptions import FailStopDagInvalidTriggerRule
from airflow.sdk.definitions.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
DEFAULT_OWNER,
@@ -47,6 +51,7 @@ from airflow.sdk.definitions.decorators import
fixup_decorator_warning_stack
from airflow.sdk.definitions.node import validate_key
from airflow.sdk.types import NOTSET, validate_instance_args
from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
+from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import AttributeRemoved
T = TypeVar("T", bound=FunctionType)
@@ -58,8 +63,9 @@ if TYPE_CHECKING:
class ParamsDict: ...
from airflow.sdk.definitions.dag import DAG
- from airflow.sdk.definitions.taskgroup import TaskGroup
+ from airflow.utils.operator_resources import Resources
+from airflow.sdk.definitions.taskgroup import TaskGroup
# TODO: Task-SDK
AirflowException = RuntimeError
@@ -489,6 +495,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
task_id: str
owner: str = DEFAULT_OWNER
email: str | Sequence[str] | None = None
+ email_on_retry: bool = True
+ email_on_failure: bool = True
retries: int | None = DEFAULT_RETRIES
retry_delay: timedelta | float = DEFAULT_RETRY_DELAY
retry_exponential_backoff: bool = False
@@ -500,7 +508,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
wait_for_past_depends_before_skipping: bool =
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
wait_for_downstream: bool = False
dag: DAG | None = None
- params: MutableMapping | None = None
+ params: collections.abc.MutableMapping | None = None
default_args: dict | None = None
priority_weight: int = DEFAULT_PRIORITY_WEIGHT
# TODO:
@@ -618,6 +626,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
task_id: str,
owner: str = DEFAULT_OWNER,
email: str | Sequence[str] | None = None,
+ email_on_retry: bool = True,
+ email_on_failure: bool = True,
retries: int | None = DEFAULT_RETRIES,
retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
retry_exponential_backoff: bool = False,
@@ -670,14 +680,16 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
):
from airflow.sdk.definitions.contextmanager import DagContext,
TaskGroupContext
+ dag = dag or DagContext.get_current()
+ task_group = task_group or TaskGroupContext.get_current(dag)
+
self.task_id = task_group.child_id(task_id) if task_group else task_id
if not self.__from_mapped and task_group:
task_group.add(self)
- dag = dag or DagContext.get_current()
- task_group = task_group or TaskGroupContext.get_current(dag)
-
- super().__init__(dag=dag, task_group=task_group)
+ super().__init__()
+ self.dag = dag
+ self.task_group = task_group
kwargs.pop("_airflow_mapped_validation_only", None)
if kwargs:
@@ -689,6 +701,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
self.owner = owner
self.email = email
+ self.email_on_retry = email_on_retry
+ self.email_on_failure = email_on_failure
if execution_timeout is not None and not isinstance(execution_timeout,
timedelta):
raise ValueError(
@@ -706,10 +720,14 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
# self._post_execute_hook = post_execute
if start_date:
- self.start_date = timezone.convert_to_utc(start_date)
+ # TODO: Task-SDK
+ # self.start_date = timezone.convert_to_utc(start_date)
+ self.start_date = start_date
if end_date:
- self.end_date = timezone.convert_to_utc(end_date)
+ # TODO: Task-SDK
+ # self.end_date = timezone.convert_to_utc(end_date)
+ self.end_date = end_date
if executor:
warnings.warn(
@@ -731,16 +749,14 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
raise ValueError(f"pool slots for {self.task_id}{dag_str} cannot
be less than 1")
self.sla = sla
- """
- # if not TriggerRule.is_valid(trigger_rule):
- # raise AirflowException(
- # f"The trigger_rule must be one of
{TriggerRule.all_triggers()},"
- # f"'{dag.dag_id if dag else ''}.{task_id}'; received
'{trigger_rule}'."
- # )
- #
- # self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
- # FailStopDagInvalidTriggerRule.check(dag=dag,
trigger_rule=self.trigger_rule)
- """
+ if not TriggerRule.is_valid(trigger_rule):
+ raise ValueError(
+ f"The trigger_rule must be one of
{TriggerRule.all_triggers()},"
+ f"'{dag.dag_id if dag else ''}.{task_id}'; received
'{trigger_rule}'."
+ )
+
+ self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
+ FailStopDagInvalidTriggerRule.check(dag=dag,
trigger_rule=self.trigger_rule)
self.depends_on_past: bool = depends_on_past
self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
@@ -892,7 +908,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
other = [other]
for obj in other:
- if not attr.has(obj):
+ if not attrs.has(obj):
raise TypeError(f"{obj} cannot be an inlet")
self.add_inlets(other)
@@ -947,8 +963,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
return parsed_retries
@staticmethod
- def _convert_timedelta(value: float | timedelta) -> timedelta:
- if isinstance(value, timedelta):
+ def _convert_timedelta(value: float | timedelta | None) -> timedelta |
None:
+ if value is None or isinstance(value, timedelta):
return value
return timedelta(seconds=value)
@@ -959,6 +975,9 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
def _convert_resources(resources: dict[str, Any] | None) -> Resources |
None:
if resources is None:
return None
+
+ from airflow.utils.operator_resources import Resources
+
return Resources(**resources)
@property
diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
index d97339fc265..47edde0b5fe 100644
--- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
+++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py
@@ -28,17 +28,16 @@ T = TypeVar("T")
class ContextStack(Generic[T]):
- active: bool = False
_context: deque[T]
- def __init_subclass__(cls) -> None:
- cls._context = deque()
+ def __init_subclass__(cls, /, **kwargs) -> None:
+ if not kwargs.get("share_parent_context", False):
+ cls._context = deque()
return super().__init_subclass__()
@classmethod
def push(cls, obj: T):
cls._context.appendleft(obj)
- cls.active = True
@classmethod
def pop(cls) -> T | None:
@@ -52,7 +51,8 @@ class ContextStack(Generic[T]):
return None
@classmethod
- def is_active(cls) -> bool:
+ @property
+ def active(cls) -> bool:
"""The active property says if any object is currently in scope."""
try:
cls._context[0]
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 5cf76da4882..6e62b4524f7 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -59,6 +59,7 @@ from airflow.exceptions import (
from airflow.models.param import DagParam
from airflow.sdk.definitions.abstractoperator import AbstractOperator
from airflow.sdk.definitions.baseoperator import BaseOperator
+from airflow.sdk.types import NOTSET
from airflow.stats import Stats
from airflow.timetables.base import Timetable
from airflow.timetables.interval import CronDataIntervalTimetable,
DeltaDataIntervalTimetable
@@ -71,7 +72,7 @@ from airflow.timetables.simple import (
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.trigger_rule import TriggerRule
-from airflow.utils.types import NOTSET, EdgeInfoType
+from airflow.utils.types import EdgeInfoType
if TYPE_CHECKING:
from airflow.decorators import TaskDecoratorCollection
@@ -130,8 +131,6 @@ DAG_ARGS_EXPECTED_TYPES = {
"max_active_runs": int,
"max_consecutive_failed_dag_runs": int,
"dagrun_timeout": timedelta,
- "default_view": str,
- "orientation": str,
"catchup": bool,
"doc_md": str,
"is_paused_upon_creation": bool,
@@ -161,7 +160,18 @@ def _create_timetable(interval: ScheduleInterval,
timezone: Timezone | FixedTime
raise ValueError(f"{interval!r} is not a valid schedule.")
[email protected](kw_only=True)
+def _all_after_dag_id_to_kw_only(cls, fields: list[attrs.Attribute]):
+ i = iter(fields)
+ f = next(i)
+ if f.name != "dag_id":
+ raise RuntimeError("dag_id was not the first field")
+ yield f
+
+ for f in i:
+ yield f.evolve(kw_only=True)
+
+
[email protected](repr=False, field_transformer=_all_after_dag_id_to_kw_only)
class DAG:
"""
A dag (directed acyclic graph) is a collection of tasks with directional
dependencies.
@@ -231,9 +241,6 @@ class DAG:
:param dagrun_timeout: Specify the duration a DagRun should be allowed to
run before it times out or
fails. Task instances that are running when a DagRun is timed out will
be marked as skipped.
:param sla_miss_callback: DEPRECATED - The SLA feature is removed in
Airflow 3.0, to be replaced with a new implementation in 3.1
- :param default_view: Specify DAG default view (grid, graph, duration,
- gantt, landing_times),
default grid
- :param orientation: Specify DAG orientation in graph view (LR, TB, RL,
BT), default LR
:param catchup: Perform scheduler catchup (or only run latest)? Defaults
to True
:param on_failure_callback: A function or list of functions to be called
when a DagRun of this dag fails.
A context dictionary is passed as a single parameter to this function.
@@ -290,7 +297,7 @@ class DAG:
# NOTE: When updating arguments here, please also keep arguments in @dag()
# below in sync. (Search for 'def dag(' in this file.)
- dag_id: str
+ dag_id: str = attrs.field(kw_only=False)
description: str | None = None
start_date: datetime | None = None
end_date: datetime | None = None
@@ -299,7 +306,10 @@ class DAG:
timetable: Timetable = attrs.field(init=False)
full_filepath: str | None = None
template_searchpath: str | Iterable[str] | None = None
- # template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
+ # TODO: Task-SDK: Work out how to not import jinj2 until we need it! It's
expensive
+ template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
+ user_defined_macros: dict | None = None
+ user_defined_filters: dict | None = None
default_args: dict | None = attrs.field(factory=dict, converter=copy.copy)
concurrency: int | None = None
max_active_tasks: int = 16
@@ -326,6 +336,16 @@ class DAG:
task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.NO_OP)
+ fileloc: str = attrs.field(init=False)
+
+ edge_info: dict[str, dict[str, EdgeInfoType]] = attrs.field(init=False,
factory=dict)
+
+ @fileloc.default
+ def _default_fileloc(self) -> str:
+ # Skip over this frame, and the 'attrs generated init'
+ back = sys._getframe().f_back.f_back
+ return back.f_code.co_filename if back else ""
+
@dag_display_name.default
def _default_dag_display_name(self) -> str:
return self.dag_id
@@ -443,7 +463,6 @@ class DAG:
This is called by the DAG bag before bagging the DAG.
"""
- self.validate_schedule_and_params()
self.timetable.validate()
self.validate_setup_teardown()
@@ -850,13 +869,6 @@ class DAG:
args = parser.parse_args()
args.func(args, self)
- def get_default_view(self):
- """Allow backward compatible jinja2 templates."""
- if self.default_view is None:
- return airflow_conf.get("webserver", "dag_default_view").lower()
- else:
- return self.default_view
-
@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
@@ -943,8 +955,6 @@ def dag(
),
dagrun_timeout: timedelta | None = None,
sla_miss_callback: Any = None,
- default_view: str = airflow_conf.get_mandatory_value("webserver",
"dag_default_view").lower(),
- orientation: str = airflow_conf.get_mandatory_value("webserver",
"dag_orientation"),
catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"),
on_success_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback |
list[DagStateChangeCallback] = None,
@@ -995,8 +1005,6 @@ def dag(
max_consecutive_failed_dag_runs=max_consecutive_failed_dag_runs,
dagrun_timeout=dagrun_timeout,
sla_miss_callback=sla_miss_callback,
- default_view=default_view,
- orientation=orientation,
catchup=catchup,
on_success_callback=on_success_callback,
on_failure_callback=on_failure_callback,
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py
b/task_sdk/src/airflow/sdk/definitions/node.py
index 40236cac8bd..92527531eb9 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/node.py
@@ -24,7 +24,6 @@ from collections.abc import Iterable, Sequence
from datetime import datetime
from typing import TYPE_CHECKING
-import attrs
import methodtools
import re2
@@ -55,7 +54,6 @@ def validate_key(k: str, max_length: int = 250):
)
[email protected](hash=False, repr=False)
class DAGNode(DependencyMixin, metaclass=ABCMeta):
"""
A base class for a node in the graph of a workflow.
@@ -63,13 +61,17 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
A node may be an Operator or a Task Group, either mapped or unmapped.
"""
- dag: DAG | None = attrs.field(kw_only=True, default=None)
- task_group: TaskGroup | None = attrs.field(kw_only=True, default=None)
+ dag: DAG | None
+ task_group: TaskGroup | None
"""The task_group that contains this node"""
- start_date: datetime | None = attrs.field(kw_only=True, default=None)
- end_date: datetime | None = attrs.field(kw_only=True, default=None)
- upstream_task_ids: set[str] = attrs.field(init=False, factory=set)
- downstream_task_ids: set[str] = attrs.field(init=False, factory=set)
+ start_date: datetime | None
+ end_date: datetime | None
+ upstream_task_ids: set[str]
+ downstream_task_ids: set[str]
+
+ def __init__(self):
+ self.upstream_task_ids = set()
+ self.downstream_task_ids = set()
@property
@abstractmethod
@@ -188,6 +190,13 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
raise RuntimeError(f"Operator {self} has not been assigned to a
DAG yet")
return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
+ def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
+ """Get set of the direct relative ids to the current task, upstream or
downstream."""
+ if upstream:
+ return self.upstream_task_ids
+ else:
+ return self.downstream_task_ids
+
# def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
# """Serialize a task group's content; used by
TaskGroupSerialization."""
# raise NotImplementedError()
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 6bcda41d2d2..85139efdf0c 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -153,8 +153,12 @@ class TaskGroup(DAGNode):
return not self.group_id
@property
- def parent_group(self) -> TaskGroup | None:
- return self.task_group
+ def task_group(self) -> TaskGroup | None:
+ return self.parent_group
+
+ @task_group.setter
+ def _set_task_group(self, tg: TaskGroup):
+ self.parent_group = tg
def __iter__(self):
for child in self.children.values():
@@ -173,9 +177,9 @@ class TaskGroup(DAGNode):
from airflow.sdk.definitions.contextmanager import TaskGroupContext
if TaskGroupContext.active:
- if task.task_group and task.task_group != self:
- task.task_group.children.pop(task.node_id, None)
- task.task_group = self
+ if task.parent_group and task.parent_group != self:
+ task.parent_group.children.pop(task.node_id, None)
+ task.parent_group = self
existing_tg = task.task_group
if isinstance(task, AbstractOperator) and existing_tg is not None and
existing_tg != self:
raise TaskAlreadyInTaskGroup(task.node_id, existing_tg.node_id,
self.node_id)
@@ -213,9 +217,9 @@ class TaskGroup(DAGNode):
@property
def group_id(self) -> str | None:
"""group_id of this TaskGroup."""
- if self.task_group and self.task_group.prefix_group_id and
self.task_group._group_id:
+ if self.parent_group and self.parent_group.prefix_group_id and
self.parent_group.group_id:
# defer to parent whether it adds a prefix
- return self.task_group.child_id(self._group_id)
+ return self.parent_group.child_id(self.group_id)
return self._group_id
@@ -471,7 +475,7 @@ class TaskGroup(DAGNode):
while tg:
if tg.node_id in graph_unsorted:
break
- tg = tg.task_group
+ tg = tg.parent_group
if tg:
# We are already going to visit that TG
@@ -499,7 +503,7 @@ class TaskGroup(DAGNode):
while group is not None:
if isinstance(group, MappedTaskGroup):
yield group
- group = group.task_group
+ group = group.parent_group
def iter_tasks(self) -> Iterator[AbstractOperator]:
"""Return an iterator of the child tasks."""
diff --git a/task_sdk/tests/defintions/test_baseoperator.py
b/task_sdk/tests/defintions/test_baseoperator.py
index 0222de3ee32..f90891e1e86 100644
--- a/task_sdk/tests/defintions/test_baseoperator.py
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -124,6 +124,27 @@ class TestBaseOperator:
):
BaseOperator(task_id="test", execution_timeout=1)
+ def test_default_resources(self):
+ task = BaseOperator(task_id="default-resources")
+ assert task.resources is None
+
+ def test_custom_resources(self):
+ task = BaseOperator(task_id="custom-resources", resources={"cpus": 1,
"ram": 1024})
+ assert task.resources.cpus.qty == 1
+ assert task.resources.ram.qty == 1024
+
+ def test_default_email_on_actions(self):
+ test_task = BaseOperator(task_id="test_default_email_on_actions")
+ assert test_task.email_on_retry is True
+ assert test_task.email_on_failure is True
+
+ def test_email_on_actions(self):
+ test_task = BaseOperator(
+ task_id="test_default_email_on_actions", email_on_retry=False,
email_on_failure=True
+ )
+ assert test_task.email_on_retry is False
+ assert test_task.email_on_failure is True
+
def test_incorrect_default_args(self):
default_args = {"test_param": True, "extra_param": True}
op = FakeOperator(default_args=default_args)
@@ -228,6 +249,13 @@ class TestBaseOperator:
with pytest.raises(ValueError, match="can not be changed"):
op1.dag = DAG(dag_id="dag2")
+ def test_invalid_trigger_rule(self):
+ with pytest.raises(
+ ValueError,
+ match=(r"The trigger_rule must be one of .*,'\.op1'; received
'some_rule'\."),
+ ):
+ BaseOperator(task_id="op1", trigger_rule="some_rule")
+
def test_init_subclass_args():
class InitSubclassOp(BaseOperator):
@@ -291,7 +319,13 @@ def test_operator_retries_conversion(retries, expected):
assert op.retries == expected
-def test_dag_level_retry_delay(dag_maker):
+def test_default_retry_delay():
+ task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
+
+ assert task1.retry_delay == timedelta(seconds=300)
+
+
+def test_dag_level_retry_delay():
with DAG(dag_id="test_dag_level_retry_delay", default_args={"retry_delay":
timedelta(seconds=100)}):
task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
diff --git a/tests/models/test_baseoperator.py
b/tests/models/test_baseoperator.py
index eea5dae2693..8227a6b7c8f 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -21,7 +21,7 @@ import copy
import logging
import uuid
from collections import defaultdict
-from datetime import date, datetime, timedelta
+from datetime import date, datetime
from typing import NamedTuple
from unittest import mock
@@ -316,27 +316,6 @@ class TestBaseOperator:
task.render_template_fields(context={"foo": "whatever", "bar":
"whatever"})
assert mock_jinja_env.call_count == 1
- def test_default_resources(self):
- task = BaseOperator(task_id="default-resources")
- assert task.resources is None
-
- def test_custom_resources(self):
- task = BaseOperator(task_id="custom-resources", resources={"cpus": 1,
"ram": 1024})
- assert task.resources.cpus.qty == 1
- assert task.resources.ram.qty == 1024
-
- def test_default_email_on_actions(self):
- test_task = BaseOperator(task_id="test_default_email_on_actions")
- assert test_task.email_on_retry is True
- assert test_task.email_on_failure is True
-
- def test_email_on_actions(self):
- test_task = BaseOperator(
- task_id="test_default_email_on_actions", email_on_retry=False,
email_on_failure=True
- )
- assert test_task.email_on_retry is False
- assert test_task.email_on_failure is True
-
def test_cross_downstream(self):
"""Test if all dependencies between tasks are all set correctly."""
dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime.now())
@@ -598,16 +577,6 @@ class TestBaseOperator:
assert op_no_dag.start_date.tzinfo
assert op_no_dag.end_date.tzinfo
- def test_invalid_trigger_rule(self):
- with pytest.raises(
- AirflowException,
- match=(
- f"The trigger_rule must be one of
{TriggerRule.all_triggers()},"
- "'.op1'; received 'some_rule'."
- ),
- ):
- BaseOperator(task_id="op1", trigger_rule="some_rule")
-
# ensure the default logging config is used for this test, no matter what
ran before
@pytest.mark.usefixtures("reset_logging_config")
def test_logging_propogated_by_default(self, caplog):
@@ -619,14 +588,6 @@ class TestBaseOperator:
assert caplog.messages == ["test"]
[email protected]_test
-def test_default_retry_delay(dag_maker):
- with dag_maker(dag_id="test_default_retry_delay"):
- task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
-
- assert task1.retry_delay == timedelta(seconds=300)
-
-
def test_deepcopy():
# Test bug when copying an operator attached to a DAG
with DAG("dag0", schedule=None, start_date=DEFAULT_DATE) as dag: