This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 268a400984 Refactor ORM DAG insertion logic (#42358)
268a400984 is described below
commit 268a40098418ae6cf177d89e0654dd1dc157ea3f
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Sep 23 01:34:48 2024 -0700
Refactor ORM DAG insertion logic (#42358)
* Refactor ORM DAG insertion logic
Basically using the same pattern as the dataset inserts. Making fk
relation updates more readable.
The SQL queries CAN be made more efficient, but I decided to keep things
as-is (not worse than before). These are likely not that big a deal.
* Use set operation for readability
---
airflow/dag_processing/collection.py | 200 ++++++++++++++++++-----------------
airflow/models/dag.py | 24 ++---
2 files changed, 110 insertions(+), 114 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index 3f75e0b23b..5d54d17b87 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -61,46 +61,28 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
-def collect_orm_dags(dags: dict[str, DAG], *, session: Session) -> dict[str,
DagModel]:
- """
- Collect DagModel objects from DAG objects.
-
- An existing DagModel is fetched if there's a matching ID in the database.
- Otherwise, a new DagModel is created and added to the session.
- """
+def _find_orm_dags(dag_ids: Iterable[str], *, session: Session) -> dict[str,
DagModel]:
+ """Find existing DagModel objects from DAG objects."""
stmt = (
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
- .where(DagModel.dag_id.in_(dags))
+ .where(DagModel.dag_id.in_(dag_ids))
.options(joinedload(DagModel.schedule_dataset_references))
.options(joinedload(DagModel.schedule_dataset_alias_references))
.options(joinedload(DagModel.task_outlet_dataset_references))
)
stmt = with_row_locks(stmt, of=DagModel, session=session)
- existing_orm_dags = {dm.dag_id: dm for dm in
session.scalars(stmt).unique()}
+ return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}
+
- for dag_id, dag in dags.items():
- if dag_id in existing_orm_dags:
- continue
- orm_dag = DagModel(dag_id=dag_id)
+def _create_orm_dags(dags: Iterable[DAG], *, session: Session) ->
Iterator[DagModel]:
+ for dag in dags:
+ orm_dag = DagModel(dag_id=dag.dag_id)
if dag.is_paused_upon_creation is not None:
orm_dag.is_paused = dag.is_paused_upon_creation
- orm_dag.tags = []
- log.info("Creating ORM DAG for %s", dag_id)
+ log.info("Creating ORM DAG for %s", dag.dag_id)
session.add(orm_dag)
- existing_orm_dags[dag_id] = orm_dag
-
- return existing_orm_dags
-
-
-def create_orm_dag(dag: DAG, session: Session) -> DagModel:
- orm_dag = DagModel(dag_id=dag.dag_id)
- if dag.is_paused_upon_creation is not None:
- orm_dag.is_paused = dag.is_paused_upon_creation
- orm_dag.tags = []
- log.info("Creating ORM DAG for %s", dag.dag_id)
- session.add(orm_dag)
- return orm_dag
+ yield orm_dag
def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select:
@@ -158,75 +140,101 @@ class _RunInfo(NamedTuple):
)
-def update_orm_dags(
- source_dags: dict[str, DAG],
- target_dags: dict[str, DagModel],
- *,
- processor_subdir: str | None = None,
- session: Session,
-) -> None:
- """
- Apply DAG attributes to DagModel objects.
-
- Objects in ``target_dags`` are modified in-place.
- """
- run_info = _RunInfo.calculate(source_dags, session=session)
-
- for dag_id, dm in sorted(target_dags.items()):
- dag = source_dags[dag_id]
- dm.fileloc = dag.fileloc
- dm.owners = dag.owner
- dm.is_active = True
- dm.has_import_errors = False
- dm.last_parsed_time = utcnow()
- dm.default_view = dag.default_view
- dm._dag_display_property_value = dag._dag_display_property_value
- dm.description = dag.description
- dm.max_active_tasks = dag.max_active_tasks
- dm.max_active_runs = dag.max_active_runs
- dm.max_consecutive_failed_dag_runs =
dag.max_consecutive_failed_dag_runs
- dm.has_task_concurrency_limits = any(
- t.max_active_tis_per_dag is not None or
t.max_active_tis_per_dagrun is not None for t in dag.tasks
- )
- dm.timetable_summary = dag.timetable.summary
- dm.timetable_description = dag.timetable.description
- dm.dataset_expression = dag.timetable.dataset_condition.as_expression()
- dm.processor_subdir = processor_subdir
-
- last_automated_run: DagRun | None =
run_info.latest_runs.get(dag.dag_id)
- if last_automated_run is None:
- last_automated_data_interval = None
- else:
- last_automated_data_interval =
dag.get_run_data_interval(last_automated_run)
- if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
- dm.next_dagrun_create_after = None
+def _update_dag_tags(tag_names: set[str], dm: DagModel, *, session: Session)
-> None:
+ orm_tags = {t.name: t for t in dm.tags}
+ for name, orm_tag in orm_tags.items():
+ if name not in tag_names:
+ session.delete(orm_tag)
+ dm.tags.extend(DagTag(name=name, dag_id=dm.dag_id) for name in
tag_names.difference(orm_tags))
+
+
+def _update_dag_owner_links(dag_owner_links: dict[str, str], dm: DagModel, *,
session: Session) -> None:
+ orm_dag_owner_attributes = {obj.owner: obj for obj in dm.dag_owner_links}
+ for owner, obj in orm_dag_owner_attributes.items():
+ try:
+ link = dag_owner_links[owner]
+ except KeyError:
+ session.delete(obj)
else:
- dm.calculate_dagrun_date_fields(dag, last_automated_data_interval)
-
- if not dag.timetable.dataset_condition:
- dm.schedule_dataset_references = []
- dm.schedule_dataset_alias_references = []
- # FIXME: STORE NEW REFERENCES.
-
- dag_tags = set(dag.tags or ())
- for orm_tag in (dm_tags := list(dm.tags or [])):
- if orm_tag.name not in dag_tags:
- session.delete(orm_tag)
- dm.tags.remove(orm_tag)
- orm_tag_names = {t.name for t in dm_tags}
- for dag_tag in dag_tags:
- if dag_tag not in orm_tag_names:
- dag_tag_orm = DagTag(name=dag_tag, dag_id=dag.dag_id)
- dm.tags.append(dag_tag_orm)
- session.add(dag_tag_orm)
-
- dm_links = dm.dag_owner_links or []
- for dm_link in dm_links:
- if dm_link not in dag.owner_links:
- session.delete(dm_link)
- for owner_name, owner_link in dag.owner_links.items():
- dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id,
owner=owner_name, link=owner_link)
- session.add(dag_owner_orm)
+ if obj.link != link:
+ obj.link = link
+ dm.dag_owner_links.extend(
+ DagOwnerAttributes(dag_id=dm.dag_id, owner=owner, link=link)
+ for owner, link in dag_owner_links.items()
+ if owner not in orm_dag_owner_attributes
+ )
+
+
+class DagModelOperation(NamedTuple):
+ """Collect DAG objects and perform database operations for them."""
+
+ dags: dict[str, DAG]
+
+ def add_dags(self, *, session: Session) -> dict[str, DagModel]:
+ orm_dags = _find_orm_dags(self.dags, session=session)
+ orm_dags.update(
+ (model.dag_id, model)
+ for model in _create_orm_dags(
+ (dag for dag_id, dag in self.dags.items() if dag_id not in
orm_dags),
+ session=session,
+ )
+ )
+ return orm_dags
+
+ def update_dags(
+ self,
+ orm_dags: dict[str, DagModel],
+ *,
+ processor_subdir: str | None = None,
+ session: Session,
+ ) -> None:
+ run_info = _RunInfo.calculate(self.dags, session=session)
+
+ for dag_id, dm in sorted(orm_dags.items()):
+ dag = self.dags[dag_id]
+ dm.fileloc = dag.fileloc
+ dm.owners = dag.owner
+ dm.is_active = True
+ dm.has_import_errors = False
+ dm.last_parsed_time = utcnow()
+ dm.default_view = dag.default_view
+ dm._dag_display_property_value = dag._dag_display_property_value
+ dm.description = dag.description
+ dm.max_active_tasks = dag.max_active_tasks
+ dm.max_active_runs = dag.max_active_runs
+ dm.max_consecutive_failed_dag_runs =
dag.max_consecutive_failed_dag_runs
+ dm.has_task_concurrency_limits = any(
+ t.max_active_tis_per_dag is not None or
t.max_active_tis_per_dagrun is not None
+ for t in dag.tasks
+ )
+ dm.timetable_summary = dag.timetable.summary
+ dm.timetable_description = dag.timetable.description
+ dm.dataset_expression =
dag.timetable.dataset_condition.as_expression()
+ dm.processor_subdir = processor_subdir
+
+ last_automated_run: DagRun | None =
run_info.latest_runs.get(dag.dag_id)
+ if last_automated_run is None:
+ last_automated_data_interval = None
+ else:
+ last_automated_data_interval =
dag.get_run_data_interval(last_automated_run)
+ if run_info.num_active_runs.get(dag.dag_id, 0) >=
dm.max_active_runs:
+ dm.next_dagrun_create_after = None
+ else:
+ dm.calculate_dagrun_date_fields(dag,
last_automated_data_interval)
+
+ if not dag.timetable.dataset_condition:
+ dm.schedule_dataset_references = []
+ dm.schedule_dataset_alias_references = []
+ # FIXME: STORE NEW REFERENCES.
+
+ if dag.tags:
+ _update_dag_tags(set(dag.tags), dm, session=session)
+ else: # Optimization: no references at all, just clear everything.
+ dm.tags = []
+ if dag.owner_links:
+ _update_dag_owner_links(dag.owner_links, dm, session=session)
+ else: # Optimization: no references at all, just clear everything.
+ dm.dag_owner_links = []
def _find_all_datasets(dags: Iterable[DAG]) -> Iterator[Dataset]:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 00820585b6..c95d11f3ef 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2644,28 +2644,16 @@ class DAG(LoggingMixin):
if not dags:
return
- from airflow.dag_processing.collection import (
- DatasetModelOperation,
- collect_orm_dags,
- create_orm_dag,
- update_orm_dags,
- )
+ from airflow.dag_processing.collection import DagModelOperation,
DatasetModelOperation
log.info("Sync %s DAGs", len(dags))
- dags_by_ids = {dag.dag_id: dag for dag in dags}
- del dags
-
- orm_dags = collect_orm_dags(dags_by_ids, session=session)
- orm_dags.update(
- (dag_id, create_orm_dag(dag, session=session))
- for dag_id, dag in dags_by_ids.items()
- if dag_id not in orm_dags
- )
+ dag_op = DagModelOperation({dag.dag_id: dag for dag in dags})
- update_orm_dags(dags_by_ids, orm_dags,
processor_subdir=processor_subdir, session=session)
- DagCode.bulk_sync_to_db((dag.fileloc for dag in dags_by_ids.values()),
session=session)
+ orm_dags = dag_op.add_dags(session=session)
+ dag_op.update_dags(orm_dags, processor_subdir=processor_subdir,
session=session)
+ DagCode.bulk_sync_to_db((dag.fileloc for dag in dags), session=session)
- dataset_op = DatasetModelOperation.collect(dags_by_ids)
+ dataset_op = DatasetModelOperation.collect(dag_op.dags)
orm_datasets = dataset_op.add_datasets(session=session)
orm_dataset_aliases = dataset_op.add_dataset_aliases(session=session)