This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit c0973ae27b06314faf49c35d9db5c3249b1e4b77 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Mon Oct 28 18:02:06 2024 +0000 Fix serialization --- airflow/dag_processing/collection.py | 2 +- airflow/decorators/base.py | 28 +++---- airflow/decorators/bash.py | 4 +- airflow/decorators/sensor.py | 4 +- airflow/models/baseoperator.py | 6 +- airflow/models/dag.py | 40 ++++++--- airflow/models/mappedoperator.py | 3 +- airflow/serialization/schema.json | 8 +- airflow/serialization/serialized_objects.py | 30 +++---- dev/breeze/doc/images/output_build-docs.txt | 2 +- dev/breeze/doc/images/output_prod-image.txt | 2 +- dev/breeze/doc/images/output_prod-image_build.txt | 2 +- dev/breeze/doc/images/output_setup.txt | 2 +- .../doc/images/output_setup_autocomplete.txt | 2 +- dev/breeze/doc/images/output_setup_config.txt | 2 +- dev/breeze/doc/images/output_start-airflow.txt | 2 +- .../providers/amazon/aws/operators/comprehend.py | 6 +- .../airflow/providers/amazon/aws/operators/dms.py | 6 +- .../amazon/aws/operators/kinesis_analytics.py | 6 +- .../providers/amazon/aws/operators/sagemaker.py | 4 +- .../fab/auth_manager/security_manager/override.py | 9 +- .../src/airflow/sdk/definitions/baseoperator.py | 23 +++-- task_sdk/src/airflow/sdk/definitions/dag.py | 5 +- task_sdk/src/airflow/sdk/definitions/taskgroup.py | 21 +++-- tests/serialization/test_dag_serialization.py | 98 ++++++++++++---------- 25 files changed, 180 insertions(+), 137 deletions(-) diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index b06821ff6db..f608900ee76 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -213,7 +213,7 @@ class DagModelOperation(NamedTuple): dm.default_view = dag.default_view if hasattr(dag, "_dag_display_property_value"): dm._dag_display_property_value = dag._dag_display_property_value - else: + elif dag.dag_display_name != dag.dag_id: dm._dag_display_property_value = dag.dag_display_name dm.description = dag.description dm.max_active_tasks = dag.max_active_tasks diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index c9e4cf170f9..1c9e441190a 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -41,7 +41,6 @@ import re2 import typing_extensions from airflow.assets import Asset -from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY from airflow.models.baseoperator import ( BaseOperator, coerce_resources, @@ -56,7 +55,6 @@ from airflow.models.expandinput import ( is_mappable, ) from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value -from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext @@ -460,32 +458,26 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubcla task_id = task_group.child_id(task_id) # Logic here should be kept in sync with BaseOperatorMeta.partial(). - if "task_concurrency" in partial_kwargs: - raise TypeError("unexpected argument: task_concurrency") if partial_kwargs.get("wait_for_downstream"): partial_kwargs["depends_on_past"] = True start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None)) end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) - if partial_kwargs.get("pool") is None: - partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME if "pool_slots" in partial_kwargs: if partial_kwargs["pool_slots"] < 1: dag_str = "" if dag: dag_str = f" in dag {dag.dag_id}" raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") - partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) - partial_kwargs["retry_delay"] = coerce_timedelta( - partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), - key="retry_delay", - ) - max_retry_delay = partial_kwargs.get("max_retry_delay") - partial_kwargs["max_retry_delay"] = ( - max_retry_delay - if max_retry_delay is None - else coerce_timedelta(max_retry_delay, key="max_retry_delay") - ) - partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources")) + + for fld, convert in ( + ("retries", parse_retries), + ("retry_delay", coerce_timedelta), + ("max_retry_delay", coerce_timedelta), + ("resources", coerce_resources), + ): + if (v := partial_kwargs.get(fld, NOTSET)) is not NOTSET: + partial_kwargs[fld] = convert(v) # type: ignore[operator] + partial_kwargs.setdefault("executor_config", {}) partial_kwargs.setdefault("op_args", []) partial_kwargs.setdefault("op_kwargs", {}) diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py index 44738492da0..e4dc19745e0 100644 --- a/airflow/decorators/bash.py +++ b/airflow/decorators/bash.py @@ -18,7 +18,7 @@ from __future__ import annotations import warnings -from typing import Any, Callable, Collection, Mapping, Sequence +from typing import Any, Callable, ClassVar, Collection, Mapping, Sequence from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.providers.standard.operators.bash import BashOperator @@ -39,7 +39,7 @@ class _BashDecoratedOperator(DecoratedOperator, BashOperator): """ template_fields: Sequence[str] = (*DecoratedOperator.template_fields, *BashOperator.template_fields) - template_fields_renderers: dict[str, str] = { + template_fields_renderers: ClassVar[dict[str, str]] = { **DecoratedOperator.template_fields_renderers, **BashOperator.template_fields_renderers, } diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py index c332a78f95c..9ee4eeb2a79 100644 --- a/airflow/decorators/sensor.py +++ b/airflow/decorators/sensor.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Callable, ClassVar, Sequence from airflow.decorators.base import get_unique_task_id, task_decorator_factory from airflow.sensors.python import PythonSensor @@ -42,7 +42,7 @@ class DecoratedSensorOperator(PythonSensor): """ template_fields: Sequence[str] = ("op_args", "op_kwargs") - template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"} + template_fields_renderers: ClassVar[dict[str, str]] = {"op_args": "py", "op_kwargs": "py"} custom_operator_name = "@task.sensor" diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f0d1ee6f965..e0ef0764406 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -137,10 +137,12 @@ def parse_retries(retries: Any) -> int | None: return parsed_retries -def coerce_timedelta(value: float | timedelta, *, key: str) -> timedelta: +def coerce_timedelta(value: float | timedelta, *, key: str | None = None) -> timedelta: if isinstance(value, timedelta): return value - logger.debug("%s isn't a timedelta object, assuming secs", key) + # TODO: remove this log here + if key: + logger.debug("%s isn't a timedelta object, assuming secs", key) return timedelta(seconds=value) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d133eb43e68..b9620e52202 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -44,6 +44,7 @@ from typing import ( ) import attrs +import methodtools import pendulum import re2 import sqlalchemy_jsonfield @@ -301,6 +302,17 @@ else: return task_sdk_dag_decorator(dag_id, __DAG_class=DAG, __warnings_stacklevel_delta=3, **kwargs) +def _convert_max_consecutive_failed_dag_runs(val: int) -> int: + if val == 0: + val = airflow_conf.getint("core", "max_consecutive_failed_dag_runs_per_dag") + if val < 0: + raise ValueError( + f"Invalid max_consecutive_failed_dag_runs: {val}." + f"Requires max_consecutive_failed_dag_runs >= 0" + ) + return val + + @functools.total_ordering @attrs.define(hash=False, repr=False, eq=False) class DAG(TaskSDKDag, LoggingMixin): @@ -428,11 +440,15 @@ class DAG(TaskSDKDag, LoggingMixin): _processor_dags_folder: str | None = attrs.field(init=False, default=None) # Override the default from parent class to use config - max_consecutive_failed_dag_runs: int = attrs.field() + max_consecutive_failed_dag_runs: int = attrs.field( + default=0, + converter=_convert_max_consecutive_failed_dag_runs, + validator=attrs.validators.instance_of(int), + ) - @max_consecutive_failed_dag_runs.default - def _max_consecutive_failed_dag_runs_default(self): - return airflow_conf.getint("core", "max_consecutive_failed_dag_runs_per_dag") + @property + def safe_dag_id(self): + return self.dag_id.replace(".", "__dot__") def validate(self): super().validate() @@ -723,14 +739,6 @@ class DAG(TaskSDKDag, LoggingMixin): def timetable_summary(self) -> str: return self.timetable.summary - @property - def max_active_tasks(self) -> int: - return self._max_active_tasks - - @max_active_tasks.setter - def max_active_tasks(self, value: int): - self._max_active_tasks = value - @property def pickle_id(self) -> int | None: return self._pickle_id @@ -775,6 +783,14 @@ class DAG(TaskSDKDag, LoggingMixin): """Return a boolean indicating whether this DAG is paused.""" return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id == self.dag_id)) + @methodtools.lru_cache(maxsize=None) + @classmethod + def get_serialized_fields(cls): + """Stringified DAGs and operators contain exactly these fields.""" + return TaskSDKDag.get_serialized_fields() | { + "_processor_dags_folder", + } + @staticmethod @internal_api_call @provide_session diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 52a08bce027..925acfc16f0 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -333,7 +333,8 @@ class MappedOperator(AbstractOperator): @classmethod def get_serialized_fields(cls): # Not using 'cls' here since we only want to serialize base fields. - return frozenset(attr.fields_dict(MappedOperator)) - { + return (frozenset(attr.fields_dict(MappedOperator)) | {"task_type"}) - { + "_task_type", "dag", "deps", "expand_input", # This is needed to be able to accept XComArg. diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index f89bd348776..32ccd3dfff9 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -156,7 +156,6 @@ {"type": "string"} ] }, - "orientation": { "type" : "string"}, "dag_display_name": { "type" : "string"}, "description": { "type" : "string"}, "_concurrency": { "type" : "number"}, @@ -168,8 +167,7 @@ "end_date": { "$ref": "#/definitions/datetime" }, "dagrun_timeout": { "$ref": "#/definitions/timedelta" }, "doc_md": { "type" : "string"}, - "_default_view": { "type" : "string"}, - "_access_control": {"$ref": "#/definitions/dict" }, + "access_control": {"$ref": "#/definitions/dict" }, "is_paused_upon_creation": { "type": "boolean" }, "has_on_success_callback": { "type": "boolean" }, "has_on_failure_callback": { "type": "boolean" }, @@ -219,7 +217,7 @@ "$comment": "A task/operator in a DAG", "type": "object", "required": [ - "_task_type", + "task_type", "_task_module", "task_id", "ui_color", @@ -227,7 +225,7 @@ "template_fields" ], "properties": { - "_task_type": { "type": "string" }, + "task_type": { "type": "string" }, "_task_module": { "type": "string" }, "_operator_extra_links": { "$ref": "#/definitions/extra_links" }, "task_id": { "type": "string" }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index e595ef17449..79403860f5f 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -21,7 +21,6 @@ from __future__ import annotations import collections.abc import datetime import enum -import inspect import itertools import logging import weakref @@ -599,7 +598,7 @@ class BaseSerialization: if key == "_operator_name": # when operator_name matches task_type, we can remove # it to reduce the JSON payload - task_type = getattr(object_to_serialize, "_task_type", None) + task_type = getattr(object_to_serialize, "task_type", None) if value != task_type: serialized_object[key] = cls.serialize(value) elif key in decorated_fields: @@ -1157,9 +1156,9 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): """Serialize operator into a JSON object.""" serialize_op = cls.serialize_to_json(op, cls._decorated_fields) - serialize_op["_task_type"] = getattr(op, "_task_type", type(op).__name__) + serialize_op["task_type"] = getattr(op, "task_type", type(op).__name__) serialize_op["_task_module"] = getattr(op, "_task_module", type(op).__module__) - if op.operator_name != serialize_op["_task_type"]: + if op.operator_name != serialize_op["task_type"]: serialize_op["_operator_name"] = op.operator_name # Used to determine if an Operator is inherited from EmptyOperator @@ -1183,7 +1182,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): # Store all template_fields as they are if there are JSON Serializable # If not, store them as strings # And raise an exception if the field is not templateable - forbidden_fields = set(inspect.signature(BaseOperator.__init__).parameters.keys()) + forbidden_fields = set(SerializedBaseOperator._CONSTRUCTOR_PARAMS.keys()) # Though allow some of the BaseOperator fields to be templated anyway forbidden_fields.difference_update({"email"}) if op.template_fields: @@ -1248,7 +1247,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): op_extra_links_from_plugin = {} if "_operator_name" not in encoded_op: - encoded_op["_operator_name"] = encoded_op["_task_type"] + encoded_op["_operator_name"] = encoded_op["task_type"] # We don't want to load Extra Operator links in Scheduler if cls._load_operator_extra_links: @@ -1262,7 +1261,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): for ope in plugins_manager.operator_extra_links: for operator in ope.operators: if ( - operator.__name__ == encoded_op["_task_type"] + operator.__name__ == encoded_op["task_type"] and operator.__module__ == encoded_op["_task_module"] ): op_extra_links_from_plugin.update({ope.name: ope}) @@ -1278,6 +1277,8 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): if k in ("_outlets", "_inlets"): # `_outlets` -> `outlets` k = k[1:] + elif k == "task_type": + k = "_task_type" if k == "_downstream_task_ids": # Upgrade from old format/name k = "downstream_task_ids" @@ -1389,7 +1390,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): try: operator_name = encoded_op["_operator_name"] except KeyError: - operator_name = encoded_op["_task_type"] + operator_name = encoded_op["task_type"] op = MappedOperator( operator_class=op_data, @@ -1406,7 +1407,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): ui_fgcolor=BaseOperator.ui_fgcolor, is_empty=False, task_module=encoded_op["_task_module"], - task_type=encoded_op["_task_type"], + task_type=encoded_op["task_type"], operator_name=operator_name, dag=None, task_group=None, @@ -1582,16 +1583,13 @@ class SerializedDAG(DAG, BaseSerialization): not pickle-able. SerializedDAG works for all DAGs. """ - _decorated_fields = {"default_args", "_access_control"} + _decorated_fields = {"default_args", "access_control"} @staticmethod def __get_constructor_defaults(): param_to_attr = { - "max_active_tasks": "_max_active_tasks", - "dag_display_name": "_dag_display_property_value", "description": "_description", "default_view": "_default_view", - "access_control": "_access_control", } return { param_to_attr.get(k, k): v.default @@ -1688,7 +1686,7 @@ class SerializedDAG(DAG, BaseSerialization): else: # This must be old data that had no task_group. Create a root TaskGroup and add # all tasks to it. - dag.task_group = TaskGroup.create_root(dag) + object.__setattr__(dag, "task_group", TaskGroup.create_root(dag)) for task in dag.tasks: dag.task_group.add(task) @@ -1711,8 +1709,10 @@ class SerializedDAG(DAG, BaseSerialization): def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): # {} is explicitly different from None in the case of DAG-level access control # and as a result we need to preserve empty dicts through serialization for this field - if attrname == "_access_control" and var is not None: + if attrname == "access_control" and var is not None: return False + if attrname == "dag_display_name" and var == op.dag_id: + return True return super()._is_excluded(var, attrname, op) @classmethod diff --git a/dev/breeze/doc/images/output_build-docs.txt b/dev/breeze/doc/images/output_build-docs.txt index 4bebe1e1639..644e9bf7676 100644 --- a/dev/breeze/doc/images/output_build-docs.txt +++ b/dev/breeze/doc/images/output_build-docs.txt @@ -1 +1 @@ -03dd58933b63fc368157f716b1852e1b +b349182dab04b6ff58acd122a403a5a4 diff --git a/dev/breeze/doc/images/output_prod-image.txt b/dev/breeze/doc/images/output_prod-image.txt index 4e4ac97bd60..c767ee09d4f 100644 --- a/dev/breeze/doc/images/output_prod-image.txt +++ b/dev/breeze/doc/images/output_prod-image.txt @@ -1 +1 @@ -55030fe0d7718eb668fa1a37128647b0 +d91bcc76b14f186e749efe2c6aaa8682 diff --git a/dev/breeze/doc/images/output_prod-image_build.txt b/dev/breeze/doc/images/output_prod-image_build.txt index 7799e6f009e..9c6c509a669 100644 --- a/dev/breeze/doc/images/output_prod-image_build.txt +++ b/dev/breeze/doc/images/output_prod-image_build.txt @@ -1 +1 @@ -d0214e8e95fcb56c91e0e416690eb24f +fe048412f9fc1527a30eaaf0a986fa16 diff --git a/dev/breeze/doc/images/output_setup.txt b/dev/breeze/doc/images/output_setup.txt index b8f9048b91f..274751197da 100644 --- a/dev/breeze/doc/images/output_setup.txt +++ b/dev/breeze/doc/images/output_setup.txt @@ -1 +1 @@ -d4a4f1b405f912fa234ff4116068290a +08c78d9dddd037a2ade6b751c5a22ff9 diff --git a/dev/breeze/doc/images/output_setup_autocomplete.txt b/dev/breeze/doc/images/output_setup_autocomplete.txt index 185feef0264..144c2613cd6 100644 --- a/dev/breeze/doc/images/output_setup_autocomplete.txt +++ b/dev/breeze/doc/images/output_setup_autocomplete.txt @@ -1 +1 @@ -fffcd49e102e09ccd69b3841a9e3ea8e +ec3b4541a478afe5cb86a6f1c48f50f5 diff --git a/dev/breeze/doc/images/output_setup_config.txt b/dev/breeze/doc/images/output_setup_config.txt index 6bc958bebb2..09ae63f0968 100644 --- a/dev/breeze/doc/images/output_setup_config.txt +++ b/dev/breeze/doc/images/output_setup_config.txt @@ -1 +1 @@ -ee2e731f011b1d93dcbfbcaebf6482a6 +235af93483ea83592052476479757683 diff --git a/dev/breeze/doc/images/output_start-airflow.txt b/dev/breeze/doc/images/output_start-airflow.txt index 5811c7ec666..37e9d00ae51 100644 --- a/dev/breeze/doc/images/output_start-airflow.txt +++ b/dev/breeze/doc/images/output_start-airflow.txt @@ -1 +1 @@ -e63a3289a1be34b82c28b606dee0c472 +c880eabfdc882c26e7a2cc4c47a58be3 diff --git a/providers/src/airflow/providers/amazon/aws/operators/comprehend.py b/providers/src/airflow/providers/amazon/aws/operators/comprehend.py index 880440726c6..88bbbf9bf46 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/comprehend.py +++ b/providers/src/airflow/providers/amazon/aws/operators/comprehend.py @@ -17,7 +17,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Sequence from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -55,7 +55,7 @@ class ComprehendBaseOperator(AwsBaseOperator[ComprehendHook]): "input_data_config", "output_data_config", "data_access_role_arn", "language_code" ) - template_fields_renderers: dict = {"input_data_config": "json", "output_data_config": "json"} + template_fields_renderers: ClassVar[dict] = {"input_data_config": "json", "output_data_config": "json"} def __init__( self, @@ -248,7 +248,7 @@ class ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook] "document_classifier_kwargs", ) - template_fields_renderers: dict = { + template_fields_renderers: ClassVar[dict] = { "input_data_config": "json", "output_data_config": "json", "document_classifier_kwargs": "json", diff --git a/providers/src/airflow/providers/amazon/aws/operators/dms.py b/providers/src/airflow/providers/amazon/aws/operators/dms.py index c564f802185..6d64de85e26 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/dms.py +++ b/providers/src/airflow/providers/amazon/aws/operators/dms.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, ClassVar, Sequence from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator @@ -64,7 +64,7 @@ class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]): "migration_type", "create_task_kwargs", ) - template_fields_renderers = { + template_fields_renderers: ClassVar[dict] = { "table_mappings": "json", "create_task_kwargs": "json", } @@ -173,7 +173,7 @@ class DmsDescribeTasksOperator(AwsBaseOperator[DmsHook]): aws_hook_class = DmsHook template_fields: Sequence[str] = aws_template_fields("describe_tasks_kwargs") - template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"} + template_fields_renderers: ClassVar[dict[str, str]] = {"describe_tasks_kwargs": "json"} def __init__(self, *, describe_tasks_kwargs: dict | None = None, **kwargs): super().__init__(**kwargs) diff --git a/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py b/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py index 727aa714c61..93f8bc6b805 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py +++ b/providers/src/airflow/providers/amazon/aws/operators/kinesis_analytics.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Sequence from botocore.exceptions import ClientError @@ -70,7 +70,7 @@ class KinesisAnalyticsV2CreateApplicationOperator(AwsBaseOperator[KinesisAnalyti "create_application_kwargs", "application_description", ) - template_fields_renderers: dict = { + template_fields_renderers: ClassVar[dict] = { "create_application_kwargs": "json", } @@ -149,7 +149,7 @@ class KinesisAnalyticsV2StartApplicationOperator(AwsBaseOperator[KinesisAnalytic "application_name", "run_configuration", ) - template_fields_renderers: dict = { + template_fields_renderers: ClassVar[dict] = { "run_configuration": "json", } diff --git a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py index 57a91945262..e2fad21f09d 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -20,7 +20,7 @@ import datetime import json import time from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Sequence from botocore.exceptions import ClientError @@ -65,7 +65,7 @@ class SageMakerBaseOperator(BaseOperator): template_fields: Sequence[str] = ("config",) template_ext: Sequence[str] = () - template_fields_renderers: dict = {"config": "json"} + template_fields_renderers: ClassVar[dict] = {"config": "json"} ui_color: str = "#ededed" integer_fields: list[list[Any]] = [] diff --git a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py index 97e13398ffb..9c49845b420 100644 --- a/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import copy import datetime import itertools import logging @@ -24,7 +25,7 @@ import os import random import uuid import warnings -from typing import TYPE_CHECKING, Any, Callable, Collection, Container, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Collection, Container, Iterable, Mapping, Sequence import jwt import packaging.version @@ -1107,7 +1108,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): def sync_perm_for_dag( self, dag_id: str, - access_control: dict[str, dict[str, Collection[str]] | Collection[str]] | None = None, + access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]] | None = None, ) -> None: """ Sync permissions for given dag id. @@ -1128,7 +1129,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): if access_control is not None: self.log.debug("Syncing DAG-level permissions for DAG '%s'", dag_id) - self._sync_dag_view_permissions(dag_id, access_control.copy()) + self._sync_dag_view_permissions(dag_id, copy.copy(access_control)) else: self.log.debug( "Not syncing DAG-level permissions for DAG '%s' as access control is unset.", @@ -1149,7 +1150,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): def _sync_dag_view_permissions( self, dag_id: str, - access_control: dict[str, dict[str, Collection[str]] | Collection[str]], + access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]], ) -> None: """ Set the access policy on the given DAG's ViewModel. diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index e99ad835fb7..fc16682a63c 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -565,7 +565,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): template_fields: Collection[str] = () template_ext: Sequence[str] = () - template_fields_renderers: dict[str, str] = field(default_factory=dict, init=False) + template_fields_renderers: ClassVar[dict[str, str]] = {} # Defines the color in the UI ui_color: str = "#fff" @@ -720,8 +720,6 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): task_group.add(self) super().__init__() - if dag is not None: - self.dag = dag self.task_group = task_group kwargs.pop("_airflow_mapped_validation_only", None) @@ -864,6 +862,11 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): if SetupTeardownContext.active: SetupTeardownContext.update_context_map(self) + # We set self.dag right at the end as `_convert_dag` calls `dag.add_task` for us, and we need all the + # other properties to be set at that point + if dag is not None: + self.dag = dag + validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) def __eq__(self, other): @@ -1050,6 +1053,9 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): from airflow.utils.operator_resources import Resources + if isinstance(resources, Resources): + return resources + return Resources(**resources) def _convert_is_setup(self, value: bool) -> bool: @@ -1177,14 +1183,14 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): "_BaseOperator__instantiated", "_BaseOperator__init_kwargs", "_BaseOperator__from_mapped", - "_on_failure_fail_dagrun", + "on_failure_fail_dagrun", + "task_group", + "_task_type", } - | { # Class level defaults need to be added to this list + | { # Class level defaults, or `@property` need to be added to this list "start_date", "end_date", - "_task_type", - "_operator_name", - "subdag", + "task_type", "ui_color", "ui_fgcolor", "template_ext", @@ -1198,6 +1204,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): "start_trigger_args", "_needs_expansion", "start_from_trigger", + "max_retry_delay", } ) DagContext.pop() diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 9cc24828458..da7a2efa6ff 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -382,7 +382,6 @@ class DAG: timezone: FixedTimezone | Timezone = attrs.field(init=False) schedule: ScheduleArg = attrs.field(default=None, on_setattr=attrs.setters.frozen) timetable: Timetable = attrs.field(init=False) - full_filepath: str | None = None template_searchpath: str | Iterable[str] | None = attrs.field( default=None, converter=_convert_str_to_tuple ) @@ -448,6 +447,10 @@ class DAG: self.start_date = timezone.convert_to_utc(self.start_date) self.end_date = timezone.convert_to_utc(self.end_date) + if "start_date" in self.default_args: + self.default_args["start_date"] = self.start_date + if "end_date" in self.default_args: + self.default_args["end_date"] = self.end_date @params.validator def _validate_params(self, _, params: ParamsDict): diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index 26b1f6c45e4..70f4537aa1a 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -69,6 +69,12 @@ def _default_parent_group() -> TaskGroup | None: return TaskGroupContext.get_current() +def _parent_used_group_ids(tg: TaskGroup) -> set: + if tg.parent_group: + return tg.parent_group.used_group_ids + return set() + + # This could be achieved with `@dag.default` and make this a method, but for some unknown reason when we do # that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track that this is an Attrs class. So # we've gone with this and moved on with our lives, mypy is to much of a dark beast to battle over this. @@ -124,7 +130,11 @@ class TaskGroup(DAGNode): upstream_task_ids: set[str] = attrs.field(factory=set, init=False) downstream_task_ids: set[str] = attrs.field(factory=set, init=False) - used_group_ids: set[str] = attrs.field(factory=set, init=False, on_setattr=attrs.setters.frozen) + used_group_ids: set[str] = attrs.field( + default=attrs.Factory(_parent_used_group_ids, takes_self=True), + init=False, + on_setattr=attrs.setters.frozen, + ) ui_color: str = "CornflowerBlue" ui_fgcolor: str = "#000" @@ -142,7 +152,6 @@ class TaskGroup(DAGNode): self._check_for_group_id_collisions(self.add_suffix_on_collision) if self.parent_group: - object.__setattr__(self, "used_group_ids", self.parent_group.used_group_ids) self.parent_group.add(self) if self.parent_group.default_args: self.default_args = {**self.parent_group.default_args, **self.default_args} @@ -157,7 +166,7 @@ class TaskGroup(DAGNode): return # if given group_id already used assign suffix by incrementing largest used suffix integer # Example : task_group ==> task_group__1 -> task_group__2 -> task_group__3 - if self._group_id in self.used_group_ids: + if self.group_id in self.used_group_ids: if not add_suffix_on_collision: raise DuplicateTaskIdFound(f"group_id '{self._group_id}' has already been added to the DAG") base = re2.split(r"__\d+$", self._group_id)[0] @@ -178,7 +187,7 @@ class TaskGroup(DAGNode): @property def node_id(self): - return self._group_id + return self.group_id @property def is_root(self) -> bool: @@ -250,9 +259,9 @@ class TaskGroup(DAGNode): @property def group_id(self) -> str | None: """group_id of this TaskGroup.""" - if self.parent_group and self.parent_group.prefix_group_id and self.parent_group.node_id: + if self.parent_group and self.parent_group.prefix_group_id and self.parent_group._group_id: # defer to parent whether it adds a prefix - return self.parent_group.child_id(self.group_id) + return self.parent_group.child_id(self._group_id) return self._group_id diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index a012d3dbc37..783b04152fd 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -20,6 +20,7 @@ from __future__ import annotations import copy +import dataclasses import importlib import importlib.util import json @@ -124,7 +125,7 @@ serialized_simple_dag_ground_truth = { "delta": 86400.0, }, }, - "_task_group": { + "task_group": { "_group_id": None, "prefix_group_id": True, "children": {"bash_task": ("operator", "bash_task"), "custom_task": ("operator", "custom_task")}, @@ -137,7 +138,7 @@ serialized_simple_dag_ground_truth = { "downstream_task_ids": [], }, "is_paused_upon_creation": False, - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "doc_md": "### DAG Tutorial Documentation", "fileloc": None, "_processor_dags_folder": f"{repo_root}/tests/dags", @@ -146,7 +147,6 @@ serialized_simple_dag_ground_truth = { "__type": "operator", "__var": { "task_id": "bash_task", - "owner": "airflow", "retries": 1, "retry_delay": 300.0, "max_retry_delay": 600.0, @@ -158,7 +158,7 @@ serialized_simple_dag_ground_truth = { "template_fields": ["bash_command", "env", "cwd"], "template_fields_renderers": {"bash_command": "bash", "env": "json"}, "bash_command": "echo {{ task.task_id }}", - "_task_type": "BashOperator", + "task_type": "BashOperator", "_task_module": "airflow.providers.standard.operators.bash", "pool": "default_pool", "is_setup": False, @@ -174,7 +174,6 @@ serialized_simple_dag_ground_truth = { }, }, "doc_md": "### Task Tutorial Documentation", - "_log_config_logger_name": "airflow.task.operators", "_needs_expansion": False, "weight_rule": "downstream", "start_trigger_args": None, @@ -196,14 +195,13 @@ serialized_simple_dag_ground_truth = { "template_ext": [], "template_fields": ["bash_command"], "template_fields_renderers": {}, - "_task_type": "CustomOperator", + "task_type": "CustomOperator", "_operator_name": "@custom", "_task_module": "tests_common.test_utils.mock_operators", "pool": "default_pool", "is_setup": False, "is_teardown": False, "on_failure_fail_dagrun": False, - "_log_config_logger_name": "airflow.task.operators", "_needs_expansion": False, "weight_rule": "downstream", "start_trigger_args": None, @@ -212,7 +210,7 @@ serialized_simple_dag_ground_truth = { }, ], "timezone": "UTC", - "_access_control": { + "access_control": { "__type": "dict", "__var": { "test_role": { @@ -456,7 +454,7 @@ class TestStringifiedDAGs: serialized_dag = SerializedDAG.to_dict(dag) SerializedDAG.validate_schema(serialized_dag) - assert serialized_dag["dag"]["_access_control"] == {"__type": "dict", "__var": {}} + assert serialized_dag["dag"]["access_control"] == {"__type": "dict", "__var": {}} @pytest.mark.db_test def test_dag_serialization_unregistered_custom_timetable(self): @@ -491,8 +489,8 @@ class TestStringifiedDAGs: task["__var"] = dict(sorted(task["__var"].items(), key=lambda x: x[0])) tasks.append(task) dag_dict["dag"]["tasks"] = tasks - dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] = sorted( - dag_dict["dag"]["_access_control"]["__var"]["test_role"]["__var"] + dag_dict["dag"]["access_control"]["__var"]["test_role"]["__var"] = sorted( + dag_dict["dag"]["access_control"]["__var"]["test_role"]["__var"] ) return dag_dict @@ -567,7 +565,7 @@ class TestStringifiedDAGs: "timezone", # Need to check fields in it, to exclude functions. "default_args", - "_task_group", + "task_group", "params", "_processor_dags_folder", } @@ -613,7 +611,7 @@ class TestStringifiedDAGs: assert isinstance(serialized_task, SerializedBaseOperator) fields_to_check = task.get_serialized_fields() - { # Checked separately - "_task_type", + "task_type", "_operator_name", # Type is excluded, so don't check it "_log", @@ -675,7 +673,7 @@ class TestStringifiedDAGs: # MappedOperator.operator_class holds a backup of the serialized # data; checking its entirety basically duplicates this validation # function, so we just do some sanity checks. - serialized_task.operator_class["_task_type"] == type(task).__name__ + serialized_task.operator_class["task_type"] == type(task).__name__ if isinstance(serialized_task.operator_class, DecoratedOperator): serialized_task.operator_class["_operator_name"] == task._operator_name @@ -804,7 +802,7 @@ class TestStringifiedDAGs: "__version": 1, "dag": { "default_args": {"__type": "dict", "__var": {}}, - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": __file__, "tasks": [], "timezone": "UTC", @@ -820,7 +818,7 @@ class TestStringifiedDAGs: "__version": 1, "dag": { "default_args": {"__type": "dict", "__var": {}}, - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": __file__, "tasks": [], "timezone": "UTC", @@ -1060,6 +1058,7 @@ class TestStringifiedDAGs: link = simple_task.get_extra_links(ti, GoogleLink.name) assert "https://www.google.com" == link + @pytest.mark.usefixtures("clear_all_logger_handlers") def test_extra_operator_links_logs_error_for_non_registered_extra_links(self, caplog): """ Assert OperatorLinks not registered via Plugins and if it is not an inbuilt Operator Link, @@ -1221,14 +1220,20 @@ class TestStringifiedDAGs: This test verifies that there are no new fields added to BaseOperator. And reminds that tests should be added for it. """ + from airflow.utils.trigger_rule import TriggerRule + base_operator = BaseOperator(task_id="10") - fields = {k: v for (k, v) in vars(base_operator).items() if k in BaseOperator.get_serialized_fields()} + # Return the name of any annotated class property, or anything explicitly listed in serialized fields + field_names = { + fld.name + for fld in dataclasses.fields(BaseOperator) + if fld.name in BaseOperator.get_serialized_fields() + } | BaseOperator.get_serialized_fields() + fields = {k: getattr(base_operator, k) for k in field_names} assert fields == { "_logger_name": None, - "_log_config_logger_name": "airflow.task.operators", - "_post_execute_hook": None, - "_pre_execute_hook": None, - "_task_display_property_value": None, + "_needs_expansion": None, + "_task_display_name": None, "allow_nested_operators": True, "depends_on_past": False, "do_xcom_push": True, @@ -1238,23 +1243,22 @@ class TestStringifiedDAGs: "doc_rst": None, "doc_yaml": None, "downstream_task_ids": set(), + "end_date": None, "email": None, "email_on_failure": True, "email_on_retry": True, "execution_timeout": None, "executor": None, "executor_config": {}, - "ignore_first_depends_on_past": True, + "ignore_first_depends_on_past": False, + "is_setup": False, + "is_teardown": False, "inlets": [], "map_index_template": None, "max_active_tis_per_dag": None, "max_active_tis_per_dagrun": None, "max_retry_delay": None, - "on_execute_callback": None, - "on_failure_callback": None, - "on_retry_callback": None, - "on_skipped_callback": None, - "on_success_callback": None, + "on_failure_fail_dagrun": False, "outlets": [], "owner": "airflow", "params": {}, @@ -1267,8 +1271,18 @@ class TestStringifiedDAGs: "retry_delay": timedelta(0, 300), "retry_exponential_backoff": False, "run_as_user": None, + "sla": None, + "start_date": None, + "start_from_trigger": False, + "start_trigger_args": None, "task_id": "10", - "trigger_rule": "all_success", + "task_type": "BaseOperator", + "template_ext": (), + "template_fields": (), + "template_fields_renderers": {}, + "trigger_rule": TriggerRule.ALL_SUCCESS, + "ui_color": "#fff", + "ui_fgcolor": "#000", "wait_for_downstream": False, "wait_for_past_depends_before_skipping": False, "weight_rule": _DownstreamPriorityWeightStrategy(), @@ -1294,7 +1308,7 @@ class TestStringifiedDAGs: "template_ext": [], "template_fields": ["bash_command"], "template_fields_renderers": {}, - "_task_type": "CustomOperator", + "task_type": "CustomOperator", "_task_module": "tests_common.test_utils.mock_operators", "pool": "default_pool", "ui_color": "#fff", @@ -2052,7 +2066,7 @@ class TestStringifiedDAGs: serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2071,7 +2085,7 @@ class TestStringifiedDAGs: serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2091,7 +2105,7 @@ class TestStringifiedDAGs: serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2108,7 +2122,7 @@ class TestStringifiedDAGs: serialized = { "__version": 1, "dag": { - "_dag_id": "simple_dag", + "dag_id": "simple_dag", "fileloc": "/path/to/file.py", "tasks": [], "timezone": "UTC", @@ -2290,7 +2304,7 @@ def test_operator_expand_serde(): "_is_mapped": True, "_needs_expansion": True, "_task_module": "airflow.providers.standard.operators.bash", - "_task_type": "BashOperator", + "task_type": "BashOperator", "start_trigger_args": None, "start_from_trigger": False, "downstream_task_ids": [], @@ -2323,7 +2337,7 @@ def test_operator_expand_serde(): assert op.deps is MappedOperator.deps_for(BaseOperator) assert op.operator_class == { - "_task_type": "BashOperator", + "task_type": "BashOperator", "_needs_expansion": True, "start_trigger_args": None, "start_from_trigger": False, @@ -2353,7 +2367,7 @@ def test_operator_expand_xcomarg_serde(): "_is_mapped": True, "_needs_expansion": True, "_task_module": "tests_common.test_utils.mock_operators", - "_task_type": "MockOperator", + "task_type": "MockOperator", "downstream_task_ids": [], "expand_input": { "type": "dict-of-lists", @@ -2408,7 +2422,7 @@ def test_operator_expand_kwargs_literal_serde(strict): "_is_mapped": True, "_needs_expansion": True, "_task_module": "tests_common.test_utils.mock_operators", - "_task_type": "MockOperator", + "task_type": "MockOperator", "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", @@ -2463,7 +2477,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict): "_is_mapped": True, "_needs_expansion": True, "_task_module": "tests_common.test_utils.mock_operators", - "_task_type": "MockOperator", + "task_type": "MockOperator", "downstream_task_ids": [], "expand_input": { "type": "list-of-dicts", @@ -2575,7 +2589,7 @@ def test_taskflow_expand_serde(): "_is_mapped": True, "_needs_expansion": True, "_task_module": "airflow.decorators.python", - "_task_type": "_PythonDecoratedOperator", + "task_type": "_PythonDecoratedOperator", "_operator_name": "@task", "downstream_task_ids": [], "partial_kwargs": { @@ -2677,7 +2691,7 @@ def test_taskflow_expand_kwargs_serde(strict): "_is_mapped": True, "_needs_expansion": True, "_task_module": "airflow.decorators.python", - "_task_type": "_PythonDecoratedOperator", + "task_type": "_PythonDecoratedOperator", "_operator_name": "@task", "start_trigger_args": None, "start_from_trigger": False, @@ -2771,7 +2785,7 @@ def test_mapped_task_group_serde(): tg.expand(a=[".", ".."]) ser_dag = SerializedBaseOperator.serialize(dag) - assert ser_dag[Encoding.VAR]["_task_group"]["children"]["tg"] == ( + assert ser_dag[Encoding.VAR]["task_group"]["children"]["tg"] == ( "taskgroup", { "_group_id": "tg", @@ -2831,7 +2845,7 @@ def test_mapped_task_with_operator_extra_links_property(): "template_ext": [], "template_fields": [], "template_fields_renderers": {}, - "_task_type": "_DummyOperator", + "task_type": "_DummyOperator", "_task_module": "tests.serialization.test_dag_serialization", "_is_empty": False, "_is_mapped": True,
