This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8d816fb8c4 Refactor bulk_save_to_db (#42245)
8d816fb8c4 is described below

commit 8d816fb8c455e81b8a186da37c6009316839bfcd
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Sep 19 23:45:43 2024 -0700

    Refactor bulk_save_to_db (#42245)
    
    Co-authored-by: Ephraim Anierobi <[email protected]>
---
 airflow/dag_processing/collection.py    | 408 ++++++++++++++++++++++++++++++++
 airflow/datasets/__init__.py            |  16 +-
 airflow/datasets/manager.py             |   2 -
 airflow/models/dag.py                   | 342 ++------------------------
 airflow/models/taskinstance.py          |   1 +
 airflow/timetables/base.py              |   5 +-
 tests/dag_processing/test_collection.py |  64 +++++
 tests/datasets/test_dataset.py          |   2 +-
 tests/models/test_dag.py                |  41 ----
 9 files changed, 515 insertions(+), 366 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
new file mode 100644
index 0000000000..3f75e0b23b
--- /dev/null
+++ b/airflow/dag_processing/collection.py
@@ -0,0 +1,408 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Utility code that write DAGs in bulk into the database.
+
+This should generally only be called by internal methods such as
+``DagBag._sync_to_db``, ``DAG.bulk_write_to_db``.
+
+:meta private:
+"""
+
+from __future__ import annotations
+
+import itertools
+import logging
+from typing import TYPE_CHECKING, NamedTuple
+
+from sqlalchemy import func, select
+from sqlalchemy.orm import joinedload, load_only
+from sqlalchemy.sql import expression
+
+from airflow.datasets import Dataset, DatasetAlias
+from airflow.datasets.manager import dataset_manager
+from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag
+from airflow.models.dagrun import DagRun
+from airflow.models.dataset import (
+    DagScheduleDatasetAliasReference,
+    DagScheduleDatasetReference,
+    DatasetAliasModel,
+    DatasetModel,
+    TaskOutletDatasetReference,
+)
+from airflow.utils.sqlalchemy import with_row_locks
+from airflow.utils.timezone import utcnow
+from airflow.utils.types import DagRunType
+
+if TYPE_CHECKING:
+    from collections.abc import Collection, Iterable, Iterator
+
+    from sqlalchemy.orm import Session
+    from sqlalchemy.sql import Select
+
+    from airflow.typing_compat import Self
+
+log = logging.getLogger(__name__)
+
+
+def collect_orm_dags(dags: dict[str, DAG], *, session: Session) -> dict[str, 
DagModel]:
+    """
+    Collect DagModel objects from DAG objects.
+
+    An existing DagModel is fetched if there's a matching ID in the database.
+    Otherwise, a new DagModel is created and added to the session.
+    """
+    stmt = (
+        select(DagModel)
+        .options(joinedload(DagModel.tags, innerjoin=False))
+        .where(DagModel.dag_id.in_(dags))
+        .options(joinedload(DagModel.schedule_dataset_references))
+        .options(joinedload(DagModel.schedule_dataset_alias_references))
+        .options(joinedload(DagModel.task_outlet_dataset_references))
+    )
+    stmt = with_row_locks(stmt, of=DagModel, session=session)
+    existing_orm_dags = {dm.dag_id: dm for dm in 
session.scalars(stmt).unique()}
+
+    for dag_id, dag in dags.items():
+        if dag_id in existing_orm_dags:
+            continue
+        orm_dag = DagModel(dag_id=dag_id)
+        if dag.is_paused_upon_creation is not None:
+            orm_dag.is_paused = dag.is_paused_upon_creation
+        orm_dag.tags = []
+        log.info("Creating ORM DAG for %s", dag_id)
+        session.add(orm_dag)
+        existing_orm_dags[dag_id] = orm_dag
+
+    return existing_orm_dags
+
+
+def create_orm_dag(dag: DAG, session: Session) -> DagModel:
+    orm_dag = DagModel(dag_id=dag.dag_id)
+    if dag.is_paused_upon_creation is not None:
+        orm_dag.is_paused = dag.is_paused_upon_creation
+    orm_dag.tags = []
+    log.info("Creating ORM DAG for %s", dag.dag_id)
+    session.add(orm_dag)
+    return orm_dag
+
+
+def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select:
+    """Build a select statement to retrieve the last automated run for each 
dag."""
+    if len(dag_ids) == 1:  # Index optimized fast path to avoid more 
complicated & slower groupby queryplan.
+        (dag_id,) = dag_ids
+        last_automated_runs_subq = (
+            select(func.max(DagRun.execution_date).label("max_execution_date"))
+            .where(
+                DagRun.dag_id == dag_id,
+                DagRun.run_type.in_((DagRunType.BACKFILL_JOB, 
DagRunType.SCHEDULED)),
+            )
+            .scalar_subquery()
+        )
+        query = select(DagRun).where(
+            DagRun.dag_id == dag_id,
+            DagRun.execution_date == last_automated_runs_subq,
+        )
+    else:
+        last_automated_runs_subq = (
+            select(DagRun.dag_id, 
func.max(DagRun.execution_date).label("max_execution_date"))
+            .where(
+                DagRun.dag_id.in_(dag_ids),
+                DagRun.run_type.in_((DagRunType.BACKFILL_JOB, 
DagRunType.SCHEDULED)),
+            )
+            .group_by(DagRun.dag_id)
+            .subquery()
+        )
+        query = select(DagRun).where(
+            DagRun.dag_id == last_automated_runs_subq.c.dag_id,
+            DagRun.execution_date == 
last_automated_runs_subq.c.max_execution_date,
+        )
+    return query.options(
+        load_only(
+            DagRun.dag_id,
+            DagRun.execution_date,
+            DagRun.data_interval_start,
+            DagRun.data_interval_end,
+        )
+    )
+
+
+class _RunInfo(NamedTuple):
+    latest_runs: dict[str, DagRun]
+    num_active_runs: dict[str, int]
+
+    @classmethod
+    def calculate(cls, dags: dict[str, DAG], *, session: Session) -> Self:
+        # Skip these queries entirely if no DAGs can be scheduled to save time.
+        if not any(dag.timetable.can_be_scheduled for dag in dags.values()):
+            return cls({}, {})
+        return cls(
+            {run.dag_id: run for run in 
session.scalars(_get_latest_runs_stmt(dag_ids=dags))},
+            DagRun.active_runs_of_dags(dag_ids=dags, session=session),
+        )
+
+
+def update_orm_dags(
+    source_dags: dict[str, DAG],
+    target_dags: dict[str, DagModel],
+    *,
+    processor_subdir: str | None = None,
+    session: Session,
+) -> None:
+    """
+    Apply DAG attributes to DagModel objects.
+
+    Objects in ``target_dags`` are modified in-place.
+    """
+    run_info = _RunInfo.calculate(source_dags, session=session)
+
+    for dag_id, dm in sorted(target_dags.items()):
+        dag = source_dags[dag_id]
+        dm.fileloc = dag.fileloc
+        dm.owners = dag.owner
+        dm.is_active = True
+        dm.has_import_errors = False
+        dm.last_parsed_time = utcnow()
+        dm.default_view = dag.default_view
+        dm._dag_display_property_value = dag._dag_display_property_value
+        dm.description = dag.description
+        dm.max_active_tasks = dag.max_active_tasks
+        dm.max_active_runs = dag.max_active_runs
+        dm.max_consecutive_failed_dag_runs = 
dag.max_consecutive_failed_dag_runs
+        dm.has_task_concurrency_limits = any(
+            t.max_active_tis_per_dag is not None or 
t.max_active_tis_per_dagrun is not None for t in dag.tasks
+        )
+        dm.timetable_summary = dag.timetable.summary
+        dm.timetable_description = dag.timetable.description
+        dm.dataset_expression = dag.timetable.dataset_condition.as_expression()
+        dm.processor_subdir = processor_subdir
+
+        last_automated_run: DagRun | None = 
run_info.latest_runs.get(dag.dag_id)
+        if last_automated_run is None:
+            last_automated_data_interval = None
+        else:
+            last_automated_data_interval = 
dag.get_run_data_interval(last_automated_run)
+        if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
+            dm.next_dagrun_create_after = None
+        else:
+            dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)
+
+        if not dag.timetable.dataset_condition:
+            dm.schedule_dataset_references = []
+            dm.schedule_dataset_alias_references = []
+        # FIXME: STORE NEW REFERENCES.
+
+        dag_tags = set(dag.tags or ())
+        for orm_tag in (dm_tags := list(dm.tags or [])):
+            if orm_tag.name not in dag_tags:
+                session.delete(orm_tag)
+                dm.tags.remove(orm_tag)
+        orm_tag_names = {t.name for t in dm_tags}
+        for dag_tag in dag_tags:
+            if dag_tag not in orm_tag_names:
+                dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
+                dm.tags.append(dag_tag_orm)
+                session.add(dag_tag_orm)
+
+        dm_links = dm.dag_owner_links or []
+        for dm_link in dm_links:
+            if dm_link not in dag.owner_links:
+                session.delete(dm_link)
+        for owner_name, owner_link in dag.owner_links.items():
+            dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, 
owner=owner_name, link=owner_link)
+            session.add(dag_owner_orm)
+
+
+def _find_all_datasets(dags: Iterable[DAG]) -> Iterator[Dataset]:
+    for dag in dags:
+        for _, dataset in dag.timetable.dataset_condition.iter_datasets():
+            yield dataset
+        for task in dag.task_dict.values():
+            for obj in itertools.chain(task.inlets, task.outlets):
+                if isinstance(obj, Dataset):
+                    yield obj
+
+
+def _find_all_dataset_aliases(dags: Iterable[DAG]) -> Iterator[DatasetAlias]:
+    for dag in dags:
+        for _, alias in dag.timetable.dataset_condition.iter_dataset_aliases():
+            yield alias
+        for task in dag.task_dict.values():
+            for obj in itertools.chain(task.inlets, task.outlets):
+                if isinstance(obj, DatasetAlias):
+                    yield obj
+
+
+class DatasetModelOperation(NamedTuple):
+    """Collect dataset/alias objects from DAGs and perform database operations 
for them."""
+
+    schedule_dataset_references: dict[str, list[Dataset]]
+    schedule_dataset_alias_references: dict[str, list[DatasetAlias]]
+    outlet_references: dict[str, list[tuple[str, Dataset]]]
+    datasets: dict[str, Dataset]
+    dataset_aliases: dict[str, DatasetAlias]
+
+    @classmethod
+    def collect(cls, dags: dict[str, DAG]) -> Self:
+        coll = cls(
+            schedule_dataset_references={
+                dag_id: [dataset for _, dataset in 
dag.timetable.dataset_condition.iter_datasets()]
+                for dag_id, dag in dags.items()
+            },
+            schedule_dataset_alias_references={
+                dag_id: [alias for _, alias in 
dag.timetable.dataset_condition.iter_dataset_aliases()]
+                for dag_id, dag in dags.items()
+            },
+            outlet_references={
+                dag_id: [
+                    (task_id, outlet)
+                    for task_id, task in dag.task_dict.items()
+                    for outlet in task.outlets
+                    if isinstance(outlet, Dataset)
+                ]
+                for dag_id, dag in dags.items()
+            },
+            datasets={dataset.uri: dataset for dataset in 
_find_all_datasets(dags.values())},
+            dataset_aliases={alias.name: alias for alias in 
_find_all_dataset_aliases(dags.values())},
+        )
+        return coll
+
+    def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]:
+        # Optimization: skip all database calls if no datasets were collected.
+        if not self.datasets:
+            return {}
+        orm_datasets: dict[str, DatasetModel] = {
+            dm.uri: dm
+            for dm in 
session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets)))
+        }
+
+        def _resolve_dataset_addition() -> Iterator[DatasetModel]:
+            for uri, dataset in self.datasets.items():
+                try:
+                    dm = orm_datasets[uri]
+                except KeyError:
+                    dm = orm_datasets[uri] = DatasetModel.from_public(dataset)
+                    yield dm
+                else:
+                    # The orphaned flag was bulk-set to True before parsing, 
so we
+                    # don't need to handle rows in the db without a public 
entry.
+                    dm.is_orphaned = expression.false()
+                dm.extra = dataset.extra
+
+        dataset_manager.create_datasets(list(_resolve_dataset_addition()), 
session=session)
+        return orm_datasets
+
+    def add_dataset_aliases(self, *, session: Session) -> dict[str, 
DatasetAliasModel]:
+        # Optimization: skip all database calls if no dataset aliases were 
collected.
+        if not self.dataset_aliases:
+            return {}
+        orm_aliases: dict[str, DatasetAliasModel] = {
+            da.name: da
+            for da in session.scalars(
+                
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases))
+            )
+        }
+        for name, alias in self.dataset_aliases.items():
+            try:
+                da = orm_aliases[name]
+            except KeyError:
+                da = orm_aliases[name] = DatasetAliasModel.from_public(alias)
+                session.add(da)
+        return orm_aliases
+
+    def add_dag_dataset_references(
+        self,
+        dags: dict[str, DagModel],
+        datasets: dict[str, DatasetModel],
+        *,
+        session: Session,
+    ) -> None:
+        # Optimization: No datasets means there are no references to update.
+        if not datasets:
+            return
+        for dag_id, references in self.schedule_dataset_references.items():
+            # Optimization: no references at all; this is faster than repeated 
delete().
+            if not references:
+                dags[dag_id].schedule_dataset_references = []
+                continue
+            referenced_dataset_ids = {dataset.id for dataset in 
(datasets[r.uri] for r in references)}
+            orm_refs = {r.dataset_id: r for r in 
dags[dag_id].schedule_dataset_references}
+            for dataset_id, ref in orm_refs.items():
+                if dataset_id not in referenced_dataset_ids:
+                    session.delete(ref)
+            session.bulk_save_objects(
+                DagScheduleDatasetReference(dataset_id=dataset_id, 
dag_id=dag_id)
+                for dataset_id in referenced_dataset_ids
+                if dataset_id not in orm_refs
+            )
+
+    def add_dag_dataset_alias_references(
+        self,
+        dags: dict[str, DagModel],
+        aliases: dict[str, DatasetAliasModel],
+        *,
+        session: Session,
+    ) -> None:
+        # Optimization: No aliases means there are no references to update.
+        if not aliases:
+            return
+        for dag_id, references in 
self.schedule_dataset_alias_references.items():
+            # Optimization: no references at all; this is faster than repeated 
delete().
+            if not references:
+                dags[dag_id].schedule_dataset_alias_references = []
+                continue
+            referenced_alias_ids = {alias.id for alias in (aliases[r.name] for 
r in references)}
+            orm_refs = {a.alias_id: a for a in 
dags[dag_id].schedule_dataset_alias_references}
+            for alias_id, ref in orm_refs.items():
+                if alias_id not in referenced_alias_ids:
+                    session.delete(ref)
+            session.bulk_save_objects(
+                DagScheduleDatasetAliasReference(alias_id=alias_id, 
dag_id=dag_id)
+                for alias_id in referenced_alias_ids
+                if alias_id not in orm_refs
+            )
+
+    def add_task_dataset_references(
+        self,
+        dags: dict[str, DagModel],
+        datasets: dict[str, DatasetModel],
+        *,
+        session: Session,
+    ) -> None:
+        # Optimization: No datasets means there are no references to update.
+        if not datasets:
+            return
+        for dag_id, references in self.outlet_references.items():
+            # Optimization: no references at all; this is faster than repeated 
delete().
+            if not references:
+                dags[dag_id].task_outlet_dataset_references = []
+                continue
+            referenced_outlets = {
+                (task_id, dataset.id)
+                for task_id, dataset in ((task_id, datasets[d.uri]) for 
task_id, d in references)
+            }
+            orm_refs = {(r.task_id, r.dataset_id): r for r in 
dags[dag_id].task_outlet_dataset_references}
+            for key, ref in orm_refs.items():
+                if key not in referenced_outlets:
+                    session.delete(ref)
+            session.bulk_save_objects(
+                TaskOutletDatasetReference(dataset_id=dataset_id, 
dag_id=dag_id, task_id=task_id)
+                for task_id, dataset_id in referenced_outlets
+                if (task_id, dataset_id) not in orm_refs
+            )
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index cd57078095..6f7ae99ff7 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -194,7 +194,7 @@ class BaseDataset:
     def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
         raise NotImplementedError
 
-    def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
+    def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
         raise NotImplementedError
 
     def iter_dag_dependencies(self, *, source: str, target: str) -> 
Iterator[DagDependency]:
@@ -212,6 +212,12 @@ class DatasetAlias(BaseDataset):
 
     name: str
 
+    def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
+        return iter(())
+
+    def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
+        yield self.name, self
+
     def iter_dag_dependencies(self, *, source: str, target: str) -> 
Iterator[DagDependency]:
         """
         Iterate a dataset alias as dag dependency.
@@ -294,7 +300,7 @@ class Dataset(os.PathLike, BaseDataset):
     def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
         yield self.uri, self
 
-    def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
+    def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
         return iter(())
 
     def evaluate(self, statuses: dict[str, bool]) -> bool:
@@ -339,7 +345,7 @@ class _DatasetBooleanCondition(BaseDataset):
                 yield k, v
                 seen.add(k)
 
-    def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
+    def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
         """Filter dataest aliases in the condition."""
         for o in self.objects:
             yield from o.iter_dataset_aliases()
@@ -399,8 +405,8 @@ class _DatasetAliasCondition(DatasetAny):
         """
         return {"alias": self.name}
 
-    def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
-        yield DatasetAlias(self.name)
+    def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
+        yield self.name, DatasetAlias(self.name)
 
     def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> 
Iterator[DagDependency]:
         """
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 058eef6ab8..19f6913fff 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -62,8 +62,6 @@ class DatasetManager(LoggingMixin):
         """Create new datasets."""
         for dataset_model in dataset_models:
             session.add(dataset_model)
-        session.flush()
-
         for dataset_model in dataset_models:
             self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, 
extra=dataset_model.extra))
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 6545293ccf..6447f7be15 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -74,7 +74,7 @@ from sqlalchemy import (
 )
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import backref, joinedload, load_only, relationship
+from sqlalchemy.orm import backref, relationship
 from sqlalchemy.sql import Select, expression
 
 import airflow.templates
@@ -82,7 +82,6 @@ from airflow import settings, utils
 from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf as airflow_conf, secrets_backend_list
 from airflow.datasets import BaseDataset, Dataset, DatasetAlias, DatasetAll
-from airflow.datasets.manager import dataset_manager
 from airflow.exceptions import (
     AirflowException,
     DuplicateTaskIdFound,
@@ -100,11 +99,7 @@ from airflow.models.baseoperator import BaseOperator
 from airflow.models.dagcode import DagCode
 from airflow.models.dagpickle import DagPickle
 from airflow.models.dagrun import RUN_ID_REGEX, DagRun
-from airflow.models.dataset import (
-    DatasetAliasModel,
-    DatasetDagRunQueue,
-    DatasetModel,
-)
+from airflow.models.dataset import DatasetDagRunQueue
 from airflow.models.param import DagParam, ParamsDict
 from airflow.models.taskinstance import (
     Context,
@@ -2637,7 +2632,7 @@ class DAG(LoggingMixin):
         cls,
         dags: Collection[DAG],
         processor_subdir: str | None = None,
-        session=NEW_SESSION,
+        session: Session = NEW_SESSION,
     ):
         """
         Ensure the DagModel rows for the given dags are up-to-date in the dag 
table in the DB.
@@ -2648,323 +2643,38 @@ class DAG(LoggingMixin):
         if not dags:
             return
 
-        log.info("Sync %s DAGs", len(dags))
-        dag_by_ids = {dag.dag_id: dag for dag in dags}
-
-        dag_ids = set(dag_by_ids)
-        query = (
-            select(DagModel)
-            .options(joinedload(DagModel.tags, innerjoin=False))
-            .where(DagModel.dag_id.in_(dag_ids))
-            .options(joinedload(DagModel.schedule_dataset_references))
-            .options(joinedload(DagModel.schedule_dataset_alias_references))
-            .options(joinedload(DagModel.task_outlet_dataset_references))
-        )
-        query = with_row_locks(query, of=DagModel, session=session)
-        orm_dags: list[DagModel] = session.scalars(query).unique().all()
-        existing_dags: dict[str, DagModel] = {x.dag_id: x for x in orm_dags}
-        missing_dag_ids = dag_ids.difference(existing_dags.keys())
-
-        for missing_dag_id in missing_dag_ids:
-            orm_dag = DagModel(dag_id=missing_dag_id)
-            dag = dag_by_ids[missing_dag_id]
-            if dag.is_paused_upon_creation is not None:
-                orm_dag.is_paused = dag.is_paused_upon_creation
-            orm_dag.tags = []
-            log.info("Creating ORM DAG for %s", dag.dag_id)
-            session.add(orm_dag)
-            orm_dags.append(orm_dag)
-
-        latest_runs: dict[str, DagRun] = {}
-        num_active_runs: dict[str, int] = {}
-        # Skip these queries entirely if no DAGs can be scheduled to save time.
-        if any(dag.timetable.can_be_scheduled for dag in dags):
-            # Get the latest automated dag run for each existing dag as a 
single query (avoid n+1 query)
-            query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys()))
-            latest_runs = {run.dag_id: run for run in session.scalars(query)}
-
-            # Get number of active dagruns for all dags we are processing as a 
single query.
-            num_active_runs = 
DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session)
-
-        filelocs = []
-
-        for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
-            dag = dag_by_ids[orm_dag.dag_id]
-            filelocs.append(dag.fileloc)
-            orm_dag.fileloc = dag.fileloc
-            orm_dag.owners = dag.owner
-            orm_dag.is_active = True
-            orm_dag.has_import_errors = False
-            orm_dag.last_parsed_time = timezone.utcnow()
-            orm_dag.default_view = dag.default_view
-            orm_dag._dag_display_property_value = 
dag._dag_display_property_value
-            orm_dag.description = dag.description
-            orm_dag.max_active_tasks = dag.max_active_tasks
-            orm_dag.max_active_runs = dag.max_active_runs
-            orm_dag.max_consecutive_failed_dag_runs = 
dag.max_consecutive_failed_dag_runs
-            orm_dag.has_task_concurrency_limits = any(
-                t.max_active_tis_per_dag is not None or 
t.max_active_tis_per_dagrun is not None
-                for t in dag.tasks
-            )
-            orm_dag.timetable_summary = dag.timetable.summary
-            orm_dag.timetable_description = dag.timetable.description
-            orm_dag.dataset_expression = 
dag.timetable.dataset_condition.as_expression()
-
-            orm_dag.processor_subdir = processor_subdir
-
-            last_automated_run: DagRun | None = latest_runs.get(dag.dag_id)
-            if last_automated_run is None:
-                last_automated_data_interval = None
-            else:
-                last_automated_data_interval = 
dag.get_run_data_interval(last_automated_run)
-            if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs:
-                orm_dag.next_dagrun_create_after = None
-            else:
-                orm_dag.calculate_dagrun_date_fields(dag, 
last_automated_data_interval)
-
-            dag_tags = set(dag.tags or {})
-            orm_dag_tags = list(orm_dag.tags or [])
-            for orm_tag in orm_dag_tags:
-                if orm_tag.name not in dag_tags:
-                    session.delete(orm_tag)
-                    orm_dag.tags.remove(orm_tag)
-            orm_tag_names = {t.name for t in orm_dag_tags}
-            for dag_tag in dag_tags:
-                if dag_tag not in orm_tag_names:
-                    dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
-                    orm_dag.tags.append(dag_tag_orm)
-                    session.add(dag_tag_orm)
-
-            orm_dag_links = orm_dag.dag_owner_links or []
-            for orm_dag_link in orm_dag_links:
-                if orm_dag_link not in dag.owner_links:
-                    session.delete(orm_dag_link)
-            for owner_name, owner_link in dag.owner_links.items():
-                dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, 
owner=owner_name, link=owner_link)
-                session.add(dag_owner_orm)
-
-        DagCode.bulk_sync_to_db(filelocs, session=session)
-
-        from airflow.datasets import Dataset
-        from airflow.models.dataset import (
-            DagScheduleDatasetAliasReference,
-            DagScheduleDatasetReference,
-            DatasetModel,
-            TaskOutletDatasetReference,
+        from airflow.dag_processing.collection import (
+            DatasetModelOperation,
+            collect_orm_dags,
+            create_orm_dag,
+            update_orm_dags,
         )
 
-        dag_references: dict[str, set[tuple[Literal["dataset", 
"dataset-alias"], str]]] = defaultdict(set)
-        outlet_references = defaultdict(set)
-        # We can't use a set here as we want to preserve order
-        outlet_dataset_models: dict[DatasetModel, None] = {}
-        input_dataset_models: dict[DatasetModel, None] = {}
-        outlet_dataset_alias_models: set[DatasetAliasModel] = set()
-        input_dataset_alias_models: set[DatasetAliasModel] = set()
-
-        # here we go through dags and tasks to check for dataset references
-        # if there are now None and previously there were some, we delete them
-        # if there are now *any*, we add them to the above data structures, and
-        # later we'll persist them to the database.
-        for dag in dags:
-            curr_orm_dag = existing_dags.get(dag.dag_id)
-            if not (dataset_condition := dag.timetable.dataset_condition):
-                if curr_orm_dag:
-                    if curr_orm_dag.schedule_dataset_references:
-                        curr_orm_dag.schedule_dataset_references = []
-                    if curr_orm_dag.schedule_dataset_alias_references:
-                        curr_orm_dag.schedule_dataset_alias_references = []
-            else:
-                for _, dataset in dataset_condition.iter_datasets():
-                    dag_references[dag.dag_id].add(("dataset", dataset.uri))
-                    input_dataset_models[DatasetModel.from_public(dataset)] = 
None
-
-                for dataset_alias in dataset_condition.iter_dataset_aliases():
-                    dag_references[dag.dag_id].add(("dataset-alias", 
dataset_alias.name))
-                    
input_dataset_alias_models.add(DatasetAliasModel.from_public(dataset_alias))
-
-            curr_outlet_references = curr_orm_dag and 
curr_orm_dag.task_outlet_dataset_references
-            for task in dag.tasks:
-                dataset_outlets: list[Dataset] = []
-                dataset_alias_outlets: list[DatasetAlias] = []
-                for outlet in task.outlets:
-                    if isinstance(outlet, Dataset):
-                        dataset_outlets.append(outlet)
-                    elif isinstance(outlet, DatasetAlias):
-                        dataset_alias_outlets.append(outlet)
-
-                if not dataset_outlets:
-                    if curr_outlet_references:
-                        this_task_outlet_refs = [
-                            x
-                            for x in curr_outlet_references
-                            if x.dag_id == dag.dag_id and x.task_id == 
task.task_id
-                        ]
-                        for ref in this_task_outlet_refs:
-                            curr_outlet_references.remove(ref)
-
-                for d in dataset_outlets:
-                    outlet_dataset_models[DatasetModel.from_public(d)] = None
-                    outlet_references[(task.dag_id, task.task_id)].add(d.uri)
-
-                for d_a in dataset_alias_outlets:
-                    
outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a))
-
-        all_dataset_models = outlet_dataset_models
-        all_dataset_models.update(input_dataset_models)
-
-        # store datasets
-        stored_dataset_models: dict[str, DatasetModel] = {}
-        new_dataset_models: list[DatasetModel] = []
-        for dataset in all_dataset_models:
-            stored_dataset_model = session.scalar(
-                select(DatasetModel).where(DatasetModel.uri == 
dataset.uri).limit(1)
-            )
-            if stored_dataset_model:
-                # Some datasets may have been previously unreferenced, and 
therefore orphaned by the
-                # scheduler. But if we're here, then we have found that 
dataset again in our DAGs, which
-                # means that it is no longer an orphan, so set is_orphaned to 
False.
-                stored_dataset_model.is_orphaned = expression.false()
-                stored_dataset_models[stored_dataset_model.uri] = 
stored_dataset_model
-            else:
-                new_dataset_models.append(dataset)
-        dataset_manager.create_datasets(dataset_models=new_dataset_models, 
session=session)
-        stored_dataset_models.update(
-            {dataset_model.uri: dataset_model for dataset_model in 
new_dataset_models}
-        )
-
-        del new_dataset_models
-        del all_dataset_models
-
-        # store dataset aliases
-        all_datasets_alias_models = input_dataset_alias_models | 
outlet_dataset_alias_models
-        stored_dataset_alias_models: dict[str, DatasetAliasModel] = {}
-        new_dataset_alias_models: set[DatasetAliasModel] = set()
-        if all_datasets_alias_models:
-            all_dataset_alias_names = {
-                dataset_alias_model.name for dataset_alias_model in 
all_datasets_alias_models
-            }
-
-            stored_dataset_alias_models = {
-                dsa_m.name: dsa_m
-                for dsa_m in session.scalars(
-                    
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names))
-                ).fetchall()
-            }
-
-            if stored_dataset_alias_models:
-                new_dataset_alias_models = {
-                    dataset_alias_model
-                    for dataset_alias_model in all_datasets_alias_models
-                    if dataset_alias_model.name not in 
stored_dataset_alias_models.keys()
-                }
-            else:
-                new_dataset_alias_models = all_datasets_alias_models
-
-            session.add_all(new_dataset_alias_models)
-        session.flush()
-        stored_dataset_alias_models.update(
-            {
-                dataset_alias_model.name: dataset_alias_model
-                for dataset_alias_model in new_dataset_alias_models
-            }
+        log.info("Sync %s DAGs", len(dags))
+        dags_by_ids = {dag.dag_id: dag for dag in dags}
+        del dags
+
+        orm_dags = collect_orm_dags(dags_by_ids, session=session)
+        orm_dags.update(
+            (dag_id, create_orm_dag(dag, session=session))
+            for dag_id, dag in dags_by_ids.items()
+            if dag_id not in orm_dags
         )
 
-        del new_dataset_alias_models
-        del all_datasets_alias_models
+        update_orm_dags(dags_by_ids, orm_dags, 
processor_subdir=processor_subdir, session=session)
+        DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()), 
session=session)
 
-        # reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias 
references
-        for dag_id, base_dataset_list in dag_references.items():
-            dag_refs_needed = {
-                DagScheduleDatasetReference(
-                    
dataset_id=stored_dataset_models[base_dataset_identifier].id, dag_id=dag_id
-                )
-                if base_dataset_type == "dataset"
-                else DagScheduleDatasetAliasReference(
-                    
alias_id=stored_dataset_alias_models[base_dataset_identifier].id, dag_id=dag_id
-                )
-                for base_dataset_type, base_dataset_identifier in 
base_dataset_list
-            }
+        dataset_op = DatasetModelOperation.collect(dags_by_ids)
 
-            # if isinstance(base_dataset, Dataset)
+        orm_datasets = dataset_op.add_datasets(session=session)
+        orm_dataset_aliases = dataset_op.add_dataset_aliases(session=session)
+        session.flush()  # This populates id so we can create fks in later 
calls.
 
-            dag_refs_stored = (
-                set(existing_dags.get(dag_id).schedule_dataset_references)  # 
type: ignore
-                | 
set(existing_dags.get(dag_id).schedule_dataset_alias_references)  # type: ignore
-                if existing_dags.get(dag_id)
-                else set()
-            )
-            dag_refs_to_add = dag_refs_needed - dag_refs_stored
-            session.bulk_save_objects(dag_refs_to_add)
-            for obj in dag_refs_stored - dag_refs_needed:
-                session.delete(obj)
-
-        existing_task_outlet_refs_dict = defaultdict(set)
-        for dag_id, orm_dag in existing_dags.items():
-            for todr in orm_dag.task_outlet_dataset_references:
-                existing_task_outlet_refs_dict[(dag_id, 
todr.task_id)].add(todr)
-
-        # reconcile task-outlet-dataset references
-        for (dag_id, task_id), uri_list in outlet_references.items():
-            task_refs_needed = {
-                TaskOutletDatasetReference(
-                    dataset_id=stored_dataset_models[uri].id, dag_id=dag_id, 
task_id=task_id
-                )
-                for uri in uri_list
-            }
-            task_refs_stored = existing_task_outlet_refs_dict[(dag_id, 
task_id)]
-            task_refs_to_add = {x for x in task_refs_needed if x not in 
task_refs_stored}
-            session.bulk_save_objects(task_refs_to_add)
-            for obj in task_refs_stored - task_refs_needed:
-                session.delete(obj)
-
-        # Issue SQL/finish "Unit of Work", but let @provide_session commit (or 
if passed a session, let caller
-        # decide when to commit
+        dataset_op.add_dag_dataset_references(orm_dags, orm_datasets, 
session=session)
+        dataset_op.add_dag_dataset_alias_references(orm_dags, 
orm_dataset_aliases, session=session)
+        dataset_op.add_task_dataset_references(orm_dags, orm_datasets, 
session=session)
         session.flush()
 
-    @classmethod
-    def _get_latest_runs_stmt(cls, dags: list[str]) -> Select:
-        """
-        Build a select statement for retrieve the last automated run for each 
dag.
-
-        :param dags: dags to query
-        """
-        if len(dags) == 1:
-            # Index optimized fast path to avoid more complicated & slower 
groupby queryplan
-            existing_dag_id = dags[0]
-            last_automated_runs_subq = (
-                
select(func.max(DagRun.execution_date).label("max_execution_date"))
-                .where(
-                    DagRun.dag_id == existing_dag_id,
-                    DagRun.run_type.in_((DagRunType.BACKFILL_JOB, 
DagRunType.SCHEDULED)),
-                )
-                .scalar_subquery()
-            )
-            query = select(DagRun).where(
-                DagRun.dag_id == existing_dag_id, DagRun.execution_date == 
last_automated_runs_subq
-            )
-        else:
-            last_automated_runs_subq = (
-                select(DagRun.dag_id, 
func.max(DagRun.execution_date).label("max_execution_date"))
-                .where(
-                    DagRun.dag_id.in_(dags),
-                    DagRun.run_type.in_((DagRunType.BACKFILL_JOB, 
DagRunType.SCHEDULED)),
-                )
-                .group_by(DagRun.dag_id)
-                .subquery()
-            )
-            query = select(DagRun).where(
-                DagRun.dag_id == last_automated_runs_subq.c.dag_id,
-                DagRun.execution_date == 
last_automated_runs_subq.c.max_execution_date,
-            )
-        return query.options(
-            load_only(
-                DagRun.dag_id,
-                DagRun.execution_date,
-                DagRun.data_interval_start,
-                DagRun.data_interval_end,
-            )
-        )
-
     @provide_session
     def sync_to_db(self, processor_subdir: str | None = None, 
session=NEW_SESSION):
         """
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index f93c90b638..954e5ed4d0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2925,6 +2925,7 @@ class TaskInstance(Base, LoggingMixin):
                 dataset_obj = DatasetModel(uri=uri)
                 dataset_manager.create_datasets(dataset_models=[dataset_obj], 
session=session)
                 self.log.warning("Created a new %r as it did not exist.", 
dataset_obj)
+                session.flush()
                 dataset_objs_cache[uri] = dataset_obj
 
             for alias in alias_names:
diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py
index ce701794a4..5d97591856 100644
--- a/airflow/timetables/base.py
+++ b/airflow/timetables/base.py
@@ -24,7 +24,7 @@ from airflow.typing_compat import Protocol, runtime_checkable
 if TYPE_CHECKING:
     from pendulum import DateTime
 
-    from airflow.datasets import Dataset
+    from airflow.datasets import Dataset, DatasetAlias
     from airflow.serialization.dag_dependency import DagDependency
     from airflow.utils.types import DagRunType
 
@@ -57,6 +57,9 @@ class _NullDataset(BaseDataset):
     def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
         return iter(())
 
+    def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
+        return iter(())
+
     def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]:
         return iter(())
 
diff --git a/tests/dag_processing/test_collection.py 
b/tests/dag_processing/test_collection.py
new file mode 100644
index 0000000000..4d5a6736ad
--- /dev/null
+++ b/tests/dag_processing/test_collection.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import warnings
+
+from sqlalchemy.exc import SAWarning
+
+from airflow.dag_processing.collection import _get_latest_runs_stmt
+
+
+def test_statement_latest_runs_one_dag():
+    with warnings.catch_warnings():
+        warnings.simplefilter("error", category=SAWarning)
+
+        stmt = _get_latest_runs_stmt(["fake-dag"])
+        compiled_stmt = str(stmt.compile())
+        actual = [x.strip() for x in compiled_stmt.splitlines()]
+        expected = [
+            "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
+            "dag_run.data_interval_start, dag_run.data_interval_end",
+            "FROM dag_run",
+            "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
+            "SELECT max(dag_run.logical_date) AS max_execution_date",
+            "FROM dag_run",
+            "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN 
(__[POSTCOMPILE_run_type_1]))",
+        ]
+        assert actual == expected, compiled_stmt
+
+
+def test_statement_latest_runs_many_dag():
+    with warnings.catch_warnings():
+        warnings.simplefilter("error", category=SAWarning)
+
+        stmt = _get_latest_runs_stmt(["fake-dag-1", "fake-dag-2"])
+        compiled_stmt = str(stmt.compile())
+        actual = [x.strip() for x in compiled_stmt.splitlines()]
+        expected = [
+            "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
+            "dag_run.data_interval_start, dag_run.data_interval_end",
+            "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
+            "max(dag_run.logical_date) AS max_execution_date",
+            "FROM dag_run",
+            "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
+            "AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY 
dag_run.dag_id) AS anon_1",
+            "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = 
anon_1.max_execution_date",
+        ]
+        assert actual == expected, compiled_stmt
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index 940e445669..8221a5aea8 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -145,7 +145,7 @@ def test_dataset_iter_dataset_aliases():
         DatasetAll(DatasetAlias("example-alias-5"), Dataset("5")),
     )
     assert list(base_dataset.iter_dataset_aliases()) == [
-        DatasetAlias(f"example-alias-{i}") for i in range(1, 6)
+        (f"example-alias-{i}", DatasetAlias(f"example-alias-{i}")) for i in 
range(1, 6)
     ]
 
 
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index cb6d4d4ed4..093fdcae2f 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -23,7 +23,6 @@ import logging
 import os
 import pickle
 import re
-import warnings
 import weakref
 from datetime import timedelta
 from importlib import reload
@@ -37,7 +36,6 @@ import pendulum
 import pytest
 import time_machine
 from sqlalchemy import inspect, select
-from sqlalchemy.exc import SAWarning
 
 from airflow import settings
 from airflow.configuration import conf
@@ -3992,42 +3990,3 @@ class TestTaskClearingSetupTeardownBehavior:
                 Exception, match="Setup tasks must be followed with trigger 
rule ALL_SUCCESS."
             ):
                 dag.validate_setup_teardown()
-
-
-def test_statement_latest_runs_one_dag():
-    with warnings.catch_warnings():
-        warnings.simplefilter("error", category=SAWarning)
-
-        stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
-        compiled_stmt = str(stmt.compile())
-        actual = [x.strip() for x in compiled_stmt.splitlines()]
-        expected = [
-            "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
-            "dag_run.data_interval_start, dag_run.data_interval_end",
-            "FROM dag_run",
-            "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
-            "SELECT max(dag_run.logical_date) AS max_execution_date",
-            "FROM dag_run",
-            "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN 
(__[POSTCOMPILE_run_type_1]))",
-        ]
-        assert actual == expected, compiled_stmt
-
-
-def test_statement_latest_runs_many_dag():
-    with warnings.catch_warnings():
-        warnings.simplefilter("error", category=SAWarning)
-
-        stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
-        compiled_stmt = str(stmt.compile())
-        actual = [x.strip() for x in compiled_stmt.splitlines()]
-        expected = [
-            "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
-            "dag_run.data_interval_start, dag_run.data_interval_end",
-            "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
-            "max(dag_run.logical_date) AS max_execution_date",
-            "FROM dag_run",
-            "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
-            "AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY 
dag_run.dag_id) AS anon_1",
-            "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = 
anon_1.max_execution_date",
-        ]
-        assert actual == expected, compiled_stmt


Reply via email to