This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4b08a0fc67 Refactor _register_dataset_changes (#42343)
4b08a0fc67 is described below
commit 4b08a0fc67386bb0b81250b0521bfc3a0678955e
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Sep 24 18:50:43 2024 -0700
Refactor _register_dataset_changes (#42343)
---
airflow/dag_processing/collection.py | 37 +++++------
airflow/datasets/manager.py | 77 +++++++++++++++++-----
airflow/listeners/spec/dataset.py | 9 ++-
airflow/models/dataset.py | 6 ++
airflow/models/taskinstance.py | 51 +++++++-------
.../administration-and-deployment/listeners.rst | 1 +
newsfragments/42343.feature.rst | 1 +
newsfragments/42343.significant.rst | 7 ++
tests/datasets/test_manager.py | 7 +-
9 files changed, 128 insertions(+), 68 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index 5d54d17b87..bcac479d87 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -299,21 +299,15 @@ class DatasetModelOperation(NamedTuple):
dm.uri: dm
for dm in
session.scalars(select(DatasetModel).where(DatasetModel.uri.in_(self.datasets)))
}
-
- def _resolve_dataset_addition() -> Iterator[DatasetModel]:
- for uri, dataset in self.datasets.items():
- try:
- dm = orm_datasets[uri]
- except KeyError:
- dm = orm_datasets[uri] = DatasetModel.from_public(dataset)
- yield dm
- else:
- # The orphaned flag was bulk-set to True before parsing,
so we
- # don't need to handle rows in the db without a public
entry.
- dm.is_orphaned = expression.false()
- dm.extra = dataset.extra
-
- dataset_manager.create_datasets(list(_resolve_dataset_addition()),
session=session)
+ for model in orm_datasets.values():
+ model.is_orphaned = expression.false()
+ orm_datasets.update(
+ (model.uri, model)
+ for model in dataset_manager.create_datasets(
+ [dataset for uri, dataset in self.datasets.items() if uri not
in orm_datasets],
+ session=session,
+ )
+ )
return orm_datasets
def add_dataset_aliases(self, *, session: Session) -> dict[str,
DatasetAliasModel]:
@@ -326,12 +320,13 @@ class DatasetModelOperation(NamedTuple):
select(DatasetAliasModel).where(DatasetAliasModel.name.in_(self.dataset_aliases))
)
}
- for name, alias in self.dataset_aliases.items():
- try:
- da = orm_aliases[name]
- except KeyError:
- da = orm_aliases[name] = DatasetAliasModel.from_public(alias)
- session.add(da)
+ orm_aliases.update(
+ (model.name, model)
+ for model in dataset_manager.create_dataset_aliases(
+ [alias for name, alias in self.dataset_aliases.items() if name
not in orm_aliases],
+ session=session,
+ )
+ )
return orm_aliases
def add_dag_dataset_references(
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 19f6913fff..c5ebb2e6d7 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from collections.abc import Iterable
+from collections.abc import Collection, Iterable
from typing import TYPE_CHECKING
from sqlalchemy import exc, select
@@ -25,7 +25,6 @@ from sqlalchemy.orm import joinedload
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
-from airflow.datasets import Dataset
from airflow.listeners.listener import get_listener_manager
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.models.dataset import (
@@ -43,6 +42,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
+ from airflow.datasets import Dataset, DatasetAlias
from airflow.models.dag import DagModel
from airflow.models.taskinstance import TaskInstance
@@ -58,12 +58,51 @@ class DatasetManager(LoggingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- def create_datasets(self, dataset_models: list[DatasetModel], session:
Session) -> None:
+ def create_datasets(self, datasets: list[Dataset], *, session: Session) ->
list[DatasetModel]:
"""Create new datasets."""
- for dataset_model in dataset_models:
- session.add(dataset_model)
- for dataset_model in dataset_models:
- self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri,
extra=dataset_model.extra))
+
+ def _add_one(dataset: Dataset) -> DatasetModel:
+ model = DatasetModel.from_public(dataset)
+ session.add(model)
+ self.notify_dataset_created(dataset=dataset)
+ return model
+
+ return [_add_one(d) for d in datasets]
+
+ def create_dataset_aliases(
+ self,
+ dataset_aliases: list[DatasetAlias],
+ *,
+ session: Session,
+ ) -> list[DatasetAliasModel]:
+ """Create new dataset aliases."""
+
+ def _add_one(dataset_alias: DatasetAlias) -> DatasetAliasModel:
+ model = DatasetAliasModel.from_public(dataset_alias)
+ session.add(model)
+ self.notify_dataset_alias_created(dataset_alias=dataset_alias)
+ return model
+
+ return [_add_one(a) for a in dataset_aliases]
+
+ @classmethod
+ def _add_dataset_alias_association(
+ cls,
+ alias_names: Collection[str],
+ dataset: DatasetModel,
+ *,
+ session: Session,
+ ) -> None:
+ already_related = {m.name for m in dataset.aliases}
+ existing_aliases = {
+ m.name: m
+ for m in
session.scalars(select(DatasetAliasModel).where(DatasetAliasModel.name.in_(alias_names)))
+ }
+ dataset.aliases.extend(
+ existing_aliases.get(name, DatasetAliasModel(name=name))
+ for name in alias_names
+ if name not in already_related
+ )
@classmethod
@internal_api_call
@@ -74,8 +113,9 @@ class DatasetManager(LoggingMixin):
task_instance: TaskInstance | None = None,
dataset: Dataset,
extra=None,
- session: Session = NEW_SESSION,
+ aliases: Collection[DatasetAlias] = (),
source_alias_names: Iterable[str] | None = None,
+ session: Session = NEW_SESSION,
**kwargs,
) -> DatasetEvent | None:
"""
@@ -88,24 +128,27 @@ class DatasetManager(LoggingMixin):
dataset_model = session.scalar(
select(DatasetModel)
.where(DatasetModel.uri == dataset.uri)
-
.options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag))
+ .options(
+ joinedload(DatasetModel.aliases),
+
joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag),
+ )
)
if not dataset_model:
cls.logger().warning("DatasetModel %s not found", dataset)
return None
+ cls._add_dataset_alias_association({alias.name for alias in aliases},
dataset_model, session=session)
+
event_kwargs = {
"dataset_id": dataset_model.id,
"extra": extra,
}
if task_instance:
event_kwargs.update(
- {
- "source_task_id": task_instance.task_id,
- "source_dag_id": task_instance.dag_id,
- "source_run_id": task_instance.run_id,
- "source_map_index": task_instance.map_index,
- }
+ source_task_id=task_instance.task_id,
+ source_dag_id=task_instance.dag_id,
+ source_run_id=task_instance.run_id,
+ source_map_index=task_instance.map_index,
)
dataset_event = DatasetEvent(**event_kwargs)
@@ -155,6 +198,10 @@ class DatasetManager(LoggingMixin):
"""Run applicable notification actions when a dataset is created."""
get_listener_manager().hook.on_dataset_created(dataset=dataset)
+ def notify_dataset_alias_created(self, dataset_alias: DatasetAlias):
+ """Run applicable notification actions when a dataset alias is
created."""
+
get_listener_manager().hook.on_dataset_alias_created(dataset_alias=dataset_alias)
+
@classmethod
def notify_dataset_changed(cls, dataset: Dataset):
"""Run applicable notification actions when a dataset is changed."""
diff --git a/airflow/listeners/spec/dataset.py
b/airflow/listeners/spec/dataset.py
index 214ddad3ff..eee1a10dd7 100644
--- a/airflow/listeners/spec/dataset.py
+++ b/airflow/listeners/spec/dataset.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING
from pluggy import HookspecMarker
if TYPE_CHECKING:
- from airflow.datasets import Dataset
+ from airflow.datasets import Dataset, DatasetAlias
hookspec = HookspecMarker("airflow")
@@ -34,6 +34,13 @@ def on_dataset_created(
"""Execute when a new dataset is created."""
+@hookspec
+def on_dataset_alias_created(
+ dataset_alias: DatasetAlias,
+):
+ """Execute when a new dataset alias is created."""
+
+
@hookspec
def on_dataset_changed(
dataset: Dataset,
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
index 5033da48a3..489d6b68a6 100644
--- a/airflow/models/dataset.py
+++ b/airflow/models/dataset.py
@@ -138,6 +138,9 @@ class DatasetAliasModel(Base):
else:
return NotImplemented
+ def to_public(self) -> DatasetAlias:
+ return DatasetAlias(name=self.name)
+
class DatasetModel(Base):
"""
@@ -200,6 +203,9 @@ class DatasetModel(Base):
def __repr__(self):
return f"{self.__class__.__name__}(uri={self.uri!r},
extra={self.extra!r})"
+ def to_public(self) -> Dataset:
+ return Dataset(uri=self.uri, extra=self.extra)
+
class DagScheduleDatasetAliasReference(Base):
"""References from a DAG to a dataset alias of which it is a consumer."""
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 954e5ed4d0..d3300207ab 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -31,7 +31,7 @@ from collections import defaultdict
from contextlib import nullcontext
from datetime import timedelta
from enum import Enum
-from typing import TYPE_CHECKING, Any, Callable, Collection, Generator,
Iterable, Mapping, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Generator,
Iterable, Mapping, Tuple
from urllib.parse import quote
import dill
@@ -89,7 +89,7 @@ from airflow.exceptions import (
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID, TaskInstanceDependencies,
_sentinel
from airflow.models.dagbag import DagBag
-from airflow.models.dataset import DatasetAliasModel, DatasetModel
+from airflow.models.dataset import DatasetModel
from airflow.models.log import Log
from airflow.models.param import process_params
from airflow.models.renderedtifields import get_serialized_template_fields
@@ -2893,7 +2893,7 @@ class TaskInstance(Base, LoggingMixin):
# One task only triggers one dataset event for each dataset with the
same extra.
# This tuple[dataset uri, extra] to sets alias names mapping is used
to find whether
# there're datasets with same uri but different extra that we need to
emit more than one dataset events.
- dataset_tuple_to_alias_names_mapping: dict[tuple[str, frozenset],
set[str]] = defaultdict(set)
+ dataset_alias_names: dict[tuple[str, frozenset], set[str]] =
defaultdict(set)
for obj in self.task.outlets or []:
self.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides datasets
@@ -2908,33 +2908,27 @@ class TaskInstance(Base, LoggingMixin):
for dataset_alias_event in events[obj].dataset_alias_events:
dataset_alias_name =
dataset_alias_event["source_alias_name"]
dataset_uri = dataset_alias_event["dest_dataset_uri"]
- extra = dataset_alias_event["extra"]
- frozen_extra = frozenset(extra.items())
+ frozen_extra =
frozenset(dataset_alias_event["extra"].items())
+ dataset_alias_names[(dataset_uri,
frozen_extra)].add(dataset_alias_name)
- dataset_tuple_to_alias_names_mapping[(dataset_uri,
frozen_extra)].add(dataset_alias_name)
+ class _DatasetModelCache(Dict[str, DatasetModel]):
+ log = self.log
- dataset_objs_cache: dict[str, DatasetModel] = {}
- for (uri, extra_items), alias_names in
dataset_tuple_to_alias_names_mapping.items():
- if uri not in dataset_objs_cache:
- dataset_obj =
session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1))
- dataset_objs_cache[uri] = dataset_obj
- else:
- dataset_obj = dataset_objs_cache[uri]
-
- if not dataset_obj:
- dataset_obj = DatasetModel(uri=uri)
- dataset_manager.create_datasets(dataset_models=[dataset_obj],
session=session)
- self.log.warning("Created a new %r as it did not exist.",
dataset_obj)
+ def __missing__(self, key: str) -> DatasetModel:
+ (dataset_obj,) =
dataset_manager.create_datasets([Dataset(uri=key)], session=session)
session.flush()
- dataset_objs_cache[uri] = dataset_obj
-
- for alias in alias_names:
- alias_obj = session.scalar(
- select(DatasetAliasModel).where(DatasetAliasModel.name ==
alias).limit(1)
- )
- dataset_obj.aliases.append(alias_obj)
+ self.log.warning("Created a new %r as it did not exist.",
dataset_obj)
+ self[key] = dataset_obj
+ return dataset_obj
- extra = {k: v for k, v in extra_items}
+ dataset_objs_cache = _DatasetModelCache(
+ (dataset_obj.uri, dataset_obj)
+ for dataset_obj in session.scalars(
+ select(DatasetModel).where(DatasetModel.uri.in_(uri for uri, _
in dataset_alias_names))
+ )
+ )
+ for (uri, extra_items), alias_names in dataset_alias_names.items():
+ dataset_obj = dataset_objs_cache[uri]
self.log.info(
'Creating event for %r through aliases "%s"',
dataset_obj,
@@ -2942,8 +2936,9 @@ class TaskInstance(Base, LoggingMixin):
)
dataset_manager.register_dataset_change(
task_instance=self,
- dataset=dataset_obj,
- extra=extra,
+ dataset=dataset_obj.to_public(),
+ aliases=[DatasetAlias(name) for name in alias_names],
+ extra=dict(extra_items),
session=session,
source_alias_names=alias_names,
)
diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst
b/docs/apache-airflow/administration-and-deployment/listeners.rst
index 34909e225a..4926b12ed6 100644
--- a/docs/apache-airflow/administration-and-deployment/listeners.rst
+++ b/docs/apache-airflow/administration-and-deployment/listeners.rst
@@ -95,6 +95,7 @@ Dataset Events
--------------
- ``on_dataset_created``
+- ``on_dataset_alias_created``
- ``on_dataset_changed``
Dataset events occur when Dataset management operations are run.
diff --git a/newsfragments/42343.feature.rst b/newsfragments/42343.feature.rst
new file mode 100644
index 0000000000..8a7cdf335a
--- /dev/null
+++ b/newsfragments/42343.feature.rst
@@ -0,0 +1 @@
+New function ``create_dataset_aliases`` added to DatasetManager for
DatasetAlias creation.
diff --git a/newsfragments/42343.significant.rst
b/newsfragments/42343.significant.rst
new file mode 100644
index 0000000000..d9e1ba6b12
--- /dev/null
+++ b/newsfragments/42343.significant.rst
@@ -0,0 +1,7 @@
+``DatasetManager.create_datasets`` now takes ``Dataset`` objects
+
+This function previously accepts a list of ``DatasetModel`` objects. it now
+receives ``Dataset`` objects instead. A list of ``DatasetModel`` objects are
+created inside, and returned by the function.
+
+Also, the ``session`` argument is now keyword-only.
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index 1e7b4fda40..d3013aef60 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -169,10 +169,11 @@ class TestDatasetManager:
dataset_listener.clear()
get_listener_manager().add_listener(dataset_listener)
- dsm = DatasetModel(uri="test_dataset_uri_3")
+ ds = Dataset(uri="test_dataset_uri_3")
- dsem.create_datasets([dsm], session)
+ dsms = dsem.create_datasets([ds], session=session)
# Ensure the listener was notified
assert len(dataset_listener.created) == 1
- assert dataset_listener.created[0].uri == dsm.uri
+ assert len(dsms) == 1
+ assert dataset_listener.created[0].uri == ds.uri == dsms[0].uri