This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new a02325f7b8 Rewrite how DAG to dataset / dataset alias are stored 
(#41987) (#42055)
a02325f7b8 is described below

commit a02325f7b827894fcd73333294d11dbd6a656908
Author: Wei Lee <[email protected]>
AuthorDate: Fri Sep 6 14:41:30 2024 +0800

    Rewrite how DAG to dataset / dataset alias are stored (#41987) (#42055)
---
 airflow/models/dag.py | 88 +++++++++++++++++++++++++++++----------------------
 1 file changed, 50 insertions(+), 38 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index f848346780..58213efeec 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -3236,8 +3236,6 @@ class DAG(LoggingMixin):
         if not dags:
             return
 
-        from airflow.models.dataset import DagScheduleDatasetAliasReference
-
         log.info("Sync %s DAGs", len(dags))
         dag_by_ids = {dag.dag_id: dag for dag in dags}
 
@@ -3344,18 +3342,19 @@ class DAG(LoggingMixin):
 
         from airflow.datasets import Dataset
         from airflow.models.dataset import (
+            DagScheduleDatasetAliasReference,
             DagScheduleDatasetReference,
             DatasetModel,
             TaskOutletDatasetReference,
         )
 
-        dag_references: dict[str, set[Dataset | DatasetAlias]] = 
defaultdict(set)
+        dag_references: dict[str, set[tuple[Literal["dataset", 
"dataset-alias"], str]]] = defaultdict(set)
         outlet_references = defaultdict(set)
         # We can't use a set here as we want to preserve order
-        outlet_datasets: dict[DatasetModel, None] = {}
-        input_datasets: dict[DatasetModel, None] = {}
+        outlet_dataset_models: dict[DatasetModel, None] = {}
+        input_dataset_models: dict[DatasetModel, None] = {}
         outlet_dataset_alias_models: set[DatasetAliasModel] = set()
-        input_dataset_aliases: set[DatasetAliasModel] = set()
+        input_dataset_alias_models: set[DatasetAliasModel] = set()
 
         # here we go through dags and tasks to check for dataset references
         # if there are now None and previously there were some, we delete them
@@ -3371,12 +3370,12 @@ class DAG(LoggingMixin):
                         curr_orm_dag.schedule_dataset_alias_references = []
             else:
                 for _, dataset in dataset_condition.iter_datasets():
-                    dag_references[dag.dag_id].add(Dataset(uri=dataset.uri))
-                    input_datasets[DatasetModel.from_public(dataset)] = None
+                    dag_references[dag.dag_id].add(("dataset", dataset.uri))
+                    input_dataset_models[DatasetModel.from_public(dataset)] = 
None
 
                 for dataset_alias in dataset_condition.iter_dataset_aliases():
-                    dag_references[dag.dag_id].add(dataset_alias)
-                    
input_dataset_aliases.add(DatasetAliasModel.from_public(dataset_alias))
+                    dag_references[dag.dag_id].add(("dataset-alias", 
dataset_alias.name))
+                    
input_dataset_alias_models.add(DatasetAliasModel.from_public(dataset_alias))
 
             curr_outlet_references = curr_orm_dag and 
curr_orm_dag.task_outlet_dataset_references
             for task in dag.tasks:
@@ -3399,63 +3398,70 @@ class DAG(LoggingMixin):
                             curr_outlet_references.remove(ref)
 
                 for d in dataset_outlets:
+                    outlet_dataset_models[DatasetModel.from_public(d)] = None
                     outlet_references[(task.dag_id, task.task_id)].add(d.uri)
-                    outlet_datasets[DatasetModel.from_public(d)] = None
 
                 for d_a in dataset_alias_outlets:
                     
outlet_dataset_alias_models.add(DatasetAliasModel.from_public(d_a))
 
-        all_datasets = outlet_datasets
-        all_datasets.update(input_datasets)
+        all_dataset_models = outlet_dataset_models
+        all_dataset_models.update(input_dataset_models)
 
         # store datasets
-        stored_datasets: dict[str, DatasetModel] = {}
-        new_datasets: list[DatasetModel] = []
-        for dataset in all_datasets:
-            stored_dataset = session.scalar(
+        stored_dataset_models: dict[str, DatasetModel] = {}
+        new_dataset_models: list[DatasetModel] = []
+        for dataset in all_dataset_models:
+            stored_dataset_model = session.scalar(
                 select(DatasetModel).where(DatasetModel.uri == 
dataset.uri).limit(1)
             )
-            if stored_dataset:
+            if stored_dataset_model:
                 # Some datasets may have been previously unreferenced, and 
therefore orphaned by the
                 # scheduler. But if we're here, then we have found that 
dataset again in our DAGs, which
                 # means that it is no longer an orphan, so set is_orphaned to 
False.
-                stored_dataset.is_orphaned = expression.false()
-                stored_datasets[stored_dataset.uri] = stored_dataset
+                stored_dataset_model.is_orphaned = expression.false()
+                stored_dataset_models[stored_dataset_model.uri] = 
stored_dataset_model
             else:
-                new_datasets.append(dataset)
-        dataset_manager.create_datasets(dataset_models=new_datasets, 
session=session)
-        stored_datasets.update({dataset.uri: dataset for dataset in 
new_datasets})
+                new_dataset_models.append(dataset)
+        dataset_manager.create_datasets(dataset_models=new_dataset_models, 
session=session)
+        stored_dataset_models.update(
+            {dataset_model.uri: dataset_model for dataset_model in 
new_dataset_models}
+        )
 
-        del new_datasets
-        del all_datasets
+        del new_dataset_models
+        del all_dataset_models
 
         # store dataset aliases
-        all_datasets_alias_models = input_dataset_aliases | 
outlet_dataset_alias_models
-        stored_dataset_aliases: dict[str, DatasetAliasModel] = {}
+        all_datasets_alias_models = input_dataset_alias_models | 
outlet_dataset_alias_models
+        stored_dataset_alias_models: dict[str, DatasetAliasModel] = {}
         new_dataset_alias_models: set[DatasetAliasModel] = set()
         if all_datasets_alias_models:
-            all_dataset_alias_names = {dataset_alias.name for dataset_alias in 
all_datasets_alias_models}
+            all_dataset_alias_names = {
+                dataset_alias_model.name for dataset_alias_model in 
all_datasets_alias_models
+            }
 
-            stored_dataset_aliases = {
+            stored_dataset_alias_models = {
                 dsa_m.name: dsa_m
                 for dsa_m in session.scalars(
                     
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(all_dataset_alias_names))
                 ).fetchall()
             }
 
-            if stored_dataset_aliases:
+            if stored_dataset_alias_models:
                 new_dataset_alias_models = {
                     dataset_alias_model
                     for dataset_alias_model in all_datasets_alias_models
-                    if dataset_alias_model.name not in 
stored_dataset_aliases.keys()
+                    if dataset_alias_model.name not in 
stored_dataset_alias_models.keys()
                 }
             else:
                 new_dataset_alias_models = all_datasets_alias_models
 
             session.add_all(new_dataset_alias_models)
         session.flush()
-        stored_dataset_aliases.update(
-            {dataset_alias.name: dataset_alias for dataset_alias in 
new_dataset_alias_models}
+        stored_dataset_alias_models.update(
+            {
+                dataset_alias_model.name: dataset_alias_model
+                for dataset_alias_model in new_dataset_alias_models
+            }
         )
 
         del new_dataset_alias_models
@@ -3464,14 +3470,18 @@ class DAG(LoggingMixin):
         # reconcile dag-schedule-on-dataset and dag-schedule-on-dataset-alias 
references
         for dag_id, base_dataset_list in dag_references.items():
             dag_refs_needed = {
-                
DagScheduleDatasetReference(dataset_id=stored_datasets[base_dataset.uri].id, 
dag_id=dag_id)
-                if isinstance(base_dataset, Dataset)
+                DagScheduleDatasetReference(
+                    
dataset_id=stored_dataset_models[base_dataset_identifier].id, dag_id=dag_id
+                )
+                if base_dataset_type == "dataset"
                 else DagScheduleDatasetAliasReference(
-                    alias_id=stored_dataset_aliases[base_dataset.name].id, 
dag_id=dag_id
+                    
alias_id=stored_dataset_alias_models[base_dataset_identifier].id, dag_id=dag_id
                 )
-                for base_dataset in base_dataset_list
+                for base_dataset_type, base_dataset_identifier in 
base_dataset_list
             }
 
+            # if isinstance(base_dataset, Dataset)
+
             dag_refs_stored = (
                 set(existing_dags.get(dag_id).schedule_dataset_references)  # 
type: ignore
                 | 
set(existing_dags.get(dag_id).schedule_dataset_alias_references)  # type: ignore
@@ -3491,7 +3501,9 @@ class DAG(LoggingMixin):
         # reconcile task-outlet-dataset references
         for (dag_id, task_id), uri_list in outlet_references.items():
             task_refs_needed = {
-                TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, 
dag_id=dag_id, task_id=task_id)
+                TaskOutletDatasetReference(
+                    dataset_id=stored_dataset_models[uri].id, dag_id=dag_id, 
task_id=task_id
+                )
                 for uri in uri_list
             }
             task_refs_stored = existing_task_outlet_refs_dict[(dag_id, 
task_id)]

Reply via email to