This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b08a0fc67 Refactor _register_dataset_changes (#42343)
4b08a0fc67 is described below

commit 4b08a0fc67386bb0b81250b0521bfc3a0678955e
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Sep 24 18:50:43 2024 -0700

    Refactor _register_dataset_changes (#42343)
---
 airflow/dag_processing/collection.py               | 37 +++++------
 airflow/datasets/manager.py                        | 77 +++++++++++++++++-----
 airflow/listeners/spec/dataset.py                  |  9 ++-
 airflow/models/dataset.py                          |  6 ++
 airflow/models/taskinstance.py                     | 51 +++++++-------
 .../administration-and-deployment/listeners.rst    |  1 +
 newsfragments/42343.feature.rst                    |  1 +
 newsfragments/42343.significant.rst                |  7 ++
 tests/datasets/test_manager.py                     |  7 +-
 9 files changed, 128 insertions(+), 68 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
index 5d54d17b87..bcac479d87 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -299,21 +299,15 @@ class DatasetModelOperation(NamedTuple):
             dm.uri: dm
             for dm in 
session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets)))
         }
-
-        def _resolve_dataset_addition() -> Iterator[DatasetModel]:
-            for uri, dataset in self.datasets.items():
-                try:
-                    dm = orm_datasets[uri]
-                except KeyError:
-                    dm = orm_datasets[uri] = DatasetModel.from_public(dataset)
-                    yield dm
-                else:
-                    # The orphaned flag was bulk-set to True before parsing, 
so we
-                    # don't need to handle rows in the db without a public 
entry.
-                    dm.is_orphaned = expression.false()
-                dm.extra = dataset.extra
-
-        dataset_manager.create_datasets(list(_resolve_dataset_addition()), 
session=session)
+        for model in orm_datasets.values():
+            model.is_orphaned = expression.false()
+        orm_datasets.update(
+            (model.uri, model)
+            for model in dataset_manager.create_datasets(
+                [dataset for uri, dataset in self.datasets.items() if uri not 
in orm_datasets],
+                session=session,
+            )
+        )
         return orm_datasets
 
     def add_dataset_aliases(self, *, session: Session) -> dict[str, 
DatasetAliasModel]:
@@ -326,12 +320,13 @@ class DatasetModelOperation(NamedTuple):
                 
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases))
             )
         }
-        for name, alias in self.dataset_aliases.items():
-            try:
-                da = orm_aliases[name]
-            except KeyError:
-                da = orm_aliases[name] = DatasetAliasModel.from_public(alias)
-                session.add(da)
+        orm_aliases.update(
+            (model.name, model)
+            for model in dataset_manager.create_dataset_aliases(
+                [alias for name, alias in self.dataset_aliases.items() if name 
not in orm_aliases],
+                session=session,
+            )
+        )
         return orm_aliases
 
     def add_dag_dataset_references(
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 19f6913fff..c5ebb2e6d7 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from collections.abc import Iterable
+from collections.abc import Collection, Iterable
 from typing import TYPE_CHECKING
 
 from sqlalchemy import exc, select
@@ -25,7 +25,6 @@ from sqlalchemy.orm import joinedload
 
 from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf
-from airflow.datasets import Dataset
 from airflow.listeners.listener import get_listener_manager
 from airflow.models.dagbag import DagPriorityParsingRequest
 from airflow.models.dataset import (
@@ -43,6 +42,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
 if TYPE_CHECKING:
     from sqlalchemy.orm.session import Session
 
+    from airflow.datasets import Dataset, DatasetAlias
     from airflow.models.dag import DagModel
     from airflow.models.taskinstance import TaskInstance
 
@@ -58,12 +58,51 @@ class DatasetManager(LoggingMixin):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
 
-    def create_datasets(self, dataset_models: list[DatasetModel], session: 
Session) -> None:
+    def create_datasets(self, datasets: list[Dataset], *, session: Session) -> 
list[DatasetModel]:
         """Create new datasets."""
-        for dataset_model in dataset_models:
-            session.add(dataset_model)
-        for dataset_model in dataset_models:
-            self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, 
extra=dataset_model.extra))
+
+        def _add_one(dataset: Dataset) -> DatasetModel:
+            model = DatasetModel.from_public(dataset)
+            session.add(model)
+            self.notify_dataset_created(dataset=dataset)
+            return model
+
+        return [_add_one(d) for d in datasets]
+
+    def create_dataset_aliases(
+        self,
+        dataset_aliases: list[DatasetAlias],
+        *,
+        session: Session,
+    ) -> list[DatasetAliasModel]:
+        """Create new dataset aliases."""
+
+        def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel:
+            model = DatasetAliasModel.from_public(dataset_alias)
+            session.add(model)
+            self.notify_dataset_alias_created(dataset_alias=dataset_alias)
+            return model
+
+        return [_add_one(a) for a in dataset_aliases]
+
+    @classmethod
+    def _add_dataset_alias_association(
+        cls,
+        alias_names: Collection[str],
+        dataset: DatasetModel,
+        *,
+        session: Session,
+    ) -> None:
+        already_related = {m.name for m in dataset.aliases}
+        existing_aliases = {
+            m.name: m
+            for m in 
session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names)))
+        }
+        dataset.aliases.extend(
+            existing_aliases.get(name, DatasetAliasModel(name=name))
+            for name in alias_names
+            if name not in already_related
+        )
 
     @classmethod
     @internal_api_call
@@ -74,8 +113,9 @@ class DatasetManager(LoggingMixin):
         task_instance: TaskInstance | None = None,
         dataset: Dataset,
         extra=None,
-        session: Session = NEW_SESSION,
+        aliases: Collection[DatasetAlias] = (),
         source_alias_names: Iterable[str] | None = None,
+        session: Session = NEW_SESSION,
         **kwargs,
     ) -> DatasetEvent | None:
         """
@@ -88,24 +128,27 @@ class DatasetManager(LoggingMixin):
         dataset_model = session.scalar(
             select(DatasetModel)
             .where(DatasetModel.uri == dataset.uri)
-            
.options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
+            .options(
+                joinedload(DatasetModel.aliases),
+                
joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag),
+            )
         )
         if not dataset_model:
             cls.logger().warning("DatasetModel %s not found", dataset)
             return None
 
+        cls._add_dataset_alias_association({alias.name for alias in aliases}, 
dataset_model, session=session)
+
         event_kwargs = {
             "dataset_id": dataset_model.id,
             "extra": extra,
         }
         if task_instance:
             event_kwargs.update(
-                {
-                    "source_task_id": task_instance.task_id,
-                    "source_dag_id": task_instance.dag_id,
-                    "source_run_id": task_instance.run_id,
-                    "source_map_index": task_instance.map_index,
-                }
+                source_task_id=task_instance.task_id,
+                source_dag_id=task_instance.dag_id,
+                source_run_id=task_instance.run_id,
+                source_map_index=task_instance.map_index,
             )
 
         dataset_event = DatasetEvent(**event_kwargs)
@@ -155,6 +198,10 @@ class DatasetManager(LoggingMixin):
         """Run applicable notification actions when a dataset is created."""
         get_listener_manager().hook.on_dataset_created(dataset=dataset)
 
+    def notify_dataset_alias_created(self, dataset_alias: DatasetAlias):
+        """Run applicable notification actions when a dataset alias is 
created."""
+        
get_listener_manager().hook.on_dataset_alias_created(dataset_alias=dataset_alias)
+
     @classmethod
     def notify_dataset_changed(cls, dataset: Dataset):
         """Run applicable notification actions when a dataset is changed."""
diff --git a/airflow/listeners/spec/dataset.py 
b/airflow/listeners/spec/dataset.py
index 214ddad3ff..eee1a10dd7 100644
--- a/airflow/listeners/spec/dataset.py
+++ b/airflow/listeners/spec/dataset.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING
 from pluggy import HookspecMarker
 
 if TYPE_CHECKING:
-    from airflow.datasets import Dataset
+    from airflow.datasets import Dataset, DatasetAlias
 
 hookspec = HookspecMarker("airflow")
 
@@ -34,6 +34,13 @@ def on_dataset_created(
     """Execute when a new dataset is created."""
 
 
+@hookspec
+def on_dataset_alias_created(
+    dataset_alias: DatasetAlias,
+):
+    """Execute when a new dataset alias is created."""
+
+
 @hookspec
 def on_dataset_changed(
     dataset: Dataset,
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index 5033da48a3..489d6b68a6 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -138,6 +138,9 @@ class DatasetAliasModel(Base):
         else:
             return NotImplemented
 
+    def to_public(self) -> DatasetAlias:
+        return DatasetAlias(name=self.name)
+
 
 class DatasetModel(Base):
     """
@@ -200,6 +203,9 @@ class DatasetModel(Base):
     def __repr__(self):
         return f"{self.__class__.__name__}(uri={self.uri!r}, 
extra={self.extra!r})"
 
+    def to_public(self) -> Dataset:
+        return Dataset(uri=self.uri, extra=self.extra)
+
 
 class DagScheduleDatasetAliasReference(Base):
     """References from a DAG to a dataset alias of which it is a consumer."""
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 954e5ed4d0..d3300207ab 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -31,7 +31,7 @@ from collections import defaultdict
 from contextlib import nullcontext
 from datetime import timedelta
 from enum import Enum
-from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, 
Iterable, Mapping, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Generator, 
Iterable, Mapping, Tuple
 from urllib.parse import quote
 
 import dill
@@ -89,7 +89,7 @@ from airflow.exceptions import (
 from airflow.listeners.listener import get_listener_manager
 from airflow.models.base import Base, StringID, TaskInstanceDependencies, 
_sentinel
 from airflow.models.dagbag import DagBag
-from airflow.models.dataset import DatasetAliasModel, DatasetModel
+from airflow.models.dataset import DatasetModel
 from airflow.models.log import Log
 from airflow.models.param import process_params
 from airflow.models.renderedtifields import get_serialized_template_fields
@@ -2893,7 +2893,7 @@ class TaskInstance(Base, LoggingMixin):
         # One task only triggers one dataset event for each dataset with the 
same extra.
         # This tuple[dataset uri, extra] to sets alias names mapping is used 
to find whether
         # there're datasets with same uri but different extra that we need to 
emit more than one dataset events.
-        dataset_tuple_to_alias_names_mapping: dict[tuple[str, frozenset], 
set[str]] = defaultdict(set)
+        dataset_alias_names: dict[tuple[str, frozenset], set[str]] = 
defaultdict(set)
         for obj in self.task.outlets or []:
             self.log.debug("outlet obj %s", obj)
             # Lineage can have other types of objects besides datasets
@@ -2908,33 +2908,27 @@ class TaskInstance(Base, LoggingMixin):
                 for dataset_alias_event in events[obj].dataset_alias_events:
                     dataset_alias_name = 
dataset_alias_event["source_alias_name"]
                     dataset_uri = dataset_alias_event["dest_dataset_uri"]
-                    extra = dataset_alias_event["extra"]
-                    frozen_extra = frozenset(extra.items())
+                    frozen_extra = 
frozenset(dataset_alias_event["extra"].items())
+                    dataset_alias_names[(dataset_uri, 
frozen_extra)].add(dataset_alias_name)
 
-                    dataset_tuple_to_alias_names_mapping[(dataset_uri, 
frozen_extra)].add(dataset_alias_name)
+        class _DatasetModelCache(Dict[str, DatasetModel]):
+            log = self.log
 
-        dataset_objs_cache: dict[str, DatasetModel] = {}
-        for (uri, extra_items), alias_names in 
dataset_tuple_to_alias_names_mapping.items():
-            if uri not in dataset_objs_cache:
-                dataset_obj = 
session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1))
-                dataset_objs_cache[uri] = dataset_obj
-            else:
-                dataset_obj = dataset_objs_cache[uri]
-
-            if not dataset_obj:
-                dataset_obj = DatasetModel(uri=uri)
-                dataset_manager.create_datasets(dataset_models=[dataset_obj], 
session=session)
-                self.log.warning("Created a new %r as it did not exist.", 
dataset_obj)
+            def __missing__(self, key: str) -> DatasetModel:
+                (dataset_obj,) = 
dataset_manager.create_datasets([Dataset(uri=key)], session=session)
                 session.flush()
-                dataset_objs_cache[uri] = dataset_obj
-
-            for alias in alias_names:
-                alias_obj = session.scalar(
-                    select(DatasetAliasModel).where(DatasetAliasModel.name == 
alias).limit(1)
-                )
-                dataset_obj.aliases.append(alias_obj)
+                self.log.warning("Created a new %r as it did not exist.", 
dataset_obj)
+                self[key] = dataset_obj
+                return dataset_obj
 
-            extra = {k: v for k, v in extra_items}
+        dataset_objs_cache = _DatasetModelCache(
+            (dataset_obj.uri, dataset_obj)
+            for dataset_obj in session.scalars(
+                select(DatasetModel).where(DatasetModel.uri.in_(uri for uri, _ 
in dataset_alias_names))
+            )
+        )
+        for (uri, extra_items), alias_names in dataset_alias_names.items():
+            dataset_obj = dataset_objs_cache[uri]
             self.log.info(
                 'Creating event for %r through aliases "%s"',
                 dataset_obj,
@@ -2942,8 +2936,9 @@ class TaskInstance(Base, LoggingMixin):
             )
             dataset_manager.register_dataset_change(
                 task_instance=self,
-                dataset=dataset_obj,
-                extra=extra,
+                dataset=dataset_obj.to_public(),
+                aliases=[DatasetAlias(name) for name in alias_names],
+                extra=dict(extra_items),
                 session=session,
                 source_alias_names=alias_names,
             )
diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst 
b/docs/apache-airflow/administration-and-deployment/listeners.rst
index 34909e225a..4926b12ed6 100644
--- a/docs/apache-airflow/administration-and-deployment/listeners.rst
+++ b/docs/apache-airflow/administration-and-deployment/listeners.rst
@@ -95,6 +95,7 @@ Dataset Events
 --------------
 
 - ``on_dataset_created``
+- ``on_dataset_alias_created``
 - ``on_dataset_changed``
 
 Dataset events occur when Dataset management operations are run.
diff --git a/newsfragments/42343.feature.rst b/newsfragments/42343.feature.rst
new file mode 100644
index 0000000000..8a7cdf335a
--- /dev/null
+++ b/newsfragments/42343.feature.rst
@@ -0,0 +1 @@
+New function ``create_dataset_aliases`` added to DatasetManager for 
DatasetAlias creation.
diff --git a/newsfragments/42343.significant.rst 
b/newsfragments/42343.significant.rst
new file mode 100644
index 0000000000..d9e1ba6b12
--- /dev/null
+++ b/newsfragments/42343.significant.rst
@@ -0,0 +1,7 @@
+``DatasetManager.create_datasets`` now takes ``Dataset`` objects
+
+This function previously accepts a list of ``DatasetModel`` objects. it now
+receives ``Dataset`` objects instead. A list of ``DatasetModel`` objects are
+created inside, and returned by the function.
+
+Also, the ``session`` argument is now keyword-only.
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index 1e7b4fda40..d3013aef60 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -169,10 +169,11 @@ class TestDatasetManager:
         dataset_listener.clear()
         get_listener_manager().add_listener(dataset_listener)
 
-        dsm = DatasetModel(uri="test_dataset_uri_3")
+        ds = Dataset(uri="test_dataset_uri_3")
 
-        dsem.create_datasets([dsm], session)
+        dsms = dsem.create_datasets([ds], session=session)
 
         # Ensure the listener was notified
         assert len(dataset_listener.created) == 1
-        assert dataset_listener.created[0].uri == dsm.uri
+        assert len(dsms) == 1
+        assert dataset_listener.created[0].uri == ds.uri == dsms[0].uri

Reply via email to