This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit c5273a837e5ce4b1de38799aa16d3b7c43b57bed Author: Kaxil Naik <[email protected]> AuthorDate: Tue Oct 22 22:24:42 2024 +0100 More fixes to test_dagbag.py [skip ci] [skip ci] --- airflow/models/dag.py | 8 -------- airflow/serialization/schema.json | 2 +- task_sdk/src/airflow/sdk/definitions/baseoperator.py | 2 +- task_sdk/src/airflow/sdk/definitions/dag.py | 15 ++++++++++++++- tests/models/test_dagbag.py | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index e26a696ee7a..61e260f2fbc 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -743,14 +743,6 @@ class DAG(TaskSDKDag, LoggingMixin): def max_active_tasks(self, value: int): self._max_active_tasks = value - @property - def access_control(self): - return self._access_control - - @access_control.setter - def access_control(self, value): - self._access_control = DAG._upgrade_outdated_dag_access_control(value) - @property def pickle_id(self) -> int | None: return self._pickle_id diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index e313e2c7af7..f89bd348776 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -158,7 +158,7 @@ }, "orientation": { "type" : "string"}, "dag_display_name": { "type" : "string"}, - "_description": { "type" : "string"}, + "description": { "type" : "string"}, "_concurrency": { "type" : "number"}, "max_active_tasks": { "type" : "number"}, "max_active_runs": { "type" : "number"}, diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index bc2c696ab44..d8e07c44b71 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -563,7 +563,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): template_fields: Collection[str] = () template_ext: Sequence[str] = () - template_fields_renderers: dict[str, str] = field(default_factory=dict, init=False) + template_fields_renderers: ClassVar[dict[str, str]] = {} # Defines the color in the UI ui_color: str = "#fff" diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index ba82f19339e..8224f3a3167 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -353,7 +353,7 @@ class DAG: default=None, converter=attrs.Converter(_convert_params, takes_self=True), # type: ignore[misc, call-overload] ) - access_control: dict | None = None + _access_control: dict | None = None is_paused_upon_creation: bool | None = None jinja_environment_kwargs: dict | None = None render_template_as_native_obj: bool = attrs.field(default=False, converter=bool) @@ -381,6 +381,8 @@ class DAG: self.start_date = timezone.convert_to_utc(self.start_date) self.end_date = timezone.convert_to_utc(self.end_date) + # This should trigger the setters for access_control + self.access_control = self.access_control @fileloc.default def _default_fileloc(self) -> str: @@ -684,6 +686,17 @@ class DAG: result._log = self._log # type: ignore[attr-defined] return result + @property + def access_control(self): + return self._access_control + + @access_control.setter + def access_control(self, value): + if hasattr(self, "_upgrade_outdated_dag_access_control"): + self._access_control = self._upgrade_outdated_dag_access_control(value) + else: + self._access_control = value + def partial_subset( self, task_ids_or_regex: str | Pattern | Iterable[str], diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 89e899543eb..f563c72f545 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -667,7 +667,7 @@ with airflow.DAG( """Test that dagbag.sync_to_db is retried on OperationalError""" dagbag = DagBag("/dev/null") - mock_dag = mock.MagicMock(spec=DAG) + mock_dag = mock.MagicMock() dagbag.dags["mock_dag"] = mock_dag op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)
