This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new de5fb662504 Correctly set `owners` on DagModel after TaskSDK parsing
change (#45094)
de5fb662504 is described below
commit de5fb662504735f8654e7292588a661c8a96aede
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Dec 19 22:08:19 2024 +0000
Correctly set `owners` on DagModel after TaskSDK parsing change (#45094)
* Correctly set `owners` on DagModel after TaskSDK parsing change
It turns out this behaviour wasn't adequetely tested, at least not using
serialized DAGs
This test fixes this case, and adds/moves some of the tests from
models/test_dag.py into test_collection which more accurately represents
where
the code it is testing now lives.
---
airflow/dag_processing/collection.py | 2 +-
airflow/serialization/serialized_objects.py | 6 +++
tests/dag_processing/test_collection.py | 71 +++++++++++++++++++++++++++++
tests/models/test_dag.py | 56 -----------------------
4 files changed, 78 insertions(+), 57 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index 1babf238184..f3e3b8322ca 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -383,7 +383,7 @@ class DagModelOperation(NamedTuple):
for dag_id, dm in sorted(orm_dags.items()):
dag = self.dags[dag_id]
dm.fileloc = dag.fileloc
- dm.owners = dag.owner
+ dm.owners = dag.owner or conf.get("operators", "default_owner")
dm.is_active = True
dm.has_import_errors = False
dm.last_parsed_time = utcnow()
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index f1aa274d44b..514d32b822e 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -1930,6 +1930,12 @@ class LazyDeserializedDAG(pydantic.BaseModel):
def has_task_concurrency_limits(self) -> bool:
return any(task.get("max_active_tis_per_dag") is not None for task in
self.data["dag"]["tasks"])
+ @property
+ def owner(self) -> str:
+ return ", ".join(
+ set(filter(None, (task[Encoding.VAR].get("owner") for task in
self.data["dag"]["tasks"])))
+ )
+
def get_task_assets(
self,
inlets: bool = True,
diff --git a/tests/dag_processing/test_collection.py
b/tests/dag_processing/test_collection.py
index 3c539bb50ef..ca435cc1a4f 100644
--- a/tests/dag_processing/test_collection.py
+++ b/tests/dag_processing/test_collection.py
@@ -31,6 +31,7 @@ from sqlalchemy import func, select
from sqlalchemy.exc import OperationalError, SAWarning
import airflow.dag_processing.collection
+from airflow.configuration import conf
from airflow.dag_processing.collection import (
AssetModelOperation,
_get_latest_runs_stmt,
@@ -50,6 +51,7 @@ from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
from airflow.sdk.definitions.asset import Asset
+from airflow.serialization.serialized_objects import LazyDeserializedDAG,
SerializedDAG
from airflow.utils import timezone as tz
from airflow.utils.session import create_session
@@ -177,6 +179,10 @@ class TestUpdateDagParsingResults:
get_listener_manager().clear()
dag_import_error_listener.clear()
+ def dag_to_lazy_serdag(self, dag: DAG) -> LazyDeserializedDAG:
+ ser_dict = SerializedDAG.to_dict(dag)
+ return LazyDeserializedDAG(data=ser_dict)
+
@pytest.mark.usefixtures("clean_db") # sync_perms in fab has bad session
commit hygiene
def test_sync_perms_syncs_dag_specific_perms_on_update(
self, monkeypatch, spy_agency: SpyAgency, session, time_machine
@@ -414,3 +420,68 @@ class TestUpdateDagParsingResults:
spy_agency.assert_spy_called_with(
spy, dag.dag_id, access_control={"Public": {"DAGs": {"can_read"},
"DAG Runs": {"can_create"}}}
)
+
+ @pytest.mark.parametrize(
+ ("attrs", "expected"),
+ [
+ pytest.param(
+ {
+ "_tasks_": [
+ EmptyOperator(task_id="task", owner="owner1"),
+ EmptyOperator(task_id="task2", owner="owner2"),
+ EmptyOperator(task_id="task3"),
+ EmptyOperator(task_id="task4", owner="owner2"),
+ ]
+ },
+ {
+ "default_view": conf.get("webserver",
"dag_default_view").lower(),
+ "owners": ["owner1", "owner2"],
+ },
+ id="tasks-multiple-owners",
+ ),
+ pytest.param(
+ {"is_paused_upon_creation": True},
+ {"is_paused": True},
+ id="is_paused_upon_creation",
+ ),
+ pytest.param(
+ {},
+ {"owners": ["airflow"]},
+ id="default-owner",
+ ),
+ ],
+ )
+ @pytest.mark.usefixtures("clean_db")
+ def test_dagmodel_properties(self, attrs, expected, session, time_machine):
+ """Test that properties on the dag model are correctly set when
dealing with a LazySerializedDag"""
+ dt = tz.datetime(2020, 1, 5, 0, 0, 0)
+ time_machine.move_to(dt, tick=False)
+
+ tasks = attrs.pop("_tasks_", None)
+ dag = DAG("dag", **attrs)
+ if tasks:
+ dag.add_tasks(tasks)
+
+ update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {},
set(), session)
+
+ orm_dag = session.get(DagModel, ("dag",))
+
+ for attrname, expected_value in expected.items():
+ if attrname == "owners":
+ assert sorted(orm_dag.owners.split(", ")) == expected_value
+ else:
+ assert getattr(orm_dag, attrname) == expected_value
+
+ assert orm_dag.last_parsed_time == dt
+
+ def test_existing_dag_is_paused_upon_creation(self, session):
+ dag = DAG("dag_paused", schedule=None)
+ update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {},
set(), session)
+ orm_dag = session.get(DagModel, ("dag_paused",))
+ assert orm_dag.is_paused is False
+
+ dag = DAG("dag_paused", schedule=None, is_paused_upon_creation=True)
+ update_dag_parsing_results_in_db([self.dag_to_lazy_serdag(dag)], {},
set(), session)
+ # Since the dag existed before, it should not follow the pause flag
upon creation
+ orm_dag = session.get(DagModel, ("dag_paused",))
+ assert orm_dag.is_paused is False
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 632b73349ba..582ef1b27a0 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -995,54 +995,6 @@ class TestDag:
assert asset_alias_3_orm.name == "asset_alias_3"
assert len(stored_asset_alias_models) == 3
- def test_sync_to_db(self):
- dag = DAG("dag", start_date=DEFAULT_DATE, schedule=None)
- with dag:
- EmptyOperator(task_id="task", owner="owner1")
- EmptyOperator(task_id="task2", owner="owner2")
- session = settings.Session()
- dag.sync_to_db(session=session)
-
- orm_dag = session.query(DagModel).filter(DagModel.dag_id ==
"dag").one()
- assert set(orm_dag.owners.split(", ")) == {"owner1", "owner2"}
- assert orm_dag.is_active
- assert orm_dag.default_view is not None
- assert orm_dag.default_view == conf.get("webserver",
"dag_default_view").lower()
- assert orm_dag.safe_dag_id == "dag"
- session.close()
-
- def test_sync_to_db_default_view(self):
- dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE,
default_view="graph")
- with dag:
- EmptyOperator(task_id="task", owner="owner1")
- session = settings.Session()
- dag.sync_to_db(session=session)
-
- orm_dag = session.query(DagModel).filter(DagModel.dag_id ==
"dag").one()
- assert orm_dag.default_view is not None
- assert orm_dag.default_view == "graph"
- session.close()
-
- def test_existing_dag_is_paused_upon_creation(self):
- dag = DAG("dag_paused", schedule=None)
- dag.sync_to_db()
- assert not dag.get_is_paused()
-
- dag = DAG("dag_paused", schedule=None, is_paused_upon_creation=True)
- dag.sync_to_db()
- # Since the dag existed before, it should not follow the pause flag
upon creation
- assert not dag.get_is_paused()
-
- def test_new_dag_is_paused_upon_creation(self):
- dag = DAG("new_nonexisting_dag", schedule=None,
is_paused_upon_creation=True)
- session = settings.Session()
- dag.sync_to_db(session=session)
-
- orm_dag = session.query(DagModel).filter(DagModel.dag_id ==
"new_nonexisting_dag").one()
- # Since the dag didn't exist before, it should follow the pause flag
upon creation
- assert orm_dag.is_paused
- session.close()
-
@mock.patch.dict(
os.environ,
{
@@ -1096,14 +1048,6 @@ class TestDag:
dag.clear()
self._clean_up(dag_id)
- def test_existing_dag_default_view(self):
- with create_session() as session:
- session.add(DagModel(dag_id="dag_default_view_old",
default_view=None))
- session.commit()
- orm_dag = session.query(DagModel).filter(DagModel.dag_id ==
"dag_default_view_old").one()
- assert orm_dag.default_view is None
- assert orm_dag.get_default_view() == conf.get("webserver",
"dag_default_view").lower()
-
def test_dag_is_deactivated_upon_dagfile_deletion(self):
dag_id = "old_existing_dag"
dag_fileloc = "/usr/local/airflow/dags/non_existing_path.py"