This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 268a400984 Refactor ORM DAG insertion logic (#42358)
268a400984 is described below

commit 268a40098418ae6cf177d89e0654dd1dc157ea3f
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Sep 23 01:34:48 2024 -0700

    Refactor ORM DAG insertion logic (#42358)
    
    * Refactor ORM DAG insertion logic
    
    Basically using the same pattern as the dataset inserts. Making fk
    relation updates more readable.
    
    The SQL queries CAN be made more efficient, but I decided to keep things
    as-is (not worse than before). These are likely not that big a deal.
    
    * Use set operation for readability
---
 airflow/dag_processing/collection.py | 200 ++++++++++++++++++-----------------
 airflow/models/dag.py                |  24 ++---
 2 files changed, 110 insertions(+), 114 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
index 3f75e0b23b..5d54d17b87 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -61,46 +61,28 @@ if TYPE_CHECKING:
 log = logging.getLogger(__name__)
 
 
-def collect_orm_dags(dags: dict[str, DAG], *, session: Session) -> dict[str, 
DagModel]:
-    """
-    Collect DagModel objects from DAG objects.
-
-    An existing DagModel is fetched if there's a matching ID in the database.
-    Otherwise, a new DagModel is created and added to the session.
-    """
+def _find_orm_dags(dag_ids: Iterable[str], *, session: Session) -> dict[str, 
DagModel]:
+    """Find existing DagModel objects from DAG objects."""
     stmt = (
         select(DagModel)
         .options(joinedload(DagModel.tags, innerjoin=False))
-        .where(DagModel.dag_id.in_(dags))
+        .where(DagModel.dag_id.in_(dag_ids))
         .options(joinedload(DagModel.schedule_dataset_references))
         .options(joinedload(DagModel.schedule_dataset_alias_references))
         .options(joinedload(DagModel.task_outlet_dataset_references))
     )
     stmt = with_row_locks(stmt, of=DagModel, session=session)
-    existing_orm_dags = {dm.dag_id: dm for dm in 
session.scalars(stmt).unique()}
+    return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}
+
 
-    for dag_id, dag in dags.items():
-        if dag_id in existing_orm_dags:
-            continue
-        orm_dag = DagModel(dag_id=dag_id)
+def _create_orm_dags(dags: Iterable[DAG], *, session: Session) -> 
Iterator[DagModel]:
+    for dag in dags:
+        orm_dag = DagModel(dag_id=dag.dag_id)
         if dag.is_paused_upon_creation is not None:
             orm_dag.is_paused = dag.is_paused_upon_creation
-        orm_dag.tags = []
-        log.info("Creating ORM DAG for %s", dag_id)
+        log.info("Creating ORM DAG for %s", dag.dag_id)
         session.add(orm_dag)
-        existing_orm_dags[dag_id] = orm_dag
-
-    return existing_orm_dags
-
-
-def create_orm_dag(dag: DAG, session: Session) -> DagModel:
-    orm_dag = DagModel(dag_id=dag.dag_id)
-    if dag.is_paused_upon_creation is not None:
-        orm_dag.is_paused = dag.is_paused_upon_creation
-    orm_dag.tags = []
-    log.info("Creating ORM DAG for %s", dag.dag_id)
-    session.add(orm_dag)
-    return orm_dag
+        yield orm_dag
 
 
 def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select:
@@ -158,75 +140,101 @@ class _RunInfo(NamedTuple):
         )
 
 
-def update_orm_dags(
-    source_dags: dict[str, DAG],
-    target_dags: dict[str, DagModel],
-    *,
-    processor_subdir: str | None = None,
-    session: Session,
-) -> None:
-    """
-    Apply DAG attributes to DagModel objects.
-
-    Objects in ``target_dags`` are modified in-place.
-    """
-    run_info = _RunInfo.calculate(source_dags, session=session)
-
-    for dag_id, dm in sorted(target_dags.items()):
-        dag = source_dags[dag_id]
-        dm.fileloc = dag.fileloc
-        dm.owners = dag.owner
-        dm.is_active = True
-        dm.has_import_errors = False
-        dm.last_parsed_time = utcnow()
-        dm.default_view = dag.default_view
-        dm._dag_display_property_value = dag._dag_display_property_value
-        dm.description = dag.description
-        dm.max_active_tasks = dag.max_active_tasks
-        dm.max_active_runs = dag.max_active_runs
-        dm.max_consecutive_failed_dag_runs = 
dag.max_consecutive_failed_dag_runs
-        dm.has_task_concurrency_limits = any(
-            t.max_active_tis_per_dag is not None or 
t.max_active_tis_per_dagrun is not None for t in dag.tasks
-        )
-        dm.timetable_summary = dag.timetable.summary
-        dm.timetable_description = dag.timetable.description
-        dm.dataset_expression = dag.timetable.dataset_condition.as_expression()
-        dm.processor_subdir = processor_subdir
-
-        last_automated_run: DagRun | None = 
run_info.latest_runs.get(dag.dag_id)
-        if last_automated_run is None:
-            last_automated_data_interval = None
-        else:
-            last_automated_data_interval = 
dag.get_run_data_interval(last_automated_run)
-        if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
-            dm.next_dagrun_create_after = None
+def _update_dag_tags(tag_names: set[str], dm: DagModel, *, session: Session) 
-> None:
+    orm_tags = {t.name: t for t in dm.tags}
+    for name, orm_tag in orm_tags.items():
+        if name not in tag_names:
+            session.delete(orm_tag)
+    dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in 
tag_names.difference(orm_tags))
+
+
+def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *, 
session: Session) -> None:
+    orm_dag_owner_attributes = {obj.owner: obj for obj in dm.dag_owner_links}
+    for owner, obj in orm_dag_owner_attributes.items():
+        try:
+            link = dag_owner_links[owner]
+        except KeyError:
+            session.delete(obj)
         else:
-            dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)
-
-        if not dag.timetable.dataset_condition:
-            dm.schedule_dataset_references = []
-            dm.schedule_dataset_alias_references = []
-        # FIXME: STORE NEW REFERENCES.
-
-        dag_tags = set(dag.tags or ())
-        for orm_tag in (dm_tags := list(dm.tags or [])):
-            if orm_tag.name not in dag_tags:
-                session.delete(orm_tag)
-                dm.tags.remove(orm_tag)
-        orm_tag_names = {t.name for t in dm_tags}
-        for dag_tag in dag_tags:
-            if dag_tag not in orm_tag_names:
-                dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
-                dm.tags.append(dag_tag_orm)
-                session.add(dag_tag_orm)
-
-        dm_links = dm.dag_owner_links or []
-        for dm_link in dm_links:
-            if dm_link not in dag.owner_links:
-                session.delete(dm_link)
-        for owner_name, owner_link in dag.owner_links.items():
-            dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, 
owner=owner_name, link=owner_link)
-            session.add(dag_owner_orm)
+            if obj.link != link:
+                obj.link = link
+    dm.dag_owner_links.extend(
+        DagOwnerAttributes(dag_id=dm.dag_id, owner=owner, link=link)
+        for owner, link in dag_owner_links.items()
+        if owner not in orm_dag_owner_attributes
+    )
+
+
+class DagModelOperation(NamedTuple):
+    """Collect DAG objects and perform database operations for them."""
+
+    dags: dict[str, DAG]
+
+    def add_dags(self, *, session: Session) -> dict[str, DagModel]:
+        orm_dags = _find_orm_dags(self.dags, session=session)
+        orm_dags.update(
+            (model.dag_id, model)
+            for model in _create_orm_dags(
+                (dag for dag_id, dag in self.dags.items() if dag_id not in 
orm_dags),
+                session=session,
+            )
+        )
+        return orm_dags
+
+    def update_dags(
+        self,
+        orm_dags: dict[str, DagModel],
+        *,
+        processor_subdir: str | None = None,
+        session: Session,
+    ) -> None:
+        run_info = _RunInfo.calculate(self.dags, session=session)
+
+        for dag_id, dm in sorted(orm_dags.items()):
+            dag = self.dags[dag_id]
+            dm.fileloc = dag.fileloc
+            dm.owners = dag.owner
+            dm.is_active = True
+            dm.has_import_errors = False
+            dm.last_parsed_time = utcnow()
+            dm.default_view = dag.default_view
+            dm._dag_display_property_value = dag._dag_display_property_value
+            dm.description = dag.description
+            dm.max_active_tasks = dag.max_active_tasks
+            dm.max_active_runs = dag.max_active_runs
+            dm.max_consecutive_failed_dag_runs = 
dag.max_consecutive_failed_dag_runs
+            dm.has_task_concurrency_limits = any(
+                t.max_active_tis_per_dag is not None or 
t.max_active_tis_per_dagrun is not None
+                for t in dag.tasks
+            )
+            dm.timetable_summary = dag.timetable.summary
+            dm.timetable_description = dag.timetable.description
+            dm.dataset_expression = 
dag.timetable.dataset_condition.as_expression()
+            dm.processor_subdir = processor_subdir
+
+            last_automated_run: DagRun | None = 
run_info.latest_runs.get(dag.dag_id)
+            if last_automated_run is None:
+                last_automated_data_interval = None
+            else:
+                last_automated_data_interval = 
dag.get_run_data_interval(last_automated_run)
+            if run_info.num_active_runs.get(dag.dag_id, 0) >= 
dm.max_active_runs:
+                dm.next_dagrun_create_after = None
+            else:
+                dm.calculate_dagrun_date_fields(dag, 
last_automated_data_interval)
+
+            if not dag.timetable.dataset_condition:
+                dm.schedule_dataset_references = []
+                dm.schedule_dataset_alias_references = []
+            # FIXME: STORE NEW REFERENCES.
+
+            if dag.tags:
+                _update_dag_tags(set(dag.tags), dm, session=session)
+            else:  # Optimization: no references at all, just clear everything.
+                dm.tags = []
+            if dag.owner_links:
+                _update_dag_owner_links(dag.owner_links, dm, session=session)
+            else:  # Optimization: no references at all, just clear everything.
+                dm.dag_owner_links = []
 
 
 def _find_all_datasets(dags: Iterable[DAG]) -> Iterator[Dataset]:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 00820585b6..c95d11f3ef 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2644,28 +2644,16 @@ class DAG(LoggingMixin):
         if not dags:
             return
 
-        from airflow.dag_processing.collection import (
-            DatasetModelOperation,
-            collect_orm_dags,
-            create_orm_dag,
-            update_orm_dags,
-        )
+        from airflow.dag_processing.collection import DagModelOperation, 
DatasetModelOperation
 
         log.info("Sync %s DAGs", len(dags))
-        dags_by_ids = {dag.dag_id: dag for dag in dags}
-        del dags
-
-        orm_dags = collect_orm_dags(dags_by_ids, session=session)
-        orm_dags.update(
-            (dag_id, create_orm_dag(dag, session=session))
-            for dag_id, dag in dags_by_ids.items()
-            if dag_id not in orm_dags
-        )
+        dag_op = DagModelOperation({dag.dag_id: dag for dag in dags})
 
-        update_orm_dags(dags_by_ids, orm_dags, 
processor_subdir=processor_subdir, session=session)
-        DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()), 
session=session)
+        orm_dags = dag_op.add_dags(session=session)
+        dag_op.update_dags(orm_dags, processor_subdir=processor_subdir, 
session=session)
+        DagCode.bulk_sync_to_db((dag.fileloc for dag in dags), session=session)
 
-        dataset_op = DatasetModelOperation.collect(dags_by_ids)
+        dataset_op = DatasetModelOperation.collect(dag_op.dags)
 
         orm_datasets = dataset_op.add_datasets(session=session)
         orm_dataset_aliases = dataset_op.add_dataset_aliases(session=session)

Reply via email to