This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8d816fb8c4 Refactor bulk_save_to_db (#42245)
8d816fb8c4 is described below
commit 8d816fb8c455e81b8a186da37c6009316839bfcd
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Sep 19 23:45:43 2024 -0700
Refactor bulk_save_to_db (#42245)
Co-authored-by: Ephraim Anierobi <[email protected]>
---
airflow/dag_processing/collection.py | 408 ++++++++++++++++++++++++++++++++
airflow/datasets/__init__.py | 16 +-
airflow/datasets/manager.py | 2 -
airflow/models/dag.py | 342 ++------------------------
airflow/models/taskinstance.py | 1 +
airflow/timetables/base.py | 5 +-
tests/dag_processing/test_collection.py | 64 +++++
tests/datasets/test_dataset.py | 2 +-
tests/models/test_dag.py | 41 ----
9 files changed, 515 insertions(+), 366 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
new file mode 100644
index 0000000000..3f75e0b23b
--- /dev/null
+++ b/airflow/dag_processing/collection.py
@@ -0,0 +1,408 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Utility code that write DAGs in bulk into the database.
+
+This should generally only be called by internal methods such as
+``DagBag._sync_to_db``, ``DAG.bulk_write_to_db``.
+
+:meta private:
+"""
+
+from __future__ import annotations
+
+import itertools
+import logging
+from typing import TYPE_CHECKING, NamedTuple
+
+from sqlalchemy import func, select
+from sqlalchemy.orm import joinedload, load_only
+from sqlalchemy.sql import expression
+
+from airflow.datasets import Dataset, DatasetAlias
+from airflow.datasets.manager import dataset_manager
+from airflow.models.dag import DAG, DagModel, DagOwnerAttributes, DagTag
+from airflow.models.dagrun import DagRun
+from airflow.models.dataset import (
+ DagScheduleDatasetAliasReference,
+ DagScheduleDatasetReference,
+ DatasetAliasModel,
+ DatasetModel,
+ TaskOutletDatasetReference,
+)
+from airflow.utils.sqlalchemy import with_row_locks
+from airflow.utils.timezone import utcnow
+from airflow.utils.types import DagRunType
+
+if TYPE_CHECKING:
+ from collections.abc import Collection, Iterable, Iterator
+
+ from sqlalchemy.orm import Session
+ from sqlalchemy.sql import Select
+
+ from airflow.typing_compat import Self
+
+log = logging.getLogger(__name__)
+
+
+def collect_orm_dags(dags: dict[str, DAG], *, session: Session) -> dict[str,
DagModel]:
+ """
+ Collect DagModel objects from DAG objects.
+
+ An existing DagModel is fetched if there's a matching ID in the database.
+ Otherwise, a new DagModel is created and added to the session.
+ """
+ stmt = (
+ select(DagModel)
+ .options(joinedload(DagModel.tags, innerjoin=False))
+ .where(DagModel.dag_id.in_(dags))
+ .options(joinedload(DagModel.schedule_dataset_references))
+ .options(joinedload(DagModel.schedule_dataset_alias_references))
+ .options(joinedload(DagModel.task_outlet_dataset_references))
+ )
+ stmt = with_row_locks(stmt, of=DagModel, session=session)
+ existing_orm_dags = {dm.dag_id: dm for dm in
session.scalars(stmt).unique()}
+
+ for dag_id, dag in dags.items():
+ if dag_id in existing_orm_dags:
+ continue
+ orm_dag = DagModel(dag_id=dag_id)
+ if dag.is_paused_upon_creation is not None:
+ orm_dag.is_paused = dag.is_paused_upon_creation
+ orm_dag.tags = []
+ log.info("Creating ORM DAG for %s", dag_id)
+ session.add(orm_dag)
+ existing_orm_dags[dag_id] = orm_dag
+
+ return existing_orm_dags
+
+
+def create_orm_dag(dag: DAG, session: Session) -> DagModel:
+ orm_dag = DagModel(dag_id=dag.dag_id)
+ if dag.is_paused_upon_creation is not None:
+ orm_dag.is_paused = dag.is_paused_upon_creation
+ orm_dag.tags = []
+ log.info("Creating ORM DAG for %s", dag.dag_id)
+ session.add(orm_dag)
+ return orm_dag
+
+
+def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select:
+ """Build a select statement to retrieve the last automated run for each
dag."""
+ if len(dag_ids) == 1: # Index optimized fast path to avoid more
complicated & slower groupby queryplan.
+ (dag_id,) = dag_ids
+ last_automated_runs_subq = (
+ select(func.max(DagRun.execution_date).label("max_execution_date"))
+ .where(
+ DagRun.dag_id == dag_id,
+ DagRun.run_type.in_((DagRunType.BACKFILL_JOB,
DagRunType.SCHEDULED)),
+ )
+ .scalar_subquery()
+ )
+ query = select(DagRun).where(
+ DagRun.dag_id == dag_id,
+ DagRun.execution_date == last_automated_runs_subq,
+ )
+ else:
+ last_automated_runs_subq = (
+ select(DagRun.dag_id,
func.max(DagRun.execution_date).label("max_execution_date"))
+ .where(
+ DagRun.dag_id.in_(dag_ids),
+ DagRun.run_type.in_((DagRunType.BACKFILL_JOB,
DagRunType.SCHEDULED)),
+ )
+ .group_by(DagRun.dag_id)
+ .subquery()
+ )
+ query = select(DagRun).where(
+ DagRun.dag_id == last_automated_runs_subq.c.dag_id,
+ DagRun.execution_date ==
last_automated_runs_subq.c.max_execution_date,
+ )
+ return query.options(
+ load_only(
+ DagRun.dag_id,
+ DagRun.execution_date,
+ DagRun.data_interval_start,
+ DagRun.data_interval_end,
+ )
+ )
+
+
+class _RunInfo(NamedTuple):
+ latest_runs: dict[str, DagRun]
+ num_active_runs: dict[str, int]
+
+ @classmethod
+ def calculate(cls, dags: dict[str, DAG], *, session: Session) -> Self:
+ # Skip these queries entirely if no DAGs can be scheduled to save time.
+ if not any(dag.timetable.can_be_scheduled for dag in dags.values()):
+ return cls({}, {})
+ return cls(
+ {run.dag_id: run for run in
session.scalars(_get_latest_runs_stmt(dag_ids=dags))},
+ DagRun.active_runs_of_dags(dag_ids=dags, session=session),
+ )
+
+
+def update_orm_dags(
+ source_dags: dict[str, DAG],
+ target_dags: dict[str, DagModel],
+ *,
+ processor_subdir: str | None = None,
+ session: Session,
+) -> None:
+ """
+ Apply DAG attributes to DagModel objects.
+
+ Objects in ``target_dags`` are modified in-place.
+ """
+ run_info = _RunInfo.calculate(source_dags, session=session)
+
+ for dag_id, dm in sorted(target_dags.items()):
+ dag = source_dags[dag_id]
+ dm.fileloc = dag.fileloc
+ dm.owners = dag.owner
+ dm.is_active = True
+ dm.has_import_errors = False
+ dm.last_parsed_time = utcnow()
+ dm.default_view = dag.default_view
+ dm._dag_display_property_value = dag._dag_display_property_value
+ dm.description = dag.description
+ dm.max_active_tasks = dag.max_active_tasks
+ dm.max_active_runs = dag.max_active_runs
+ dm.max_consecutive_failed_dag_runs =
dag.max_consecutive_failed_dag_runs
+ dm.has_task_concurrency_limits = any(
+ t.max_active_tis_per_dag is not None or
t.max_active_tis_per_dagrun is not None for t in dag.tasks
+ )
+ dm.timetable_summary = dag.timetable.summary
+ dm.timetable_description = dag.timetable.description
+ dm.dataset_expression = dag.timetable.dataset_condition.as_expression()
+ dm.processor_subdir = processor_subdir
+
+ last_automated_run: DagRun | None =
run_info.latest_runs.get(dag.dag_id)
+ if last_automated_run is None:
+ last_automated_data_interval = None
+ else:
+ last_automated_data_interval =
dag.get_run_data_interval(last_automated_run)
+ if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
+ dm.next_dagrun_create_after = None
+ else:
+ dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)
+
+ if not dag.timetable.dataset_condition:
+ dm.schedule_dataset_references = []
+ dm.schedule_dataset_alias_references = []
+ # FIXME: STORE NEW REFERENCES.
+
+ dag_tags = set(dag.tags or ())
+ for orm_tag in (dm_tags := list(dm.tags or [])):
+ if orm_tag.name not in dag_tags:
+ session.delete(orm_tag)
+ dm.tags.remove(orm_tag)
+ orm_tag_names = {t.name for t in dm_tags}
+ for dag_tag in dag_tags:
+ if dag_tag not in orm_tag_names:
+ dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
+ dm.tags.append(dag_tag_orm)
+ session.add(dag_tag_orm)
+
+ dm_links = dm.dag_owner_links or []
+ for dm_link in dm_links:
+ if dm_link not in dag.owner_links:
+ session.delete(dm_link)
+ for owner_name, owner_link in dag.owner_links.items():
+ dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id,
owner=owner_name, link=owner_link)
+ session.add(dag_owner_orm)
+
+
+def _find_all_datasets(dags: Iterable[DAG]) -> Iterator[Dataset]:
+ for dag in dags:
+ for _, dataset in dag.timetable.dataset_condition.iter_datasets():
+ yield dataset
+ for task in dag.task_dict.values():
+ for obj in itertools.chain(task.inlets, task.outlets):
+ if isinstance(obj, Dataset):
+ yield obj
+
+
+def _find_all_dataset_aliases(dags: Iterable[DAG]) -> Iterator[DatasetAlias]:
+ for dag in dags:
+ for _, alias in dag.timetable.dataset_condition.iter_dataset_aliases():
+ yield alias
+ for task in dag.task_dict.values():
+ for obj in itertools.chain(task.inlets, task.outlets):
+ if isinstance(obj, DatasetAlias):
+ yield obj
+
+
+class DatasetModelOperation(NamedTuple):
+ """Collect dataset/alias objects from DAGs and perform database operations
for them."""
+
+ schedule_dataset_references: dict[str, list[Dataset]]
+ schedule_dataset_alias_references: dict[str, list[DatasetAlias]]
+ outlet_references: dict[str, list[tuple[str, Dataset]]]
+ datasets: dict[str, Dataset]
+ dataset_aliases: dict[str, DatasetAlias]
+
+ @classmethod
+ def collect(cls, dags: dict[str, DAG]) -> Self:
+ coll = cls(
+ schedule_dataset_references={
+ dag_id: [dataset for _, dataset in
dag.timetable.dataset_condition.iter_datasets()]
+ for dag_id, dag in dags.items()
+ },
+ schedule_dataset_alias_references={
+ dag_id: [alias for _, alias in
dag.timetable.dataset_condition.iter_dataset_aliases()]
+ for dag_id, dag in dags.items()
+ },
+ outlet_references={
+ dag_id: [
+ (task_id, outlet)
+ for task_id, task in dag.task_dict.items()
+ for outlet in task.outlets
+ if isinstance(outlet, Dataset)
+ ]
+ for dag_id, dag in dags.items()
+ },
+ datasets={dataset.uri: dataset for dataset in
_find_all_datasets(dags.values())},
+ dataset_aliases={alias.name: alias for alias in
_find_all_dataset_aliases(dags.values())},
+ )
+ return coll
+
+ def add_datasets(self, *, session: Session) -> dict[str, DatasetModel]:
+ # Optimization: skip all database calls if no datasets were collected.
+ if not self.datasets:
+ return {}
+ orm_datasets: dict[str, DatasetModel] = {
+ dm.uri: dm
+ for dm in
session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets)))
+ }
+
+ def _resolve_dataset_addition() -> Iterator[DatasetModel]:
+ for uri, dataset in self.datasets.items():
+ try:
+ dm = orm_datasets[uri]
+ except KeyError:
+ dm = orm_datasets[uri] = DatasetModel.from_public(dataset)
+ yield dm
+ else:
+ # The orphaned flag was bulk-set to True before parsing,
so we
+ # don't need to handle rows in the db without a public
entry.
+ dm.is_orphaned = expression.false()
+ dm.extra = dataset.extra
+
+ dataset_manager.create_datasets(list(_resolve_dataset_addition()),
session=session)
+ return orm_datasets
+
+ def add_dataset_aliases(self, *, session: Session) -> dict[str,
DatasetAliasModel]:
+ # Optimization: skip all database calls if no dataset aliases were
collected.
+ if not self.dataset_aliases:
+ return {}
+ orm_aliases: dict[str, DatasetAliasModel] = {
+ da.name: da
+ for da in session.scalars(
+
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases))
+ )
+ }
+ for name, alias in self.dataset_aliases.items():
+ try:
+ da = orm_aliases[name]
+ except KeyError:
+ da = orm_aliases[name] = DatasetAliasModel.from_public(alias)
+ session.add(da)
+ return orm_aliases
+
+ def add_dag_dataset_references(
+ self,
+ dags: dict[str, DagModel],
+ datasets: dict[str, DatasetModel],
+ *,
+ session: Session,
+ ) -> None:
+ # Optimization: No datasets means there are no references to update.
+ if not datasets:
+ return
+ for dag_id, references in self.schedule_dataset_references.items():
+ # Optimization: no references at all; this is faster than repeated
delete().
+ if not references:
+ dags[dag_id].schedule_dataset_references = []
+ continue
+ referenced_dataset_ids = {dataset.id for dataset in
(datasets[r.uri] for r in references)}
+ orm_refs = {r.dataset_id: r for r in
dags[dag_id].schedule_dataset_references}
+ for dataset_id, ref in orm_refs.items():
+ if dataset_id not in referenced_dataset_ids:
+ session.delete(ref)
+ session.bulk_save_objects(
+ DagScheduleDatasetReference(dataset_id=dataset_id,
dag_id=dag_id)
+ for dataset_id in referenced_dataset_ids
+ if dataset_id not in orm_refs
+ )
+
+ def add_dag_dataset_alias_references(
+ self,
+ dags: dict[str, DagModel],
+ aliases: dict[str, DatasetAliasModel],
+ *,
+ session: Session,
+ ) -> None:
+ # Optimization: No aliases means there are no references to update.
+ if not aliases:
+ return
+ for dag_id, references in
self.schedule_dataset_alias_references.items():
+ # Optimization: no references at all; this is faster than repeated
delete().
+ if not references:
+ dags[dag_id].schedule_dataset_alias_references = []
+ continue
+ referenced_alias_ids = {alias.id for alias in (aliases[r.name] for
r in references)}
+ orm_refs = {a.alias_id: a for a in
dags[dag_id].schedule_dataset_alias_references}
+ for alias_id, ref in orm_refs.items():
+ if alias_id not in referenced_alias_ids:
+ session.delete(ref)
+ session.bulk_save_objects(
+ DagScheduleDatasetAliasReference(alias_id=alias_id,
dag_id=dag_id)
+ for alias_id in referenced_alias_ids
+ if alias_id not in orm_refs
+ )
+
+ def add_task_dataset_references(
+ self,
+ dags: dict[str, DagModel],
+ datasets: dict[str, DatasetModel],
+ *,
+ session: Session,
+ ) -> None:
+ # Optimization: No datasets means there are no references to update.
+ if not datasets:
+ return
+ for dag_id, references in self.outlet_references.items():
+ # Optimization: no references at all; this is faster than repeated
delete().
+ if not references:
+ dags[dag_id].task_outlet_dataset_references = []
+ continue
+ referenced_outlets = {
+ (task_id, dataset.id)
+ for task_id, dataset in ((task_id, datasets[d.uri]) for
task_id, d in references)
+ }
+ orm_refs = {(r.task_id, r.dataset_id): r for r in
dags[dag_id].task_outlet_dataset_references}
+ for key, ref in orm_refs.items():
+ if key not in referenced_outlets:
+ session.delete(ref)
+ session.bulk_save_objects(
+ TaskOutletDatasetReference(dataset_id=dataset_id,
dag_id=dag_id, task_id=task_id)
+ for task_id, dataset_id in referenced_outlets
+ if (task_id, dataset_id) not in orm_refs
+ )
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index cd57078095..6f7ae99ff7 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -194,7 +194,7 @@ class BaseDataset:
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
raise NotImplementedError
- def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
+ def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
raise NotImplementedError
def iter_dag_dependencies(self, *, source: str, target: str) ->
Iterator[DagDependency]:
@@ -212,6 +212,12 @@ class DatasetAlias(BaseDataset):
name: str
+ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
+ return iter(())
+
+ def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
+ yield self.name, self
+
def iter_dag_dependencies(self, *, source: str, target: str) ->
Iterator[DagDependency]:
"""
Iterate a dataset alias as dag dependency.
@@ -294,7 +300,7 @@ class Dataset(os.PathLike, BaseDataset):
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield self.uri, self
- def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
+ def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
return iter(())
def evaluate(self, statuses: dict[str, bool]) -> bool:
@@ -339,7 +345,7 @@ class _DatasetBooleanCondition(BaseDataset):
yield k, v
seen.add(k)
- def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
+ def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
"""Filter dataest aliases in the condition."""
for o in self.objects:
yield from o.iter_dataset_aliases()
@@ -399,8 +405,8 @@ class _DatasetAliasCondition(DatasetAny):
"""
return {"alias": self.name}
- def iter_dataset_aliases(self) -> Iterator[DatasetAlias]:
- yield DatasetAlias(self.name)
+ def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
+ yield self.name, DatasetAlias(self.name)
def iter_dag_dependencies(self, *, source: str = "", target: str = "") ->
Iterator[DagDependency]:
"""
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 058eef6ab8..19f6913fff 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -62,8 +62,6 @@ class DatasetManager(LoggingMixin):
"""Create new datasets."""
for dataset_model in dataset_models:
session.add(dataset_model)
- session.flush()
-
for dataset_model in dataset_models:
self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri,
extra=dataset_model.extra))
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 6545293ccf..6447f7be15 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -74,7 +74,7 @@ from sqlalchemy import (
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import backref, joinedload, load_only, relationship
+from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import Select, expression
import airflow.templates
@@ -82,7 +82,6 @@ from airflow import settings, utils
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf as airflow_conf, secrets_backend_list
from airflow.datasets import BaseDataset, Dataset, DatasetAlias, DatasetAll
-from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowException,
DuplicateTaskIdFound,
@@ -100,11 +99,7 @@ from airflow.models.baseoperator import BaseOperator
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
-from airflow.models.dataset import (
- DatasetAliasModel,
- DatasetDagRunQueue,
- DatasetModel,
-)
+from airflow.models.dataset import DatasetDagRunQueue
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Context,
@@ -2637,7 +2632,7 @@ class DAG(LoggingMixin):
cls,
dags: Collection[DAG],
processor_subdir: str | None = None,
- session=NEW_SESSION,
+ session: Session = NEW_SESSION,
):
"""
Ensure the DagModel rows for the given dags are up-to-date in the dag
table in the DB.
@@ -2648,323 +2643,38 @@ class DAG(LoggingMixin):
if not dags:
return
- log.info("Sync %s DAGs", len(dags))
- dag_by_ids = {dag.dag_id: dag for dag in dags}
-
- dag_ids = set(dag_by_ids)
- query = (
- select(DagModel)
- .options(joinedload(DagModel.tags, innerjoin=False))
- .where(DagModel.dag_id.in_(dag_ids))
- .options(joinedload(DagModel.schedule_dataset_references))
- .options(joinedload(DagModel.schedule_dataset_alias_references))
- .options(joinedload(DagModel.task_outlet_dataset_references))
- )
- query = with_row_locks(query, of=DagModel, session=session)
- orm_dags: list[DagModel] = session.scalars(query).unique().all()
- existing_dags: dict[str, DagModel] = {x.dag_id: x for x in orm_dags}
- missing_dag_ids = dag_ids.difference(existing_dags.keys())
-
- for missing_dag_id in missing_dag_ids:
- orm_dag = DagModel(dag_id=missing_dag_id)
- dag = dag_by_ids[missing_dag_id]
- if dag.is_paused_upon_creation is not None:
- orm_dag.is_paused = dag.is_paused_upon_creation
- orm_dag.tags = []
- log.info("Creating ORM DAG for %s", dag.dag_id)
- session.add(orm_dag)
- orm_dags.append(orm_dag)
-
- latest_runs: dict[str, DagRun] = {}
- num_active_runs: dict[str, int] = {}
- # Skip these queries entirely if no DAGs can be scheduled to save time.
- if any(dag.timetable.can_be_scheduled for dag in dags):
- # Get the latest automated dag run for each existing dag as a
single query (avoid n+1 query)
- query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys()))
- latest_runs = {run.dag_id: run for run in session.scalars(query)}
-
- # Get number of active dagruns for all dags we are processing as a
single query.
- num_active_runs =
DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session)
-
- filelocs = []
-
- for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id):
- dag = dag_by_ids[orm_dag.dag_id]
- filelocs.append(dag.fileloc)
- orm_dag.fileloc = dag.fileloc
- orm_dag.owners = dag.owner
- orm_dag.is_active = True
- orm_dag.has_import_errors = False
- orm_dag.last_parsed_time = timezone.utcnow()
- orm_dag.default_view = dag.default_view
- orm_dag._dag_display_property_value =
dag._dag_display_property_value
- orm_dag.description = dag.description
- orm_dag.max_active_tasks = dag.max_active_tasks
- orm_dag.max_active_runs = dag.max_active_runs
- orm_dag.max_consecutive_failed_dag_runs =
dag.max_consecutive_failed_dag_runs
- orm_dag.has_task_concurrency_limits = any(
- t.max_active_tis_per_dag is not None or
t.max_active_tis_per_dagrun is not None
- for t in dag.tasks
- )
- orm_dag.timetable_summary = dag.timetable.summary
- orm_dag.timetable_description = dag.timetable.description
- orm_dag.dataset_expression =
dag.timetable.dataset_condition.as_expression()
-
- orm_dag.processor_subdir = processor_subdir
-
- last_automated_run: DagRun | None = latest_runs.get(dag.dag_id)
- if last_automated_run is None:
- last_automated_data_interval = None
- else:
- last_automated_data_interval =
dag.get_run_data_interval(last_automated_run)
- if num_active_runs.get(dag.dag_id, 0) >= orm_dag.max_active_runs:
- orm_dag.next_dagrun_create_after = None
- else:
- orm_dag.calculate_dagrun_date_fields(dag,
last_automated_data_interval)
-
- dag_tags = set(dag.tags or {})
- orm_dag_tags = list(orm_dag.tags or [])
- for orm_tag in orm_dag_tags:
- if orm_tag.name not in dag_tags:
- session.delete(orm_tag)
- orm_dag.tags.remove(orm_tag)
- orm_tag_names = {t.name for t in orm_dag_tags}
- for dag_tag in dag_tags:
- if dag_tag not in orm_tag_names:
- dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
- orm_dag.tags.append(dag_tag_orm)
- session.add(dag_tag_orm)
-
- orm_dag_links = orm_dag.dag_owner_links or []
- for orm_dag_link in orm_dag_links:
- if orm_dag_link not in dag.owner_links:
- session.delete(orm_dag_link)
- for owner_name, owner_link in dag.owner_links.items():
- dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id,
owner=owner_name, link=owner_link)
- session.add(dag_owner_orm)
-
- DagCode.bulk_sync_to_db(filelocs, session=session)
-
- from airflow.datasets import Dataset
- from airflow.models.dataset import (
- DagScheduleDatasetAliasReference,
- DagScheduleDatasetReference,
- DatasetModel,
- TaskOutletDatasetReference,
+ from airflow.dag_processing.collection import (
+ DatasetModelOperation,
+ collect_orm_dags,
+ create_orm_dag,
+ update_orm_dags,
)
- dag_references: dict[str, set[tuple[Literal["dataset",
"dataset-alias"], str]]] = defaultdict(set)
- outlet_references = defaultdict(set)
- # We can't use a set here as we want to preserve order
- outlet_dataset_models: dict[DatasetModel, None] = {}
- input_dataset_models: dict[DatasetModel, None] = {}
- outlet_dataset_alias_models: set[DatasetAliasModel] = set()
- input_dataset_alias_models: set[DatasetAliasModel] = set()
-
- # here we go through dags and tasks to check for dataset references
- # if there are now None and previously there were some, we delete them
- # if there are now *any*, we add them to the above data structures, and
- # later we'll persist them to the database.
- for dag in dags:
- curr_orm_dag = existing_dags.get(dag.dag_id)
- if not (dataset_condition := dag.timetable.dataset_condition):
- if curr_orm_dag:
- if curr_orm_dag.schedule_dataset_references:
- curr_orm_dag.schedule_dataset_references = []
- if curr_orm_dag.schedule_dataset_alias_references:
- curr_orm_dag.schedule_dataset_alias_references = []
- else:
- for _, dataset in dataset_condition.iter_datasets():
- dag_references[dag.dag_id].add(("dataset", dataset.uri))
- input_dataset_models[DatasetModel.from_public(dataset)] =
None
-
- for dataset_alias in dataset_condition.iter_dataset_aliases():
- dag_references[dag.dag_id].add(("dataset-alias",
dataset_alias.name))
-
input_dataset_alias_models.add(DatasetAliasModel.from_public(dataset_alias))
-
- curr_outlet_references = curr_orm_dag and
curr_orm_dag.task_outlet_dataset_references
- for task in dag.tasks:
- dataset_outlets: list[Dataset] = []
- dataset_alias_outlets: list[DatasetAlias] = []
- for outlet in task.outlets:
- if isinstance(outlet, Dataset):
- dataset_outlets.append(outlet)
- elif isinstance(outlet, DatasetAlias):
- dataset_alias_outlets.append(outlet)
-
- if not dataset_outlets:
- if curr_outlet_references:
- this_task_outlet_refs = [
- x
- for x in curr_outlet_references
- if x.dag_id == dag.dag_id and x.task_id ==
task.task_id
- ]
- for ref in this_task_outlet_refs:
- curr_outlet_references.remove(ref)
-
- for d in dataset_outlets:
- outlet_dataset_models[DatasetModel.from_public(d)] = None
- outlet_references[(task.dag_id, task.task_id)].add(d.uri)
-
- for d_a in dataset_alias_outlets:
-
outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a))
-
- all_dataset_models = outlet_dataset_models
- all_dataset_models.update(input_dataset_models)
-
- # store datasets
- stored_dataset_models: dict[str, DatasetModel] = {}
- new_dataset_models: list[DatasetModel] = []
- for dataset in all_dataset_models:
- stored_dataset_model = session.scalar(
- select(DatasetModel).where(DatasetModel.uri ==
dataset.uri).limit(1)
- )
- if stored_dataset_model:
- # Some datasets may have been previously unreferenced, and
therefore orphaned by the
- # scheduler. But if we're here, then we have found that
dataset again in our DAGs, which
- # means that it is no longer an orphan, so set is_orphaned to
False.
- stored_dataset_model.is_orphaned = expression.false()
- stored_dataset_models[stored_dataset_model.uri] =
stored_dataset_model
- else:
- new_dataset_models.append(dataset)
- dataset_manager.create_datasets(dataset_models=new_dataset_models,
session=session)
- stored_dataset_models.update(
- {dataset_model.uri: dataset_model for dataset_model in
new_dataset_models}
- )
-
- del new_dataset_models
- del all_dataset_models
-
- # store dataset aliases
- all_datasets_alias_models = input_dataset_alias_models |
outlet_dataset_alias_models
- stored_dataset_alias_models: dict[str, DatasetAliasModel] = {}
- new_dataset_alias_models: set[DatasetAliasModel] = set()
- if all_datasets_alias_models:
- all_dataset_alias_names = {
- dataset_alias_model.name for dataset_alias_model in
all_datasets_alias_models
- }
-
- stored_dataset_alias_models = {
- dsa_m.name: dsa_m
- for dsa_m in session.scalars(
-
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names))
- ).fetchall()
- }
-
- if stored_dataset_alias_models:
- new_dataset_alias_models = {
- dataset_alias_model
- for dataset_alias_model in all_datasets_alias_models
- if dataset_alias_model.name not in
stored_dataset_alias_models.keys()
- }
- else:
- new_dataset_alias_models = all_datasets_alias_models
-
- session.add_all(new_dataset_alias_models)
- session.flush()
- stored_dataset_alias_models.update(
- {
- dataset_alias_model.name: dataset_alias_model
- for dataset_alias_model in new_dataset_alias_models
- }
+ log.info("Sync %s DAGs", len(dags))
+ dags_by_ids = {dag.dag_id: dag for dag in dags}
+ del dags
+
+ orm_dags = collect_orm_dags(dags_by_ids, session=session)
+ orm_dags.update(
+ (dag_id, create_orm_dag(dag, session=session))
+ for dag_id, dag in dags_by_ids.items()
+ if dag_id not in orm_dags
)
- del new_dataset_alias_models
- del all_datasets_alias_models
+ update_orm_dags(dags_by_ids, orm_dags,
processor_subdir=processor_subdir, session=session)
+ DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()),
session=session)
- # reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias
references
- for dag_id, base_dataset_list in dag_references.items():
- dag_refs_needed = {
- DagScheduleDatasetReference(
-
dataset_id=stored_dataset_models[base_dataset_identifier].id, dag_id=dag_id
- )
- if base_dataset_type == "dataset"
- else DagScheduleDatasetAliasReference(
-
alias_id=stored_dataset_alias_models[base_dataset_identifier].id, dag_id=dag_id
- )
- for base_dataset_type, base_dataset_identifier in
base_dataset_list
- }
+ dataset_op = DatasetModelOperation.collect(dags_by_ids)
- # if isinstance(base_dataset, Dataset)
+ orm_datasets = dataset_op.add_datasets(session=session)
+ orm_dataset_aliases = dataset_op.add_dataset_aliases(session=session)
+ session.flush() # This populates id so we can create fks in later
calls.
- dag_refs_stored = (
- set(existing_dags.get(dag_id).schedule_dataset_references) #
type: ignore
- |
set(existing_dags.get(dag_id).schedule_dataset_alias_references) # type: ignore
- if existing_dags.get(dag_id)
- else set()
- )
- dag_refs_to_add = dag_refs_needed - dag_refs_stored
- session.bulk_save_objects(dag_refs_to_add)
- for obj in dag_refs_stored - dag_refs_needed:
- session.delete(obj)
-
- existing_task_outlet_refs_dict = defaultdict(set)
- for dag_id, orm_dag in existing_dags.items():
- for todr in orm_dag.task_outlet_dataset_references:
- existing_task_outlet_refs_dict[(dag_id,
todr.task_id)].add(todr)
-
- # reconcile task-outlet-dataset references
- for (dag_id, task_id), uri_list in outlet_references.items():
- task_refs_needed = {
- TaskOutletDatasetReference(
- dataset_id=stored_dataset_models[uri].id, dag_id=dag_id,
task_id=task_id
- )
- for uri in uri_list
- }
- task_refs_stored = existing_task_outlet_refs_dict[(dag_id,
task_id)]
- task_refs_to_add = {x for x in task_refs_needed if x not in
task_refs_stored}
- session.bulk_save_objects(task_refs_to_add)
- for obj in task_refs_stored - task_refs_needed:
- session.delete(obj)
-
- # Issue SQL/finish "Unit of Work", but let @provide_session commit (or
if passed a session, let caller
- # decide when to commit
+ dataset_op.add_dag_dataset_references(orm_dags, orm_datasets,
session=session)
+ dataset_op.add_dag_dataset_alias_references(orm_dags,
orm_dataset_aliases, session=session)
+ dataset_op.add_task_dataset_references(orm_dags, orm_datasets,
session=session)
session.flush()
- @classmethod
- def _get_latest_runs_stmt(cls, dags: list[str]) -> Select:
- """
- Build a select statement for retrieve the last automated run for each
dag.
-
- :param dags: dags to query
- """
- if len(dags) == 1:
- # Index optimized fast path to avoid more complicated & slower
groupby queryplan
- existing_dag_id = dags[0]
- last_automated_runs_subq = (
-
select(func.max(DagRun.execution_date).label("max_execution_date"))
- .where(
- DagRun.dag_id == existing_dag_id,
- DagRun.run_type.in_((DagRunType.BACKFILL_JOB,
DagRunType.SCHEDULED)),
- )
- .scalar_subquery()
- )
- query = select(DagRun).where(
- DagRun.dag_id == existing_dag_id, DagRun.execution_date ==
last_automated_runs_subq
- )
- else:
- last_automated_runs_subq = (
- select(DagRun.dag_id,
func.max(DagRun.execution_date).label("max_execution_date"))
- .where(
- DagRun.dag_id.in_(dags),
- DagRun.run_type.in_((DagRunType.BACKFILL_JOB,
DagRunType.SCHEDULED)),
- )
- .group_by(DagRun.dag_id)
- .subquery()
- )
- query = select(DagRun).where(
- DagRun.dag_id == last_automated_runs_subq.c.dag_id,
- DagRun.execution_date ==
last_automated_runs_subq.c.max_execution_date,
- )
- return query.options(
- load_only(
- DagRun.dag_id,
- DagRun.execution_date,
- DagRun.data_interval_start,
- DagRun.data_interval_end,
- )
- )
-
@provide_session
def sync_to_db(self, processor_subdir: str | None = None,
session=NEW_SESSION):
"""
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index f93c90b638..954e5ed4d0 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2925,6 +2925,7 @@ class TaskInstance(Base, LoggingMixin):
dataset_obj = DatasetModel(uri=uri)
dataset_manager.create_datasets(dataset_models=[dataset_obj],
session=session)
self.log.warning("Created a new %r as it did not exist.",
dataset_obj)
+ session.flush()
dataset_objs_cache[uri] = dataset_obj
for alias in alias_names:
diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py
index ce701794a4..5d97591856 100644
--- a/airflow/timetables/base.py
+++ b/airflow/timetables/base.py
@@ -24,7 +24,7 @@ from airflow.typing_compat import Protocol, runtime_checkable
if TYPE_CHECKING:
from pendulum import DateTime
- from airflow.datasets import Dataset
+ from airflow.datasets import Dataset, DatasetAlias
from airflow.serialization.dag_dependency import DagDependency
from airflow.utils.types import DagRunType
@@ -57,6 +57,9 @@ class _NullDataset(BaseDataset):
def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
return iter(())
+ def iter_dataset_aliases(self) -> Iterator[tuple[str, DatasetAlias]]:
+ return iter(())
+
def iter_dag_dependencies(self, source, target) -> Iterator[DagDependency]:
return iter(())
diff --git a/tests/dag_processing/test_collection.py
b/tests/dag_processing/test_collection.py
new file mode 100644
index 0000000000..4d5a6736ad
--- /dev/null
+++ b/tests/dag_processing/test_collection.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import warnings
+
+from sqlalchemy.exc import SAWarning
+
+from airflow.dag_processing.collection import _get_latest_runs_stmt
+
+
+def test_statement_latest_runs_one_dag():
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", category=SAWarning)
+
+ stmt = _get_latest_runs_stmt(["fake-dag"])
+ compiled_stmt = str(stmt.compile())
+ actual = [x.strip() for x in compiled_stmt.splitlines()]
+ expected = [
+ "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
+ "dag_run.data_interval_start, dag_run.data_interval_end",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
+ "SELECT max(dag_run.logical_date) AS max_execution_date",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN
(__[POSTCOMPILE_run_type_1]))",
+ ]
+ assert actual == expected, compiled_stmt
+
+
+def test_statement_latest_runs_many_dag():
+ with warnings.catch_warnings():
+ warnings.simplefilter("error", category=SAWarning)
+
+ stmt = _get_latest_runs_stmt(["fake-dag-1", "fake-dag-2"])
+ compiled_stmt = str(stmt.compile())
+ actual = [x.strip() for x in compiled_stmt.splitlines()]
+ expected = [
+ "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
+ "dag_run.data_interval_start, dag_run.data_interval_end",
+ "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
+ "max(dag_run.logical_date) AS max_execution_date",
+ "FROM dag_run",
+ "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
+ "AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY
dag_run.dag_id) AS anon_1",
+ "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date =
anon_1.max_execution_date",
+ ]
+ assert actual == expected, compiled_stmt
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index 940e445669..8221a5aea8 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -145,7 +145,7 @@ def test_dataset_iter_dataset_aliases():
DatasetAll(DatasetAlias("example-alias-5"), Dataset("5")),
)
assert list(base_dataset.iter_dataset_aliases()) == [
- DatasetAlias(f"example-alias-{i}") for i in range(1, 6)
+ (f"example-alias-{i}", DatasetAlias(f"example-alias-{i}")) for i in
range(1, 6)
]
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index cb6d4d4ed4..093fdcae2f 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -23,7 +23,6 @@ import logging
import os
import pickle
import re
-import warnings
import weakref
from datetime import timedelta
from importlib import reload
@@ -37,7 +36,6 @@ import pendulum
import pytest
import time_machine
from sqlalchemy import inspect, select
-from sqlalchemy.exc import SAWarning
from airflow import settings
from airflow.configuration import conf
@@ -3992,42 +3990,3 @@ class TestTaskClearingSetupTeardownBehavior:
Exception, match="Setup tasks must be followed with trigger
rule ALL_SUCCESS."
):
dag.validate_setup_teardown()
-
-
-def test_statement_latest_runs_one_dag():
- with warnings.catch_warnings():
- warnings.simplefilter("error", category=SAWarning)
-
- stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
- compiled_stmt = str(stmt.compile())
- actual = [x.strip() for x in compiled_stmt.splitlines()]
- expected = [
- "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
- "dag_run.data_interval_start, dag_run.data_interval_end",
- "FROM dag_run",
- "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.logical_date = ("
- "SELECT max(dag_run.logical_date) AS max_execution_date",
- "FROM dag_run",
- "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN
(__[POSTCOMPILE_run_type_1]))",
- ]
- assert actual == expected, compiled_stmt
-
-
-def test_statement_latest_runs_many_dag():
- with warnings.catch_warnings():
- warnings.simplefilter("error", category=SAWarning)
-
- stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
- compiled_stmt = str(stmt.compile())
- actual = [x.strip() for x in compiled_stmt.splitlines()]
- expected = [
- "SELECT dag_run.logical_date, dag_run.id, dag_run.dag_id, "
- "dag_run.data_interval_start, dag_run.data_interval_end",
- "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, "
- "max(dag_run.logical_date) AS max_execution_date",
- "FROM dag_run",
- "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
- "AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY
dag_run.dag_id) AS anon_1",
- "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date =
anon_1.max_execution_date",
- ]
- assert actual == expected, compiled_stmt