This is an automated email from the ASF dual-hosted git repository.
rom pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bdab7dc28d Use set instead of list for dags' tags (#41695)
bdab7dc28d is described below
commit bdab7dc28d15aec230b6aa91f5b43d45a9a1dcea
Author: Avihais12344 <[email protected]>
AuthorDate: Sat Sep 21 12:30:42 2024 +0300
Use set instead of list for dags' tags (#41695)
* Started working on dag tags, moved the tags to set, and added atest to
check duplications.
* Fixed more areas of the c'tor.
* Fixed test of dag tags.
* Added a check to see if the tags are mutable.
* Added newsfragment.
* Removed unecessary check.
* Removed str specification at the type, for compatability with python 3.8.
* Removed more type specification as part of compatability with python 3.8
* Fixed the newsfragment.
* Added missing word.
* Used `` for code segemnts at the rst file.
* Reformatted the file.
* Fixed wrong method for adding tag.
* Added type hinting at the dag bag.
* Deserialized the tags to set.
* Adjusted the tests for the set type.
* Added type hinting.
* Sorting the tags by name.
* Changed to typing.
* Update newsfragments/41420.significant.rst
Co-authored-by: Jens Scheffler <[email protected]>
* Update newsfragments/41420.significant.rst
Co-authored-by: Jens Scheffler <[email protected]>
* Removed the generic specification at the dag args expected types, as it
raises the error: Subscripted generics cannot be used with class and instance
checks.
* Added tags to the expected serialized DAG.
* Added sorting the tags keys by the name key.
* Fixed sorting tags by name to use `sorted` instead of `.sort`
* Fixed tags comparesion, as it's now a set, and not a list.
---------
Co-authored-by: Jens Scheffler <[email protected]>
---
airflow/models/dag.py | 9 +++---
airflow/models/dagbag.py | 4 +--
airflow/serialization/serialized_objects.py | 2 ++
newsfragments/41420.significant.rst | 11 +++++++
tests/api_connexion/schemas/test_dag_schema.py | 18 ++++++++++--
tests/models/test_dag.py | 40 ++++++++++++++++++++++++--
tests/models/test_dagbag.py | 8 +++---
tests/models/test_serialized_dag.py | 4 +--
tests/serialization/test_dag_serialization.py | 3 +-
9 files changed, 82 insertions(+), 17 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 388f322d2d..00820585b6 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -42,6 +42,7 @@ from typing import (
Iterable,
Iterator,
List,
+ MutableSet,
Pattern,
Sequence,
Union,
@@ -351,7 +352,7 @@ DAG_ARGS_EXPECTED_TYPES = {
"doc_md": str,
"is_paused_upon_creation": bool,
"render_template_as_native_obj": bool,
- "tags": list,
+ "tags": Collection,
"auto_register": bool,
"fail_stop": bool,
"dag_display_name": str,
@@ -528,7 +529,7 @@ class DAG(LoggingMixin):
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
render_template_as_native_obj: bool = False,
- tags: list[str] | None = None,
+ tags: Collection[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
fail_stop: bool = False,
@@ -678,7 +679,7 @@ class DAG(LoggingMixin):
self.doc_md = self.get_doc_md(doc_md)
- self.tags = tags or []
+ self.tags: MutableSet[str] = set(tags or [])
self._task_group = TaskGroup.create_root(self)
self.validate_schedule_and_params()
wrong_links = dict(self.iter_invalid_owner_links())
@@ -3311,7 +3312,7 @@ def dag(
is_paused_upon_creation: bool | None = None,
jinja_environment_kwargs: dict | None = None,
render_template_as_native_obj: bool = False,
- tags: list[str] | None = None,
+ tags: Collection[str] | None = None,
owner_links: dict[str, str] | None = None,
auto_register: bool = True,
fail_stop: bool = False,
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 8b155e7b52..b2d45a1331 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -266,7 +266,7 @@ class DagBag(LoggingMixin):
"""Add DAG to DagBag from DB."""
from airflow.models.serialized_dag import SerializedDagModel
- row = SerializedDagModel.get(dag_id, session)
+ row: SerializedDagModel | None = SerializedDagModel.get(dag_id,
session)
if not row:
return None
@@ -457,7 +457,7 @@ class DagBag(LoggingMixin):
found_dags.append(dag)
return found_dags
- def bag_dag(self, dag):
+ def bag_dag(self, dag: DAG):
"""
Add the DAG into the bag.
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 998b5ba3f4..12310685ec 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -1673,6 +1673,8 @@ class SerializedDAG(DAG, BaseSerialization):
v = cls.deserialize(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
+ elif k == "tags":
+ v = set(v)
# else use v as it is
setattr(dag, k, v)
diff --git a/newsfragments/41420.significant.rst
b/newsfragments/41420.significant.rst
new file mode 100644
index 0000000000..361b8c7ea9
--- /dev/null
+++ b/newsfragments/41420.significant.rst
@@ -0,0 +1,11 @@
+**Breaking Change**
+
+Replaced Python's ``list`` with ``MutableSet`` for the property ``DAG.tags``.
+
+At the constractur you still can use list,
+you actually can use any data structure that implements the
+``Collection`` interface.
+
+The ``tags`` property of the ``DAG`` model would be of type
+``MutableSet`` instead of ``list``,
+as there are no actual duplicates at the tags.
diff --git a/tests/api_connexion/schemas/test_dag_schema.py
b/tests/api_connexion/schemas/test_dag_schema.py
index 858c62815f..a4a86bc05c 100644
--- a/tests/api_connexion/schemas/test_dag_schema.py
+++ b/tests/api_connexion/schemas/test_dag_schema.py
@@ -185,7 +185,10 @@ def
test_serialize_test_dag_detail_schema(url_safe_serializer):
}
},
"start_date": "2020-06-19T00:00:00+00:00",
- "tags": [{"name": "example1"}, {"name": "example2"}],
+ "tags": sorted(
+ [{"name": "example1"}, {"name": "example2"}],
+ key=lambda val: val["name"],
+ ),
"template_searchpath": None,
"timetable_summary": "1 day, 0:00:00",
"timezone": UTC_JSON_REPR,
@@ -198,6 +201,10 @@ def
test_serialize_test_dag_detail_schema(url_safe_serializer):
}
obj = schema.dump(dag)
expected.update({"last_parsed": obj["last_parsed"]})
+ obj["tags"] = sorted(
+ obj["tags"],
+ key=lambda val: val["name"],
+ )
assert obj == expected
@@ -243,7 +250,10 @@ def
test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali
}
},
"start_date": "2020-06-19T00:00:00+00:00",
- "tags": [{"name": "example1"}, {"name": "example2"}],
+ "tags": sorted(
+ [{"name": "example1"}, {"name": "example2"}],
+ key=lambda val: val["name"],
+ ),
"template_searchpath": None,
"timetable_summary": "Dataset",
"timezone": UTC_JSON_REPR,
@@ -256,4 +266,8 @@ def
test_serialize_test_dag_with_dataset_schedule_detail_schema(url_safe_seriali
}
obj = schema.dump(dag)
expected.update({"last_parsed": obj["last_parsed"]})
+ obj["tags"] = sorted(
+ obj["tags"],
+ key=lambda val: val["name"],
+ )
assert obj == expected
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index d5fd2ad729..90d956caeb 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -805,7 +805,7 @@ class TestDag:
DAG.bulk_write_to_db(dags)
# Adding tags
for dag in dags:
- dag.tags.append("test-dag2")
+ dag.tags.add("test-dag2")
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with create_session() as session:
@@ -843,7 +843,7 @@ class TestDag:
# Removing all tags
for dag in dags:
- dag.tags = None
+ dag.tags = set()
with assert_queries_count(9):
DAG.bulk_write_to_db(dags)
with create_session() as session:
@@ -3383,6 +3383,42 @@ def test__tags_length(tags: list[str], should_pass:
bool):
DAG("test-dag", schedule=None, tags=tags)
[email protected](
+ "input_tags, expected_result",
+ [
+ pytest.param([], set(), id="empty tags"),
+ pytest.param(
+ ["a normal tag"],
+ {"a normal tag"},
+ id="one tag",
+ ),
+ pytest.param(
+ ["a normal tag", "another normal tag"],
+ {"a normal tag", "another normal tag"},
+ id="two different tags",
+ ),
+ pytest.param(
+ ["a", "a"],
+ {"a"},
+ id="two same tags",
+ ),
+ ],
+)
+def test__tags_duplicates(input_tags: list[str], expected_result: set[str]):
+ result = DAG("test-dag", tags=input_tags)
+ assert result.tags == expected_result
+
+
+def test__tags_mutable():
+ expected_tags = {"6", "7"}
+ test_dag = DAG("test-dag")
+ test_dag.tags.add("6")
+ test_dag.tags.add("7")
+ test_dag.tags.add("8")
+ test_dag.tags.remove("8")
+ assert test_dag.tags == expected_tags
+
+
@pytest.mark.need_serialized_dag
def test_get_dataset_triggered_next_run_info(dag_maker, clear_datasets):
dataset1 = Dataset(uri="ds1")
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index a5ec740df2..0179fa8652 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -830,11 +830,11 @@ with airflow.DAG(
# from DB
with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 4)),
tick=False):
with assert_queries_count(0):
- assert dag_bag.get_dag("example_bash_operator").tags ==
["example", "example2"]
+ assert dag_bag.get_dag("example_bash_operator").tags ==
{"example", "example2"}
# Make a change in the DAG and write Serialized DAG to the DB
with time_machine.travel((tz.datetime(2020, 1, 5, 0, 0, 6)),
tick=False):
- example_bash_op_dag.tags += ["new_tag"]
+ example_bash_op_dag.tags.add("new_tag")
SerializedDagModel.write_dag(dag=example_bash_op_dag)
# Since min_serialized_dag_fetch_interval is passed verify that
calling 'dag_bag.get_dag'
@@ -869,7 +869,7 @@ with airflow.DAG(
ser_dag = dag_bag.get_dag("example_bash_operator")
ser_dag_update_time =
dag_bag.dags_last_fetched["example_bash_operator"]
- assert ser_dag.tags == ["example", "example2"]
+ assert ser_dag.tags == {"example", "example2"}
assert ser_dag_update_time == tz.datetime(2020, 1, 5, 1, 0, 10)
with create_session() as session:
@@ -883,7 +883,7 @@ with airflow.DAG(
# Note the date *before* the deserialize step above, simulating a
serialization happening
# long before the transaction is committed
with time_machine.travel((tz.datetime(2020, 1, 5, 1, 0, 0)),
tick=False):
- example_bash_op_dag.tags += ["new_tag"]
+ example_bash_op_dag.tags.add("new_tag")
SerializedDagModel.write_dag(dag=example_bash_op_dag)
# Since min_serialized_dag_fetch_interval is passed verify that
calling 'dag_bag.get_dag'
diff --git a/tests/models/test_serialized_dag.py
b/tests/models/test_serialized_dag.py
index f86aa1b904..9f83280f8e 100644
--- a/tests/models/test_serialized_dag.py
+++ b/tests/models/test_serialized_dag.py
@@ -109,8 +109,8 @@ class TestSerializedDagModel:
assert dag_updated is False
# Update DAG
- example_bash_op_dag.tags += ["new_tag"]
- assert set(example_bash_op_dag.tags) == {"example", "example2",
"new_tag"}
+ example_bash_op_dag.tags.add("new_tag")
+ assert example_bash_op_dag.tags == {"example", "example2",
"new_tag"}
dag_updated = SDM.write_dag(dag=example_bash_op_dag)
s_dag_2 = session.get(SDM, example_bash_op_dag.dag_id)
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 8311aa77e7..7dfe57054c 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -260,6 +260,7 @@ serialized_simple_dag_ground_truth = {
"edge_info": {},
"dag_dependencies": [],
"params": [],
+ "tags": [],
},
}
@@ -587,7 +588,7 @@ class TestStringifiedDAGs:
roundtripped = SerializedDAG.from_json(SerializedDAG.to_json(dag))
self.validate_deserialized_dag(roundtripped, dag)
- def validate_deserialized_dag(self, serialized_dag, dag):
+ def validate_deserialized_dag(self, serialized_dag: DAG, dag: DAG):
"""
Verify that all example DAGs work with DAG Serialization by
checking fields between Serialized Dags & non-Serialized Dags