This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 9439111e73 Add Listener hooks for Datasets (#34418) 9439111e73 is described below commit 9439111e739e24f0e3751350186b0e2130d2c821 Author: Pasha Yermalovich <142800033+yermalov-h...@users.noreply.github.com> AuthorDate: Mon Nov 13 15:15:10 2023 +0100 Add Listener hooks for Datasets (#34418) This PR creates listener hooks for the following Dataset events * on_dataset_created * on_dataset_changed closes: #34327 --- airflow/datasets/manager.py | 23 ++++++++- airflow/listeners/listener.py | 3 +- airflow/listeners/spec/dataset.py | 41 +++++++++++++++ airflow/models/dag.py | 16 +++--- .../administration-and-deployment/listeners.rst | 7 +++ tests/datasets/test_manager.py | 35 +++++++++++++ tests/listeners/dataset_listener.py | 45 ++++++++++++++++ tests/listeners/test_dataset_listener.py | 60 ++++++++++++++++++++++ 8 files changed, 221 insertions(+), 9 deletions(-) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 8714ba0658..08871c9f65 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING from sqlalchemy import exc, select from airflow.configuration import conf +from airflow.datasets import Dataset +from airflow.listeners.listener import get_listener_manager from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin @@ -29,7 +31,6 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from sqlalchemy.orm.session import Session - from airflow.datasets import Dataset from airflow.models.taskinstance import TaskInstance @@ -44,6 +45,15 @@ class DatasetManager(LoggingMixin): def __init__(self, **kwargs): super().__init__(**kwargs) + def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None: + """Create new datasets.""" + for dataset_model in dataset_models: + session.add(dataset_model) + session.flush() + + for dataset_model in dataset_models: + self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra)) + def register_dataset_change( self, *, task_instance: TaskInstance, dataset: Dataset, extra=None, session: Session, **kwargs ) -> None: @@ -68,11 +78,22 @@ class DatasetManager(LoggingMixin): ) ) session.flush() + + self.notify_dataset_changed(dataset=dataset) + Stats.incr("dataset.updates") if dataset_model.consuming_dags: self._queue_dagruns(dataset_model, session) session.flush() + def notify_dataset_created(self, dataset: Dataset): + """Run applicable notification actions when a dataset is created.""" + get_listener_manager().hook.on_dataset_created(dataset=dataset) + + def notify_dataset_changed(self, dataset: Dataset): + """Run applicable notification actions when a dataset is changed.""" + get_listener_manager().hook.on_dataset_changed(dataset=dataset) + def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: # Possible race condition: if multiple dags or multiple (usually # mapped) tasks update the same dataset, this can fail with a unique diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py index eb738c3e91..d7944aa4eb 100644 --- a/airflow/listeners/listener.py +++ b/airflow/listeners/listener.py @@ -37,11 +37,12 @@ class ListenerManager: """Manage listener registration and provides hook property for calling them.""" def __init__(self): - from airflow.listeners.spec import dagrun, lifecycle, taskinstance + from airflow.listeners.spec import dagrun, dataset, lifecycle, taskinstance self.pm = pluggy.PluginManager("airflow") self.pm.add_hookspecs(lifecycle) self.pm.add_hookspecs(dagrun) + self.pm.add_hookspecs(dataset) self.pm.add_hookspecs(taskinstance) @property diff --git a/airflow/listeners/spec/dataset.py b/airflow/listeners/spec/dataset.py new file mode 100644 index 0000000000..214ddad3ff --- /dev/null +++ b/airflow/listeners/spec/dataset.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pluggy import HookspecMarker + +if TYPE_CHECKING: + from airflow.datasets import Dataset + +hookspec = HookspecMarker("airflow") + + +@hookspec +def on_dataset_created( + dataset: Dataset, +): + """Execute when a new dataset is created.""" + + +@hookspec +def on_dataset_changed( + dataset: Dataset, +): + """Execute when dataset change is registered.""" diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 9a09b706e5..428f79e14e 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -78,6 +78,7 @@ import airflow.templates from airflow import settings, utils from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf as airflow_conf, secrets_backend_list +from airflow.datasets.manager import dataset_manager from airflow.exceptions import ( AirflowDagInconsistent, AirflowException, @@ -3137,8 +3138,8 @@ class DAG(LoggingMixin): dag_references = collections.defaultdict(set) outlet_references = collections.defaultdict(set) # We can't use a set here as we want to preserve order - outlet_datasets: dict[Dataset, None] = {} - input_datasets: dict[Dataset, None] = {} + outlet_datasets: dict[DatasetModel, None] = {} + input_datasets: dict[DatasetModel, None] = {} # here we go through dags and tasks to check for dataset references # if there are now None and previously there were some, we delete them @@ -3171,7 +3172,8 @@ class DAG(LoggingMixin): all_datasets.update(input_datasets) # store datasets - stored_datasets = {} + stored_datasets: dict[str, DatasetModel] = {} + new_datasets: list[DatasetModel] = [] for dataset in all_datasets: stored_dataset = session.scalar( select(DatasetModel).where(DatasetModel.uri == dataset.uri).limit(1) @@ -3183,11 +3185,11 @@ class DAG(LoggingMixin): stored_dataset.is_orphaned = expression.false() stored_datasets[stored_dataset.uri] = stored_dataset else: - session.add(dataset) - stored_datasets[dataset.uri] = dataset - - session.flush() # this is required to ensure each dataset has its PK loaded + new_datasets.append(dataset) + dataset_manager.create_datasets(dataset_models=new_datasets, session=session) + stored_datasets.update({dataset.uri: dataset for dataset in new_datasets}) + del new_datasets del all_datasets # reconcile dag-schedule-on-dataset references diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst b/docs/apache-airflow/administration-and-deployment/listeners.rst index 4182d135a1..0672e07779 100644 --- a/docs/apache-airflow/administration-and-deployment/listeners.rst +++ b/docs/apache-airflow/administration-and-deployment/listeners.rst @@ -50,6 +50,13 @@ TaskInstance State Change Events TaskInstance state change events occur when a ``TaskInstance`` changes state. You can use these events to react to ``LocalTaskJob`` state changes. +Dataset Events +-------------- + +- ``on_dataset_created`` +- ``on_dataset_changed`` + +Dataset events occur when Dataset management operations are run. Usage ----- diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 19b6b1ed45..514ed8877a 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -24,8 +24,10 @@ import pytest from airflow.datasets import Dataset from airflow.datasets.manager import DatasetManager +from airflow.listeners.listener import get_listener_manager from airflow.models.dag import DagModel from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel +from tests.listeners import dataset_listener pytestmark = pytest.mark.db_test @@ -96,3 +98,36 @@ class TestDatasetManager: # Ensure we've created a dataset assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 assert session.query(DatasetDagRunQueue).count() == 0 + + def test_register_dataset_change_notifies_dataset_listener(self, session, mock_task_instance): + dsem = DatasetManager() + dataset_listener.clear() + get_listener_manager().add_listener(dataset_listener) + + ds = Dataset(uri="test_dataset_uri") + dag1 = DagModel(dag_id="dag1") + session.add_all([dag1]) + + dsm = DatasetModel(uri="test_dataset_uri") + session.add(dsm) + dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag1.dag_id)] + session.flush() + + dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) + + # Ensure the listener was notified + assert len(dataset_listener.changed) == 1 + assert dataset_listener.changed[0].uri == ds.uri + + def test_create_datasets_notifies_dataset_listener(self, session): + dsem = DatasetManager() + dataset_listener.clear() + get_listener_manager().add_listener(dataset_listener) + + dsm = DatasetModel(uri="test_dataset_uri") + + dsem.create_datasets([dsm], session) + + # Ensure the listener was notified + assert len(dataset_listener.created) == 1 + assert dataset_listener.created[0].uri == dsm.uri diff --git a/tests/listeners/dataset_listener.py b/tests/listeners/dataset_listener.py new file mode 100644 index 0000000000..0e4b768c69 --- /dev/null +++ b/tests/listeners/dataset_listener.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import copy +import typing + +from airflow.listeners import hookimpl + +if typing.TYPE_CHECKING: + from airflow.datasets import Dataset + + +changed: list[Dataset] = [] +created: list[Dataset] = [] + + +@hookimpl +def on_dataset_changed(dataset): + changed.append(copy.deepcopy(dataset)) + + +@hookimpl +def on_dataset_created(dataset): + created.append(copy.deepcopy(dataset)) + + +def clear(): + global changed, created + changed, created = [], [] diff --git a/tests/listeners/test_dataset_listener.py b/tests/listeners/test_dataset_listener.py new file mode 100644 index 0000000000..d17f079e2d --- /dev/null +++ b/tests/listeners/test_dataset_listener.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.datasets import Dataset +from airflow.listeners.listener import get_listener_manager +from airflow.models.dataset import DatasetModel +from airflow.operators.empty import EmptyOperator +from airflow.utils.session import provide_session +from tests.listeners import dataset_listener + + +@pytest.fixture(autouse=True) +def clean_listener_manager(): + lm = get_listener_manager() + lm.clear() + lm.add_listener(dataset_listener) + yield + lm = get_listener_manager() + lm.clear() + dataset_listener.clear() + + +@pytest.mark.db_test +@provide_session +def test_dataset_listener_on_dataset_changed_gets_calls(create_task_instance_of_operator, session): + dataset_uri = "test_dataset_uri" + ds = Dataset(uri=dataset_uri) + ds_model = DatasetModel(uri=dataset_uri) + session.add(ds_model) + + session.flush() + + ti = create_task_instance_of_operator( + operator_class=EmptyOperator, + dag_id="producing_dag", + task_id="test_task", + session=session, + outlets=[ds], + ) + ti.run() + + assert len(dataset_listener.changed) == 1 + assert dataset_listener.changed[0].uri == dataset_uri