This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 762f0212c10350c6445ba6f354af8dd357325b24 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Sun Oct 13 15:53:28 2024 +0100 Start replacing DAG class with the task-sdk DAG class instead I'm not sure if this approach is right, but it is _a_ path forward that lets us keep main working. So far this has been tested just with `uv run pytest tests/models/test_dag.py -k 'test_dag_as_context_manager'` and no further --- airflow/models/baseoperator.py | 27 +- airflow/models/dag.py | 826 +-------------------- airflow/utils/task_group.py | 30 +- task_sdk/pyproject.toml | 1 - task_sdk/src/airflow/sdk/__init__.py | 28 +- .../src/airflow/sdk/definitions/baseoperator.py | 2 +- .../src/airflow/sdk/definitions/contextmanager.py | 2 +- task_sdk/src/airflow/sdk/definitions/dag.py | 39 +- task_sdk/src/airflow/sdk/definitions/node.py | 9 +- .../definitions/{task_group.py => taskgroup.py} | 0 task_sdk/src/airflow/sdk/types.py | 4 +- task_sdk/tests/test_hello.py | 23 - uv.lock | 11 - 13 files changed, 85 insertions(+), 917 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0c3d119be19..39d56187371 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -277,14 +277,13 @@ def partial( allow_nested_operators: bool = True, **kwargs, ) -> OperatorPartial: - from airflow.models.dag import DagContext - from airflow.utils.task_group import TaskGroupContext + from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext validate_mapping_kwargs(operator_class, "partial", kwargs) - dag = dag or DagContext.get_current_dag() + dag = dag or DagContext.get_current() if dag: - task_group = task_group or TaskGroupContext.get_current_task_group(dag) + task_group = task_group or TaskGroupContext.get_current(dag) if task_group: task_id = task_group.child_id(task_id) @@ -453,8 +452,7 @@ class BaseOperatorMeta(abc.ABCMeta): @wraps(func) def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: - from airflow.models.dag import DagContext - from airflow.utils.task_group import TaskGroupContext + from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext if args: raise AirflowException("Use keyword arguments when initializing operators") @@ -464,10 +462,10 @@ class BaseOperatorMeta(abc.ABCMeta): getattr(self, "_BaseOperator__from_mapped", False), ) - dag: DAG | None = kwargs.get("dag") or DagContext.get_current_dag() + dag: DAG | None = kwargs.get("dag") or DagContext.get_current() task_group: TaskGroup | None = kwargs.get("task_group") if dag and not task_group: - task_group = TaskGroupContext.get_current_task_group(dag) + task_group = TaskGroupContext.get_current(dag) default_args, merged_params = get_merged_defaults( dag=dag, @@ -922,8 +920,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): allow_nested_operators: bool = True, **kwargs, ): - from airflow.models.dag import DagContext - from airflow.utils.task_group import TaskGroupContext + from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext self.__init_kwargs = {} @@ -937,8 +934,8 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): ) validate_key(task_id) - dag = dag or DagContext.get_current_dag() - task_group = task_group or TaskGroupContext.get_current_task_group(dag) + dag = dag or DagContext.get_current() + task_group = task_group or TaskGroupContext.get_current(dag) self.task_id = task_group.child_id(task_id) if task_group else task_id if not self.__from_mapped and task_group: @@ -1662,13 +1659,13 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" if not cls.__serialized_fields: - from airflow.models.dag import DagContext + from airflow.sdk.definitions.contextmanager import DagContext # make sure the following dummy task is not added to current active # dag in context, otherwise, it will result in # `RuntimeError: dictionary changed size during iteration` # Exception in SerializedDAG.serialize_dag() call. - DagContext.push_context_managed_dag(None) + DagContext.push(None) cls.__serialized_fields = frozenset( vars(BaseOperator(task_id="test")).keys() - { @@ -1704,7 +1701,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): "start_from_trigger", } ) - DagContext.pop_context_managed_dag() + DagContext.pop() return cls.__serialized_fields diff --git a/airflow/models/dag.py b/airflow/models/dag.py index f5def92ea92..290f42ca387 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -20,15 +20,12 @@ from __future__ import annotations import asyncio import copy import functools -import itertools import logging -import os import pathlib import pickle import sys import time import traceback -import weakref from collections import abc, defaultdict, deque from contextlib import ExitStack from datetime import datetime, timedelta @@ -40,15 +37,12 @@ from typing import ( Collection, Container, Iterable, - Iterator, - MutableSet, Pattern, Sequence, Union, cast, overload, ) -from urllib.parse import urlsplit import jinja2 import pendulum @@ -77,22 +71,19 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import backref, relationship from sqlalchemy.sql import Select, expression +import airflow.sdk.definitions.contextmanager import airflow.templates from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call -from airflow.assets import Asset, AssetAlias, AssetAll, BaseAsset +from airflow.assets import Asset, AssetAlias, BaseAsset from airflow.configuration import conf as airflow_conf, secrets_backend_list from airflow.exceptions import ( AirflowException, - DuplicateTaskIdFound, - FailStopDagInvalidTriggerRule, - ParamValidationError, TaskDeferred, - TaskNotFound, UnknownExecutorException, ) from airflow.executors.executor_loader import ExecutorLoader -from airflow.models.abstractoperator import AbstractOperator, TaskStateChangeCallback +from airflow.models.abstractoperator import TaskStateChangeCallback from airflow.models.asset import ( AssetDagRunQueue, AssetModel, @@ -102,7 +93,6 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dagcode import DagCode from airflow.models.dagpickle import DagPickle from airflow.models.dagrun import RUN_ID_REGEX, DagRun -from airflow.models.param import DagParam, ParamsDict from airflow.models.taskinstance import ( Context, TaskInstance, @@ -110,10 +100,10 @@ from airflow.models.taskinstance import ( clear_task_instances, ) from airflow.models.tasklog import LogTemplate +from airflow.sdk import DAG as TaskSDKDag from airflow.secrets.local_filesystem import LocalFilesystemBackend from airflow.security import permissions from airflow.settings import json -from airflow.stats import Stats from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable from airflow.timetables.simple import ( @@ -126,13 +116,11 @@ from airflow.timetables.trigger import CronTriggerTimetable from airflow.utils import timezone from airflow.utils.dag_cycle_tester import check_cycle from airflow.utils.decorators import fixup_decorator_warning_stack -from airflow.utils.helpers import exactly_one, validate_instance_args, validate_key -from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.helpers import exactly_one from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.trigger_rule import TriggerRule -from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType, EdgeInfoType +from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: from types import ModuleType @@ -141,13 +129,11 @@ if TYPE_CHECKING: from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session - from airflow.decorators import TaskDecoratorCollection from airflow.models.dagbag import DagBag from airflow.models.operator import Operator from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.typing_compat import Literal - from airflow.utils.task_group import TaskGroup log = logging.getLogger(__name__) @@ -359,7 +345,7 @@ DAG_ARGS_EXPECTED_TYPES = { @functools.total_ordering -class DAG(LoggingMixin): +class DAG(TaskSDKDag): """ A dag (directed acyclic graph) is a collection of tasks with directional dependencies. @@ -473,251 +459,6 @@ class DAG(LoggingMixin): :param dag_display_name: The display name of the DAG which appears on the UI. """ - _comps = { - "dag_id", - "task_ids", - "start_date", - "end_date", - "fileloc", - "template_searchpath", - "last_loaded", - } - - __serialized_fields: frozenset[str] | None = None - - fileloc: str - """ - File path that needs to be imported to load this DAG. - - This may not be an actual file on disk in the case when this DAG is loaded - from a ZIP file or other DAG distribution format. - """ - - # NOTE: When updating arguments here, please also keep arguments in @dag() - # below in sync. (Search for 'def dag(' in this file.) - def __init__( - self, - dag_id: str, - description: str | None = None, - schedule: ScheduleArg = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - template_searchpath: str | Iterable[str] | None = None, - template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined, - user_defined_macros: dict | None = None, - user_defined_filters: dict | None = None, - default_args: dict | None = None, - max_active_tasks: int = airflow_conf.getint("core", "max_active_tasks_per_dag"), - max_active_runs: int = airflow_conf.getint("core", "max_active_runs_per_dag"), - max_consecutive_failed_dag_runs: int = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ), - dagrun_timeout: timedelta | None = None, - sla_miss_callback: Any = None, - default_view: str = airflow_conf.get_mandatory_value("webserver", "dag_default_view").lower(), - orientation: str = airflow_conf.get_mandatory_value("webserver", "dag_orientation"), - catchup: bool = airflow_conf.getboolean("scheduler", "catchup_by_default"), - on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None, - doc_md: str | None = None, - params: abc.MutableMapping | None = None, - access_control: dict[str, dict[str, Collection[str]]] | dict[str, Collection[str]] | None = None, - is_paused_upon_creation: bool | None = None, - jinja_environment_kwargs: dict | None = None, - render_template_as_native_obj: bool = False, - tags: Collection[str] | None = None, - owner_links: dict[str, str] | None = None, - auto_register: bool = True, - fail_stop: bool = False, - dag_display_name: str | None = None, - ): - from airflow.utils.task_group import TaskGroup - - if tags and any(len(tag) > TAG_MAX_LEN for tag in tags): - raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters") - - self.owner_links = owner_links or {} - self.user_defined_macros = user_defined_macros - self.user_defined_filters = user_defined_filters - if default_args and not isinstance(default_args, dict): - raise TypeError("default_args must be a dict") - self.default_args = copy.deepcopy(default_args or {}) - params = params or {} - - # merging potentially conflicting default_args['params'] into params - if "params" in self.default_args: - params.update(self.default_args["params"]) - del self.default_args["params"] - - # check self.params and convert them into ParamsDict - self.params = ParamsDict(params) - - validate_key(dag_id) - - self._dag_id = dag_id - self._dag_display_property_value = dag_display_name - - self._max_active_tasks = max_active_tasks - self._pickle_id: int | None = None - - self._description = description - # set file location to caller source path - back = sys._getframe().f_back - self.fileloc = back.f_code.co_filename if back else "" - self.task_dict: dict[str, Operator] = {} - - # set timezone from start_date - tz = None - if start_date and start_date.tzinfo: - tzinfo = None if start_date.tzinfo else settings.TIMEZONE - tz = pendulum.instance(start_date, tz=tzinfo).timezone - elif date := self.default_args.get("start_date"): - if not isinstance(date, datetime): - date = timezone.parse(date) - self.default_args["start_date"] = date - start_date = date - - tzinfo = None if date.tzinfo else settings.TIMEZONE - tz = pendulum.instance(date, tz=tzinfo).timezone - self.timezone: Timezone | FixedTimezone = tz or settings.TIMEZONE - - # Apply the timezone we settled on to end_date if it wasn't supplied - if isinstance(_end_date := self.default_args.get("end_date"), str): - self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone) - - self.start_date = timezone.convert_to_utc(start_date) - self.end_date = timezone.convert_to_utc(end_date) - - # also convert tasks - if "start_date" in self.default_args: - self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"]) - if "end_date" in self.default_args: - self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"]) - - if isinstance(schedule, Timetable): - self.timetable = schedule - elif isinstance(schedule, BaseAsset): - self.timetable = AssetTriggeredTimetable(schedule) - elif isinstance(schedule, Collection) and not isinstance(schedule, str): - if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule): - raise ValueError("All elements in 'schedule' should be assets or asset aliases") - self.timetable = AssetTriggeredTimetable(AssetAll(*schedule)) - else: - self.timetable = create_timetable(schedule, self.timezone) - - requires_automatic_backfilling = self.timetable.can_be_scheduled and catchup - if requires_automatic_backfilling and not ("start_date" in self.default_args or self.start_date): - raise ValueError("start_date is required when catchup=True") - - if isinstance(template_searchpath, str): - template_searchpath = [template_searchpath] - self.template_searchpath = template_searchpath - self.template_undefined = template_undefined - self.last_loaded: datetime = timezone.utcnow() - self.safe_dag_id = dag_id.replace(".", "__dot__") - self.max_active_runs = max_active_runs - self.max_consecutive_failed_dag_runs = max_consecutive_failed_dag_runs - if self.max_consecutive_failed_dag_runs == 0: - self.max_consecutive_failed_dag_runs = airflow_conf.getint( - "core", "max_consecutive_failed_dag_runs_per_dag" - ) - if self.max_consecutive_failed_dag_runs < 0: - raise AirflowException( - f"Invalid max_consecutive_failed_dag_runs: {self.max_consecutive_failed_dag_runs}." - f"Requires max_consecutive_failed_dag_runs >= 0" - ) - if self.timetable.active_runs_limit is not None: - if self.timetable.active_runs_limit < self.max_active_runs: - raise AirflowException( - f"Invalid max_active_runs: {type(self.timetable)} " - f"requires max_active_runs <= {self.timetable.active_runs_limit}" - ) - self.dagrun_timeout = dagrun_timeout - if sla_miss_callback: - log.warning( - "The SLA feature is removed in Airflow 3.0, to be replaced with a new implementation in 3.1" - ) - if default_view in DEFAULT_VIEW_PRESETS: - self._default_view: str = default_view - else: - raise AirflowException( - f"Invalid values of dag.default_view: only support " - f"{DEFAULT_VIEW_PRESETS}, but get {default_view}" - ) - if orientation in ORIENTATION_PRESETS: - self.orientation = orientation - else: - raise AirflowException( - f"Invalid values of dag.orientation: only support " - f"{ORIENTATION_PRESETS}, but get {orientation}" - ) - self.catchup: bool = catchup - - self.partial: bool = False - self.on_success_callback = on_success_callback - self.on_failure_callback = on_failure_callback - - # Keeps track of any extra edge metadata (sparse; will not contain all - # edges, so do not iterate over it for that). Outer key is upstream - # task ID, inner key is downstream task ID. - self.edge_info: dict[str, dict[str, EdgeInfoType]] = {} - - # To keep it in parity with Serialized DAGs - # and identify if DAG has on_*_callback without actually storing them in Serialized JSON - self.has_on_success_callback: bool = self.on_success_callback is not None - self.has_on_failure_callback: bool = self.on_failure_callback is not None - - self._access_control = DAG._upgrade_outdated_dag_access_control(access_control) - self.is_paused_upon_creation = is_paused_upon_creation - self.auto_register = auto_register - - self.fail_stop: bool = fail_stop - - self.jinja_environment_kwargs = jinja_environment_kwargs - self.render_template_as_native_obj = render_template_as_native_obj - - self.doc_md = self.get_doc_md(doc_md) - - self.tags: MutableSet[str] = set(tags or []) - self._task_group = TaskGroup.create_root(self) - self.validate_schedule_and_params() - wrong_links = dict(self.iter_invalid_owner_links()) - if wrong_links: - raise AirflowException( - "Wrong link format was used for the owner. Use a valid link \n" - f"Bad formatted links are: {wrong_links}" - ) - - # this will only be set at serialization time - # it's only use is for determining the relative - # fileloc based only on the serialize dag - self._processor_dags_folder = None - - validate_instance_args(self, DAG_ARGS_EXPECTED_TYPES) - - def get_doc_md(self, doc_md: str | None) -> str | None: - if doc_md is None: - return doc_md - - if doc_md.endswith(".md"): - try: - return open(doc_md).read() - except FileNotFoundError: - return doc_md - - return doc_md - - def validate(self): - """ - Validate the DAG has a coherent setup. - - This is called by the DAG bag before bagging the DAG. - """ - self.validate_executor_field() - self.validate_schedule_and_params() - self.timetable.validate() - self.validate_setup_teardown() - def validate_executor_field(self): for task in self.tasks: if task.executor: @@ -730,63 +471,6 @@ class DAG(LoggingMixin): "update the executor configuration for this task." ) - def validate_setup_teardown(self): - """ - Validate that setup and teardown tasks are configured properly. - - :meta private: - """ - for task in self.tasks: - if task.is_setup: - for down_task in task.downstream_list: - if not down_task.is_teardown and down_task.trigger_rule != TriggerRule.ALL_SUCCESS: - # todo: we can relax this to allow out-of-scope tasks to have other trigger rules - # this is required to ensure consistent behavior of dag - # when clearing an indirect setup - raise ValueError("Setup tasks must be followed with trigger rule ALL_SUCCESS.") - FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) - - def __repr__(self): - return f"<DAG: {self.dag_id}>" - - def __eq__(self, other): - if type(self) is type(other): - # Use getattr() instead of __dict__ as __dict__ doesn't return - # correct values for properties. - return all(getattr(self, c, None) == getattr(other, c, None) for c in self._comps) - return False - - def __ne__(self, other): - return not self == other - - def __lt__(self, other): - return self.dag_id < other.dag_id - - def __hash__(self): - hash_components = [type(self)] - for c in self._comps: - # task_ids returns a list and lists can't be hashed - if c == "task_ids": - val = tuple(self.task_dict) - else: - val = getattr(self, c, None) - try: - hash(val) - hash_components.append(val) - except TypeError: - hash_components.append(repr(val)) - return hash(tuple(hash_components)) - - # Context Manager ----------------------------------------------- - def __enter__(self): - DagContext.push_context_managed_dag(self) - return self - - def __exit__(self, _type, _value, _tb): - DagContext.pop_context_managed_dag() - - # /Context Manager ---------------------------------------------- - @staticmethod def _upgrade_outdated_dag_access_control(access_control=None): """Look for outdated dag level actions in DAG access_controls and replace them with updated actions.""" @@ -1076,14 +760,6 @@ class DAG(LoggingMixin): def access_control(self, value): self._access_control = DAG._upgrade_outdated_dag_access_control(value) - @property - def dag_display_name(self) -> str: - return self._dag_display_property_value or self._dag_id - - @property - def description(self) -> str | None: - return self._description - @property def default_view(self) -> str: return self._default_view @@ -1096,41 +772,6 @@ class DAG(LoggingMixin): def pickle_id(self, value: int) -> None: self._pickle_id = value - def param(self, name: str, default: Any = NOTSET) -> DagParam: - """ - Return a DagParam object for current dag. - - :param name: dag parameter name. - :param default: fallback value for dag parameter. - :return: DagParam instance for specified name and current dag. - """ - return DagParam(current_dag=self, name=name, default=default) - - @property - def tasks(self) -> list[Operator]: - return list(self.task_dict.values()) - - @tasks.setter - def tasks(self, val): - raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.") - - @property - def task_ids(self) -> list[str]: - return list(self.task_dict) - - @property - def teardowns(self) -> list[Operator]: - return [task for task in self.tasks if getattr(task, "is_teardown", None)] - - @property - def tasks_upstream_of_teardowns(self) -> list[Operator]: - upstream_tasks = [t.upstream_list for t in self.teardowns] - return [val for sublist in upstream_tasks for val in sublist if not getattr(val, "is_teardown", None)] - - @property - def task_group(self) -> TaskGroup: - return self._task_group - @property def relative_fileloc(self) -> pathlib.Path: """File location of the importable dag 'file' relative to the configured DAGs folder.""" @@ -1145,24 +786,6 @@ class DAG(LoggingMixin): # Not relative to DAGS_FOLDER. return path - @property - def folder(self) -> str: - """Folder location of where the DAG object is instantiated.""" - return os.path.dirname(self.fileloc) - - @property - def owner(self) -> str: - """ - Return list of all owners found in DAG tasks. - - :return: Comma separated list of owners in DAG tasks - """ - return ", ".join({t.owner for t in self.tasks}) - - @property - def allow_future_exec_dates(self) -> bool: - return settings.ALLOW_FUTURE_EXEC_DATES and not self.timetable.can_be_scheduled - @provide_session def get_concurrency_reached(self, session=NEW_SESSION) -> bool: """Return a boolean indicating whether the max_active_tasks limit for this DAG has been reached.""" @@ -1251,24 +874,6 @@ class DAG(LoggingMixin): DAG.execute_callback(callbacks, context, self.dag_id) - @classmethod - def execute_callback(cls, callbacks: list[Callable] | None, context: Context | None, dag_id: str): - """ - Triggers the callbacks with the given context. - - :param callbacks: List of callbacks to call - :param context: Context to pass to all callbacks - :param dag_id: The dag_id of the DAG to find. - """ - if callbacks and context: - for callback in callbacks: - cls.logger().info("Executing dag callback function: %s", callback) - try: - callback(context) - except Exception: - cls.logger().exception("failed to invoke dag state update callback") - Stats.incr("dag.callback_exceptions", tags={"dag_id": dag_id}) - def get_active_runs(self): """ Return a list of dag run execution dates currently running. @@ -1368,45 +973,6 @@ class DAG(LoggingMixin): """Return the latest date for which at least one dag run exists.""" return session.scalar(select(func.max(DagRun.execution_date)).where(DagRun.dag_id == self.dag_id)) - def resolve_template_files(self): - for t in self.tasks: - t.resolve_template_files() - - def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: - """Build a Jinja2 environment.""" - # Collect directories to search for template files - searchpath = [self.folder] - if self.template_searchpath: - searchpath += self.template_searchpath - - # Default values (for backward compatibility) - jinja_env_options = { - "loader": jinja2.FileSystemLoader(searchpath), - "undefined": self.template_undefined, - "extensions": ["jinja2.ext.do"], - "cache_size": 0, - } - if self.jinja_environment_kwargs: - jinja_env_options.update(self.jinja_environment_kwargs) - env: jinja2.Environment - if self.render_template_as_native_obj and not force_sandboxed: - env = airflow.templates.NativeEnvironment(**jinja_env_options) - else: - env = airflow.templates.SandboxedEnvironment(**jinja_env_options) - - # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. - # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals - if self.user_defined_macros: - env.globals.update(self.user_defined_macros) - if self.user_defined_filters: - env.filters.update(self.user_defined_filters) - - return env - - def set_dependency(self, upstream_task_id, downstream_task_id): - """Set dependency between two tasks that already have been added to the DAG using add_task().""" - self.get_task(upstream_task_id).set_downstream(self.get_task(downstream_task_id)) - @provide_session def get_task_instances_before( self, @@ -1871,33 +1437,6 @@ class DAG(LoggingMixin): return altered - @property - def roots(self) -> list[Operator]: - """Return nodes with no parents. These are first to execute and are called roots or root nodes.""" - return [task for task in self.tasks if not task.upstream_list] - - @property - def leaves(self) -> list[Operator]: - """Return nodes with no children. These are last to execute and are called leaves or leaf nodes.""" - return [task for task in self.tasks if not task.downstream_list] - - def topological_sort(self): - """ - Sorts tasks in topographical order, such that a task comes after any of its upstream dependencies. - - Deprecated in place of ``task_group.topological_sort`` - """ - from airflow.utils.task_group import TaskGroup - - def nested_topo(group): - for node in group.topological_sort(): - if isinstance(node, TaskGroup): - yield from nested_topo(node) - else: - yield node - - return tuple(nested_topo(self.task_group)) - @provide_session def clear( self, @@ -2031,169 +1570,6 @@ class DAG(LoggingMixin): print("Cancelled, nothing was cleared.") return count - def __deepcopy__(self, memo): - # Switcharoo to go around deepcopying objects coming through the - # backdoor - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k not in ("user_defined_macros", "user_defined_filters", "_log"): - setattr(result, k, copy.deepcopy(v, memo)) - - result.user_defined_macros = self.user_defined_macros - result.user_defined_filters = self.user_defined_filters - if hasattr(self, "_log"): - result._log = self._log - return result - - def partial_subset( - self, - task_ids_or_regex: str | Pattern | Iterable[str], - include_downstream=False, - include_upstream=True, - include_direct_upstream=False, - ): - """ - Return a subset of the current dag based on regex matching one or more tasks. - - Returns a subset of the current dag as a deep copy of the current dag - based on a regex that should match one or many tasks, and includes - upstream and downstream neighbours based on the flag passed. - - :param task_ids_or_regex: Either a list of task_ids, or a regex to - match against task ids (as a string, or compiled regex pattern). - :param include_downstream: Include all downstream tasks of matched - tasks, in addition to matched tasks. - :param include_upstream: Include all upstream tasks of matched tasks, - in addition to matched tasks. - :param include_direct_upstream: Include all tasks directly upstream of matched - and downstream (if include_downstream = True) tasks - """ - from airflow.models.baseoperator import BaseOperator - from airflow.models.mappedoperator import MappedOperator - - # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all - # the tasks anyway, so we copy the tasks manually later - memo = {id(self.task_dict): None, id(self._task_group): None} - dag = copy.deepcopy(self, memo) # type: ignore - - if isinstance(task_ids_or_regex, (str, Pattern)): - matched_tasks = [t for t in self.tasks if re2.findall(task_ids_or_regex, t.task_id)] - else: - matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] - - also_include_ids: set[str] = set() - for t in matched_tasks: - if include_downstream: - for rel in t.get_flat_relatives(upstream=False): - also_include_ids.add(rel.task_id) - if rel not in matched_tasks: # if it's in there, we're already processing it - # need to include setups and teardowns for tasks that are in multiple - # non-collinear setup/teardown paths - if not rel.is_setup and not rel.is_teardown: - also_include_ids.update( - x.task_id for x in rel.get_upstreams_only_setups_and_teardowns() - ) - if include_upstream: - also_include_ids.update(x.task_id for x in t.get_upstreams_follow_setups()) - else: - if not t.is_setup and not t.is_teardown: - also_include_ids.update(x.task_id for x in t.get_upstreams_only_setups_and_teardowns()) - if t.is_setup and not include_downstream: - also_include_ids.update(x.task_id for x in t.downstream_list if x.is_teardown) - - also_include: list[Operator] = [self.task_dict[x] for x in also_include_ids] - direct_upstreams: list[Operator] = [] - if include_direct_upstream: - for t in itertools.chain(matched_tasks, also_include): - upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator))) - direct_upstreams.extend(upstream) - - # Compiling the unique list of tasks that made the cut - # Make sure to not recursively deepcopy the dag or task_group while copying the task. - # task_group is reset later - def _deepcopy_task(t) -> Operator: - memo.setdefault(id(t.task_group), None) - return copy.deepcopy(t, memo) - - dag.task_dict = { - t.task_id: _deepcopy_task(t) - for t in itertools.chain(matched_tasks, also_include, direct_upstreams) - } - - def filter_task_group(group, parent_group): - """Exclude tasks not included in the subdag from the given TaskGroup.""" - # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy - # and then manually deep copy the instances. (memo argument to deepcopy only works for instances - # of classes, not "native" properties of an instance) - copied = copy.copy(group) - - memo[id(group.children)] = {} - if parent_group: - memo[id(group.parent_group)] = parent_group - for attr, value in copied.__dict__.items(): - if id(value) in memo: - value = memo[id(value)] - else: - value = copy.deepcopy(value, memo) - copied.__dict__[attr] = value - - proxy = weakref.proxy(copied) - - for child in group.children.values(): - if isinstance(child, AbstractOperator): - if child.task_id in dag.task_dict: - task = copied.children[child.task_id] = dag.task_dict[child.task_id] - task.task_group = proxy - else: - copied.used_group_ids.discard(child.task_id) - else: - filtered_child = filter_task_group(child, proxy) - - # Only include this child TaskGroup if it is non-empty. - if filtered_child.children: - copied.children[child.group_id] = filtered_child - - return copied - - dag._task_group = filter_task_group(self.task_group, None) - - # Removing upstream/downstream references to tasks and TaskGroups that did not make - # the cut. - subdag_task_groups = dag.task_group.get_task_group_dict() - for group in subdag_task_groups.values(): - group.upstream_group_ids.intersection_update(subdag_task_groups) - group.downstream_group_ids.intersection_update(subdag_task_groups) - group.upstream_task_ids.intersection_update(dag.task_dict) - group.downstream_task_ids.intersection_update(dag.task_dict) - - for t in dag.tasks: - # Removing upstream/downstream references to tasks that did not - # make the cut - t.upstream_task_ids.intersection_update(dag.task_dict) - t.downstream_task_ids.intersection_update(dag.task_dict) - - if len(dag.tasks) < len(self.tasks): - dag.partial = True - - return dag - - def has_task(self, task_id: str): - return task_id in self.task_dict - - def has_task_group(self, task_group_id: str) -> bool: - return task_group_id in self.task_group_dict - - @functools.cached_property - def task_group_dict(self): - return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None} - - def get_task(self, task_id: str) -> Operator: - if task_id in self.task_dict: - return self.task_dict[task_id] - raise TaskNotFound(f"Task {task_id} not found") - def pickle_info(self): d = {} d["is_picklable"] = True @@ -2223,76 +1599,6 @@ class DAG(LoggingMixin): return dp - @property - def task(self) -> TaskDecoratorCollection: - from airflow.decorators import task - - return cast("TaskDecoratorCollection", functools.partial(task, dag=self)) - - def add_task(self, task: Operator) -> None: - """ - Add a task to the DAG. - - :param task: the task you want to add - """ - FailStopDagInvalidTriggerRule.check(dag=self, trigger_rule=task.trigger_rule) - - from airflow.utils.task_group import TaskGroupContext - - # if the task has no start date, assign it the same as the DAG - if not task.start_date: - task.start_date = self.start_date - # otherwise, the task will start on the later of its own start date and - # the DAG's start date - elif self.start_date: - task.start_date = max(task.start_date, self.start_date) - - # if the task has no end date, assign it the same as the dag - if not task.end_date: - task.end_date = self.end_date - # otherwise, the task will end on the earlier of its own end date and - # the DAG's end date - elif task.end_date and self.end_date: - task.end_date = min(task.end_date, self.end_date) - - task_id = task.task_id - if not task.task_group: - task_group = TaskGroupContext.get_current_task_group(self) - if task_group: - task_id = task_group.child_id(task_id) - task_group.add(task) - - if ( - task_id in self.task_dict and self.task_dict[task_id] is not task - ) or task_id in self._task_group.used_group_ids: - raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG") - else: - self.task_dict[task_id] = task - task.dag = self - # Add task_id to used_group_ids to prevent group_id and task_id collisions. - self._task_group.used_group_ids.add(task_id) - - self.task_count = len(self.task_dict) - - def add_tasks(self, tasks: Iterable[Operator]) -> None: - """ - Add a list of tasks to the DAG. - - :param tasks: a lit of tasks you want to add - """ - for task in tasks: - self.add_task(task) - - def _remove_task(self, task_id: str) -> None: - # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this - # doesn't guard against that - task = self.task_dict.pop(task_id) - tg = getattr(task, "task_group", None) - if tg: - tg._remove(task) - - self.task_count = len(self.task_dict) - def cli(self): """Exposes a CLI specific to this DAG.""" check_cycle(self) @@ -2692,88 +1998,6 @@ class DAG(LoggingMixin): qry = qry.where(TaskInstance.state.in_(states)) return session.scalar(qry) - @classmethod - def get_serialized_fields(cls): - """Stringified DAGs and operators contain exactly these fields.""" - if not cls.__serialized_fields: - exclusion_list = { - "schedule_dataset_references", - "schedule_dataset_alias_references", - "task_outlet_dataset_references", - "_old_context_manager_dags", - "safe_dag_id", - "last_loaded", - "user_defined_filters", - "user_defined_macros", - "partial", - "params", - "_pickle_id", - "_log", - "task_dict", - "template_searchpath", - "sla_miss_callback", - "on_success_callback", - "on_failure_callback", - "template_undefined", - "jinja_environment_kwargs", - # has_on_*_callback are only stored if the value is True, as the default is False - "has_on_success_callback", - "has_on_failure_callback", - "auto_register", - "fail_stop", - } - cls.__serialized_fields = frozenset(vars(DAG(dag_id="test", schedule=None))) - exclusion_list - return cls.__serialized_fields - - def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType: - """Return edge information for the given pair of tasks or an empty edge if there is no information.""" - # Note - older serialized DAGs may not have edge_info being a dict at all - empty = cast(EdgeInfoType, {}) - if self.edge_info: - return self.edge_info.get(upstream_task_id, {}).get(downstream_task_id, empty) - else: - return empty - - def set_edge_info(self, upstream_task_id: str, downstream_task_id: str, info: EdgeInfoType): - """ - Set the given edge information on the DAG. - - Note that this will overwrite, rather than merge with, existing info. - """ - self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info - - def validate_schedule_and_params(self): - """ - Validate Param values when the DAG has schedule defined. - - Raise exception if there are any Params which can not be resolved by their schema definition. - """ - if not self.timetable.can_be_scheduled: - return - - try: - self.params.validate() - except ParamValidationError as pverr: - raise AirflowException( - "DAG is not allowed to define a Schedule, " - "if there are any required params without default values or default values are not valid." - ) from pverr - - def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: - """ - Parse a given link, and verifies if it's a valid URL, or a 'mailto' link. - - Returns an iterator of invalid (owner, link) pairs. - """ - for owner, link in self.owner_links.items(): - result = urlsplit(link) - if result.scheme == "mailto": - # netloc is not existing for 'mailto' link, so we are checking that the path is parsed - if not result.path: - yield result.path, link - elif not result.scheme or not result.netloc: - yield owner, link - class DagTag(Base): """A tag name per dag, to allow quick filtering in the DAG view.""" @@ -3342,25 +2566,9 @@ if STATICA_HACK: # pragma: no cover """:sphinx-autoapi-skip:""" -class DagContext: +class DagContext(airflow.sdk.definitions.contextmanager.DagContext): """ - DAG context is used to keep the current DAG when DAG is used as ContextManager. - - You can use DAG as context: - - .. code-block:: python - - with DAG( - dag_id="example_dag", - default_args=default_args, - schedule="0 0 * * *", - dagrun_timeout=timedelta(minutes=60), - ) as dag: - ... - - If you do this the context stores the DAG and whenever new task is created, it will use - such stored DAG as the parent DAG. - + :meta private: """ _context_managed_dags: deque[DAG] = deque() @@ -3369,25 +2577,15 @@ class DagContext: @classmethod def push_context_managed_dag(cls, dag: DAG): - cls._context_managed_dags.appendleft(dag) + cls.push(dag) @classmethod def pop_context_managed_dag(cls) -> DAG | None: - dag = cls._context_managed_dags.popleft() - - # In a few cases around serialization we explicitly push None in to the stack - if cls.current_autoregister_module_name is not None and dag and dag.auto_register: - mod = sys.modules[cls.current_autoregister_module_name] - cls.autoregistered_dags.add((dag, mod)) - - return dag + return cls.pop() @classmethod def get_current_dag(cls) -> DAG | None: - try: - return cls._context_managed_dags[0] - except IndexError: - return None + return cls.get_current() def _run_inline_trigger(trigger): diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 69a5d015bd4..79b6329e44b 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING, Any, Generator, Iterator, Sequence import methodtools import re2 +import airflow.sdk.definitions.contextmanager from airflow.exceptions import ( AirflowDagCycleException, AirflowException, @@ -650,44 +651,23 @@ class MappedTaskGroup(TaskGroup): super().__exit__(exc_type, exc_val, exc_tb) -class TaskGroupContext: +class TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext): """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" - active: bool = False - _context_managed_task_group: TaskGroup | None = None - _previous_context_managed_task_groups: list[TaskGroup] = [] - @classmethod def push_context_managed_task_group(cls, task_group: TaskGroup): """Push a TaskGroup into the list of managed TaskGroups.""" - if cls._context_managed_task_group: - cls._previous_context_managed_task_groups.append(cls._context_managed_task_group) - cls._context_managed_task_group = task_group - cls.active = True + return cls.pusg(task_group) @classmethod def pop_context_managed_task_group(cls) -> TaskGroup | None: """Pops the last TaskGroup from the list of managed TaskGroups and update the current TaskGroup.""" - old_task_group = cls._context_managed_task_group - if cls._previous_context_managed_task_groups: - cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop() - else: - cls._context_managed_task_group = None - cls.active = False - return old_task_group + return cls.pop() @classmethod def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None: """Get the current TaskGroup.""" - from airflow.models.dag import DagContext - - if not cls._context_managed_task_group: - dag = dag or DagContext.get_current_dag() - if dag: - # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. - return dag.task_group - - return cls._context_managed_task_group + return cls.get_current() def task_group_to_dict(task_item_or_group): diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index b0a912d9f04..f83a4b7ec25 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -25,7 +25,6 @@ dependencies = [ "attrs>=24.2.0", "google-re2>=1.1.20240702", "methodtools>=0.4.7", - "structlog>=24.4.0", ] [build-system] diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index 2a3e01b64bc..baf7c85baa9 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -16,6 +16,30 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING -def hello() -> str: - return "Hello from task-sdk!" +__all__ = ["DAG", "BaseOperator", "TaskGroup"] + +if TYPE_CHECKING: + from airflow.sdk.definitions.baseoperator import BaseOperator as BaseOperator + from airflow.sdk.definitions.dag import DAG as DAG + from airflow.sdk.definitions.taskgroup import TaskGroup as TaskGroup + +__lazy_imports: dict[str, str] = { + "DAG": ".definitions.dag", + "BaseOperator": ".definitions.baseoperator", + "TaskGroup": ".definitions.taskgroup", +} + + +def __getattr__(name: str): + if module_path := __lazy_imports.get(name): + import importlib + + mod = importlib.import_module(module_path, __name__) + val = getattr(mod, name) + + # Store for next time + globals()[name] = val + return val + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index c466f158cc9..d52802db21f 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -54,7 +54,7 @@ if TYPE_CHECKING: class ParamsDict: ... from airflow.sdk.definitions.dag import DAG - from airflow.sdk.definitions.task_group import TaskGroup + from airflow.sdk.definitions.taskgroup import TaskGroup # TODO: Task-SDK diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py b/task_sdk/src/airflow/sdk/definitions/contextmanager.py index fb160e06b35..d97339fc265 100644 --- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py +++ b/task_sdk/src/airflow/sdk/definitions/contextmanager.py @@ -22,7 +22,7 @@ from types import ModuleType from typing import Generic, TypeVar from airflow.sdk.definitions.dag import DAG -from airflow.sdk.definitions.task_group import TaskGroup +from airflow.sdk.definitions.taskgroup import TaskGroup T = TypeVar("T") diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 8a2799d0f51..f80d8f7f71b 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -58,9 +58,6 @@ from airflow.exceptions import ( TaskNotFound, ) from airflow.models.param import DagParam -from airflow.models.taskinstance import ( - Context, -) from airflow.sdk.definitions.abstractoperator import AbstractOperator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.stats import Stats @@ -73,7 +70,7 @@ from airflow.utils.types import NOTSET, EdgeInfoType if TYPE_CHECKING: from airflow.decorators import TaskDecoratorCollection from airflow.models.operator import Operator - from airflow.utils.task_group import TaskGroup + from airflow.utils.taskgroup import TaskGroup log = logging.getLogger(__name__) @@ -82,6 +79,11 @@ ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"] TAG_MAX_LEN = 100 + +# TODO: Task-SDK +class Context: ... + + DagStateChangeCallback = Callable[[Context], None] ScheduleInterval = Union[None, str, timedelta, relativedelta] @@ -93,13 +95,6 @@ ScheduleArg = Union[ ] -# Defined as a function so that we can avoid import cycles -def _create_root_group(dag: DAG) -> TaskGroup: - from airflow.sdk.definitions.task_group import TaskGroup - - return TaskGroup(group_id=None, dag=dag) - - _DAG_HASH_ATTRS = frozenset( { "dag_id", @@ -299,19 +294,27 @@ class DAG: owner_links: dict[str, str] | None = None auto_register: bool = True fail_stop: bool = False - dag_display_name: str | None = None + dag_display_name: str = attrs.field() task_dict: dict[str, DAGNode] = attrs.field(factory=dict, init=False) - task_group: TaskGroup = attrs.field( - default=attrs.Factory(_create_root_group, takes_self=True), on_setattr=attrs.setters.NO_OP - ) + task_group: TaskGroup = attrs.field(on_setattr=attrs.setters.NO_OP) + + @dag_display_name.default + def _default_dag_display_name(self) -> str: + return self.dag_id + + @task_group.default + def _default_task_group(self) -> TaskGroup: + from airflow.sdk.definitions.taskgroup import TaskGroup + + return TaskGroup.create_root(dag=self) def __repr__(self): return f"<DAG: {self.dag_id}>" def __eq__(self, other: Self | Any): - if type(self) != type(other): + if not isinstance(other, type(self)): return NotImplemented return all(getattr(self, c, None) == getattr(other, c, None) for c in _DAG_HASH_ATTRS) @@ -337,7 +340,7 @@ class DAG: return hash(tuple(hash_components)) def __enter__(self) -> Self: - from .contextmanager import DagContext + from airflow.sdk.definitions.contextmanager import DagContext DagContext.push(self) return self @@ -521,7 +524,7 @@ class DAG: Deprecated in place of ``task_group.topological_sort`` """ - from airflow.utils.task_group import TaskGroup + from airflow.utils.taskgroup import TaskGroup def nested_topo(group): for node in group.topological_sort(): diff --git a/task_sdk/src/airflow/sdk/definitions/node.py b/task_sdk/src/airflow/sdk/definitions/node.py index 0abb0a83c16..1d980284677 100644 --- a/task_sdk/src/airflow/sdk/definitions/node.py +++ b/task_sdk/src/airflow/sdk/definitions/node.py @@ -17,6 +17,8 @@ from __future__ import annotations +import logging +import re from abc import ABCMeta, abstractmethod from collections.abc import Iterable, Sequence from datetime import datetime @@ -25,20 +27,19 @@ from typing import TYPE_CHECKING import attrs import methodtools import re2 -import structlog from airflow.sdk.definitions.mixins import DependencyMixin if TYPE_CHECKING: from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier - from airflow.sdk.definitions.task_group import TaskGroup + from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.types import Logger KEY_REGEX = re2.compile(r"^[\w.-]+$") GROUP_KEY_REGEX = re2.compile(r"^[\w-]+$") -CAMELCASE_TO_SNAKE_CASE_REGEX = re2.compile(r"(?!^)([A-Z]+)") +CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)") def validate_key(k: str, max_length: int = 250): @@ -98,7 +99,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): def log(self) -> Logger: typ = type(self) name = f"{typ.__module__}.{typ.__qualname__}" - return structlog.get_logger(logger_name=name) + return logging.getLogger(name) def _set_relatives( self, diff --git a/task_sdk/src/airflow/sdk/definitions/task_group.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py similarity index 100% rename from task_sdk/src/airflow/sdk/definitions/task_group.py rename to task_sdk/src/airflow/sdk/definitions/taskgroup.py diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 172ff733a14..505ee4cb191 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -41,9 +41,9 @@ NOTSET = ArgNotSet() if TYPE_CHECKING: - import structlog + import logging - Logger = structlog.typing.FilteringBoundLogger + Logger = logging.Logger else: class Logger: ... diff --git a/task_sdk/tests/test_hello.py b/task_sdk/tests/test_hello.py deleted file mode 100644 index 62cfdc069ca..00000000000 --- a/task_sdk/tests/test_hello.py +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from airflow.sdk import hello - - -def test_hello(): - assert hello() == "Hello from task-sdk!" diff --git a/uv.lock b/uv.lock index d3508fac220..3a3c2db04b7 100644 --- a/uv.lock +++ b/uv.lock @@ -1369,7 +1369,6 @@ dependencies = [ { name = "attrs" }, { name = "google-re2" }, { name = "methodtools" }, - { name = "structlog" }, ] [package.dev-dependencies] @@ -1384,7 +1383,6 @@ requires-dist = [ { name = "attrs", specifier = ">=24.2.0" }, { name = "google-re2", specifier = ">=1.1.20240702" }, { name = "methodtools", specifier = ">=0.4.7" }, - { name = "structlog", specifier = ">=24.4.0" }, ] [package.metadata.requires-dev] @@ -3688,15 +3686,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/9c/93f7bc03ff03199074e81974cc148908ead60dcf189f68ba1761a0ee35cf/starlette-0.38.6-py3-none-any.whl", hash = "sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05", size = 71451 }, ] -[[package]] -name = "structlog" -version = "24.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/78/a3/e811a94ac3853826805253c906faa99219b79951c7d58605e89c79e65768/structlog-24.4.0.tar.gz", hash = "sha256:b27bfecede327a6d2da5fbc96bd859f114ecc398a6389d664f62085ee7ae6fc4", size = 1348634 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/65/813fc133609ebcb1299be6a42e5aea99d6344afb35ccb43f67e7daaa3b92/structlog-24.4.0-py3-none-any.whl", hash = "sha256:597f61e80a91cc0749a9fd2a098ed76715a1c8a01f73e336b746504d1aad7610", size = 67180 }, -] - [[package]] name = "tabulate" version = "0.9.0"
