This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9439111e73 Add Listener hooks for Datasets (#34418)
9439111e73 is described below

commit 9439111e739e24f0e3751350186b0e2130d2c821
Author: Pasha Yermalovich <142800033+yermalov-h...@users.noreply.github.com>
AuthorDate: Mon Nov 13 15:15:10 2023 +0100

    Add Listener hooks for Datasets (#34418)
    
    This PR creates listener hooks for the following Dataset events
    * on_dataset_created
    * on_dataset_changed
    
    closes: #34327
---
 airflow/datasets/manager.py                        | 23 ++++++++-
 airflow/listeners/listener.py                      |  3 +-
 airflow/listeners/spec/dataset.py                  | 41 +++++++++++++++
 airflow/models/dag.py                              | 16 +++---
 .../administration-and-deployment/listeners.rst    |  7 +++
 tests/datasets/test_manager.py                     | 35 +++++++++++++
 tests/listeners/dataset_listener.py                | 45 ++++++++++++++++
 tests/listeners/test_dataset_listener.py           | 60 ++++++++++++++++++++++
 8 files changed, 221 insertions(+), 9 deletions(-)

diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 8714ba0658..08871c9f65 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING
 from sqlalchemy import exc, select
 
 from airflow.configuration import conf
+from airflow.datasets import Dataset
+from airflow.listeners.listener import get_listener_manager
 from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, 
DatasetModel
 from airflow.stats import Stats
 from airflow.utils.log.logging_mixin import LoggingMixin
@@ -29,7 +31,6 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 if TYPE_CHECKING:
     from sqlalchemy.orm.session import Session
 
-    from airflow.datasets import Dataset
     from airflow.models.taskinstance import TaskInstance
 
 
@@ -44,6 +45,15 @@ class DatasetManager(LoggingMixin):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
 
+    def create_datasets(self, dataset_models: list[DatasetModel], session: 
Session) -> None:
+        """Create new datasets."""
+        for dataset_model in dataset_models:
+            session.add(dataset_model)
+        session.flush()
+
+        for dataset_model in dataset_models:
+            self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, 
extra=dataset_model.extra))
+
     def register_dataset_change(
         self, *, task_instance: TaskInstance, dataset: Dataset, extra=None, 
session: Session, **kwargs
     ) -> None:
@@ -68,11 +78,22 @@ class DatasetManager(LoggingMixin):
             )
         )
         session.flush()
+
+        self.notify_dataset_changed(dataset=dataset)
+
         Stats.incr("dataset.updates")
         if dataset_model.consuming_dags:
             self._queue_dagruns(dataset_model, session)
         session.flush()
 
+    def notify_dataset_created(self, dataset: Dataset):
+        """Run applicable notification actions when a dataset is created."""
+        get_listener_manager().hook.on_dataset_created(dataset=dataset)
+
+    def notify_dataset_changed(self, dataset: Dataset):
+        """Run applicable notification actions when a dataset is changed."""
+        get_listener_manager().hook.on_dataset_changed(dataset=dataset)
+
     def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
         # Possible race condition: if multiple dags or multiple (usually
         # mapped) tasks update the same dataset, this can fail with a unique
diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py
index eb738c3e91..d7944aa4eb 100644
--- a/airflow/listeners/listener.py
+++ b/airflow/listeners/listener.py
@@ -37,11 +37,12 @@ class ListenerManager:
     """Manage listener registration and provides hook property for calling 
them."""
 
     def __init__(self):
-        from airflow.listeners.spec import dagrun, lifecycle, taskinstance
+        from airflow.listeners.spec import dagrun, dataset, lifecycle, 
taskinstance
 
         self.pm = pluggy.PluginManager("airflow")
         self.pm.add_hookspecs(lifecycle)
         self.pm.add_hookspecs(dagrun)
+        self.pm.add_hookspecs(dataset)
         self.pm.add_hookspecs(taskinstance)
 
     @property
diff --git a/airflow/listeners/spec/dataset.py 
b/airflow/listeners/spec/dataset.py
new file mode 100644
index 0000000000..214ddad3ff
--- /dev/null
+++ b/airflow/listeners/spec/dataset.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from pluggy import HookspecMarker
+
+if TYPE_CHECKING:
+    from airflow.datasets import Dataset
+
+hookspec = HookspecMarker("airflow")
+
+
+@hookspec
+def on_dataset_created(
+    dataset: Dataset,
+):
+    """Execute when a new dataset is created."""
+
+
+@hookspec
+def on_dataset_changed(
+    dataset: Dataset,
+):
+    """Execute when dataset change is registered."""
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9a09b706e5..428f79e14e 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -78,6 +78,7 @@ import airflow.templates
 from airflow import settings, utils
 from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf as airflow_conf, secrets_backend_list
+from airflow.datasets.manager import dataset_manager
 from airflow.exceptions import (
     AirflowDagInconsistent,
     AirflowException,
@@ -3137,8 +3138,8 @@ class DAG(LoggingMixin):
         dag_references = collections.defaultdict(set)
         outlet_references = collections.defaultdict(set)
         # We can't use a set here as we want to preserve order
-        outlet_datasets: dict[Dataset, None] = {}
-        input_datasets: dict[Dataset, None] = {}
+        outlet_datasets: dict[DatasetModel, None] = {}
+        input_datasets: dict[DatasetModel, None] = {}
 
         # here we go through dags and tasks to check for dataset references
         # if there are now None and previously there were some, we delete them
@@ -3171,7 +3172,8 @@ class DAG(LoggingMixin):
         all_datasets.update(input_datasets)
 
         # store datasets
-        stored_datasets = {}
+        stored_datasets: dict[str, DatasetModel] = {}
+        new_datasets: list[DatasetModel] = []
         for dataset in all_datasets:
             stored_dataset = session.scalar(
                 select(DatasetModel).where(DatasetModel.uri == 
dataset.uri).limit(1)
@@ -3183,11 +3185,11 @@ class DAG(LoggingMixin):
                 stored_dataset.is_orphaned = expression.false()
                 stored_datasets[stored_dataset.uri] = stored_dataset
             else:
-                session.add(dataset)
-                stored_datasets[dataset.uri] = dataset
-
-        session.flush()  # this is required to ensure each dataset has its PK 
loaded
+                new_datasets.append(dataset)
+        dataset_manager.create_datasets(dataset_models=new_datasets, 
session=session)
+        stored_datasets.update({dataset.uri: dataset for dataset in 
new_datasets})
 
+        del new_datasets
         del all_datasets
 
         # reconcile dag-schedule-on-dataset references
diff --git a/docs/apache-airflow/administration-and-deployment/listeners.rst 
b/docs/apache-airflow/administration-and-deployment/listeners.rst
index 4182d135a1..0672e07779 100644
--- a/docs/apache-airflow/administration-and-deployment/listeners.rst
+++ b/docs/apache-airflow/administration-and-deployment/listeners.rst
@@ -50,6 +50,13 @@ TaskInstance State Change Events
 TaskInstance state change events occur when a ``TaskInstance`` changes state.
 You can use these events to react to ``LocalTaskJob`` state changes.
 
+Dataset Events
+--------------
+
+- ``on_dataset_created``
+- ``on_dataset_changed``
+
+Dataset events occur when Dataset management operations are run.
 
 Usage
 -----
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index 19b6b1ed45..514ed8877a 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -24,8 +24,10 @@ import pytest
 
 from airflow.datasets import Dataset
 from airflow.datasets.manager import DatasetManager
+from airflow.listeners.listener import get_listener_manager
 from airflow.models.dag import DagModel
 from airflow.models.dataset import DagScheduleDatasetReference, 
DatasetDagRunQueue, DatasetEvent, DatasetModel
+from tests.listeners import dataset_listener
 
 pytestmark = pytest.mark.db_test
 
@@ -96,3 +98,36 @@ class TestDatasetManager:
         # Ensure we've created a dataset
         assert 
session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1
         assert session.query(DatasetDagRunQueue).count() == 0
+
+    def test_register_dataset_change_notifies_dataset_listener(self, session, 
mock_task_instance):
+        dsem = DatasetManager()
+        dataset_listener.clear()
+        get_listener_manager().add_listener(dataset_listener)
+
+        ds = Dataset(uri="test_dataset_uri")
+        dag1 = DagModel(dag_id="dag1")
+        session.add_all([dag1])
+
+        dsm = DatasetModel(uri="test_dataset_uri")
+        session.add(dsm)
+        dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag1.dag_id)]
+        session.flush()
+
+        dsem.register_dataset_change(task_instance=mock_task_instance, 
dataset=ds, session=session)
+
+        # Ensure the listener was notified
+        assert len(dataset_listener.changed) == 1
+        assert dataset_listener.changed[0].uri == ds.uri
+
+    def test_create_datasets_notifies_dataset_listener(self, session):
+        dsem = DatasetManager()
+        dataset_listener.clear()
+        get_listener_manager().add_listener(dataset_listener)
+
+        dsm = DatasetModel(uri="test_dataset_uri")
+
+        dsem.create_datasets([dsm], session)
+
+        # Ensure the listener was notified
+        assert len(dataset_listener.created) == 1
+        assert dataset_listener.created[0].uri == dsm.uri
diff --git a/tests/listeners/dataset_listener.py 
b/tests/listeners/dataset_listener.py
new file mode 100644
index 0000000000..0e4b768c69
--- /dev/null
+++ b/tests/listeners/dataset_listener.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import copy
+import typing
+
+from airflow.listeners import hookimpl
+
+if typing.TYPE_CHECKING:
+    from airflow.datasets import Dataset
+
+
+changed: list[Dataset] = []
+created: list[Dataset] = []
+
+
+@hookimpl
+def on_dataset_changed(dataset):
+    changed.append(copy.deepcopy(dataset))
+
+
+@hookimpl
+def on_dataset_created(dataset):
+    created.append(copy.deepcopy(dataset))
+
+
+def clear():
+    global changed, created
+    changed, created = [], []
diff --git a/tests/listeners/test_dataset_listener.py 
b/tests/listeners/test_dataset_listener.py
new file mode 100644
index 0000000000..d17f079e2d
--- /dev/null
+++ b/tests/listeners/test_dataset_listener.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.datasets import Dataset
+from airflow.listeners.listener import get_listener_manager
+from airflow.models.dataset import DatasetModel
+from airflow.operators.empty import EmptyOperator
+from airflow.utils.session import provide_session
+from tests.listeners import dataset_listener
+
+
+@pytest.fixture(autouse=True)
+def clean_listener_manager():
+    lm = get_listener_manager()
+    lm.clear()
+    lm.add_listener(dataset_listener)
+    yield
+    lm = get_listener_manager()
+    lm.clear()
+    dataset_listener.clear()
+
+
+@pytest.mark.db_test
+@provide_session
+def 
test_dataset_listener_on_dataset_changed_gets_calls(create_task_instance_of_operator,
 session):
+    dataset_uri = "test_dataset_uri"
+    ds = Dataset(uri=dataset_uri)
+    ds_model = DatasetModel(uri=dataset_uri)
+    session.add(ds_model)
+
+    session.flush()
+
+    ti = create_task_instance_of_operator(
+        operator_class=EmptyOperator,
+        dag_id="producing_dag",
+        task_id="test_task",
+        session=session,
+        outlets=[ds],
+    )
+    ti.run()
+
+    assert len(dataset_listener.changed) == 1
+    assert dataset_listener.changed[0].uri == dataset_uri

Reply via email to