This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ad8787a3e1 AIP-82 Introduce `BaseEventTrigger` as base class for 
triggers used with event driven scheduling (#46391)
3ad8787a3e1 is described below

commit 3ad8787a3e13a6733b0cf277ad3800defa74dcee
Author: Vincent <[email protected]>
AuthorDate: Thu Feb 13 09:50:05 2025 -0500

    AIP-82 Introduce `BaseEventTrigger` as base class for triggers used with 
event driven scheduling (#46391)
---
 airflow/dag_processing/collection.py               | 28 +++-------
 .../example_dags/example_asset_with_watchers.py    | 26 +++-------
 airflow/serialization/serialized_objects.py        |  3 +-
 airflow/triggers/base.py                           | 22 ++++++++
 .../airflow/providers/standard/triggers/file.py    | 60 ++++++++++++++++++++--
 .../provider_tests/standard/triggers/test_file.py  | 44 +++++++++++++++-
 .../src/airflow/sdk/definitions/asset/__init__.py  | 22 ++++++--
 tests/serialization/test_serialized_objects.py     |  4 +-
 8 files changed, 158 insertions(+), 51 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
index 7bd5b316714..e581b02478b 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -27,10 +27,9 @@ This should generally only be called by internal methods 
such as
 
 from __future__ import annotations
 
-import json
 import logging
 import traceback
-from typing import TYPE_CHECKING, Any, NamedTuple, cast
+from typing import TYPE_CHECKING, NamedTuple, cast
 
 from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
 from sqlalchemy.exc import OperationalError
@@ -53,7 +52,8 @@ from airflow.models.dagwarning import DagWarningType
 from airflow.models.errors import ParseImportError
 from airflow.models.trigger import Trigger
 from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, 
AssetUriRef
-from airflow.serialization.serialized_objects import BaseSerialization, 
SerializedAssetWatcher
+from airflow.serialization.serialized_objects import SerializedAssetWatcher
+from airflow.triggers.base import BaseEventTrigger
 from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
 from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.timezone import utcnow
@@ -758,7 +758,7 @@ class AssetModelOperation(NamedTuple):
                 else []
             )
             trigger_hash_to_trigger_dict: dict[int, dict] = {
-                self._get_trigger_hash(
+                BaseEventTrigger.hash(
                     watcher.trigger["classpath"], watcher.trigger["kwargs"]
                 ): watcher.trigger
                 for watcher in asset_watchers
@@ -768,7 +768,7 @@ class AssetModelOperation(NamedTuple):
 
             asset_model = assets[name_uri]
             trigger_hash_from_asset_model: set[int] = {
-                self._get_trigger_hash(trigger.classpath, trigger.kwargs) for 
trigger in asset_model.triggers
+                BaseEventTrigger.hash(trigger.classpath, trigger.kwargs) for 
trigger in asset_model.triggers
             }
 
             # Optimization: no diff between the DB and DAG definitions, no 
update needed
@@ -796,7 +796,7 @@ class AssetModelOperation(NamedTuple):
                 for trigger_hash in trigger_hashes
             }
             orm_triggers: dict[int, Trigger] = {
-                self._get_trigger_hash(trigger.classpath, trigger.kwargs): 
trigger
+                BaseEventTrigger.hash(trigger.classpath, trigger.kwargs): 
trigger
                 for trigger in session.scalars(
                     select(Trigger).where(
                         tuple_(Trigger.classpath, 
Trigger.encrypted_kwargs).in_(all_trigger_keys)
@@ -817,7 +817,7 @@ class AssetModelOperation(NamedTuple):
             ]
             session.add_all(new_trigger_models)
             orm_triggers.update(
-                (self._get_trigger_hash(trigger.classpath, trigger.kwargs), 
trigger)
+                (BaseEventTrigger.hash(trigger.classpath, trigger.kwargs), 
trigger)
                 for trigger in new_trigger_models
             )
 
@@ -835,7 +835,7 @@ class AssetModelOperation(NamedTuple):
                 asset_model.triggers = [
                     trigger
                     for trigger in asset_model.triggers
-                    if self._get_trigger_hash(trigger.classpath, 
trigger.kwargs) not in trigger_hashes
+                    if BaseEventTrigger.hash(trigger.classpath, 
trigger.kwargs) not in trigger_hashes
                 ]
 
         # Remove references from assets no longer used
@@ -845,15 +845,3 @@ class AssetModelOperation(NamedTuple):
         for asset_model in orphan_assets:
             if (asset_model.name, asset_model.uri) not in self.assets:
                 asset_model.triggers = []
-
-    @staticmethod
-    def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
-        """
-        Return the hash of the trigger classpath and kwargs. This is used to 
uniquely identify a trigger.
-
-        We do not want to move this logic in a `__hash__` method in 
`BaseTrigger` because we do not want to
-        make the triggers hashable. The reason being, when the triggerer 
retrieve the list of triggers, we do
-        not want it dedupe them. When used to defer tasks, 2 triggers can have 
the same classpath and kwargs.
-        This is not true for event driven scheduling.
-        """
-        return hash((classpath, 
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
diff --git a/airflow/example_dags/example_asset_with_watchers.py 
b/airflow/example_dags/example_asset_with_watchers.py
index 79d012a2ed4..4f65a7f9305 100644
--- a/airflow/example_dags/example_asset_with_watchers.py
+++ b/airflow/example_dags/example_asset_with_watchers.py
@@ -20,30 +20,17 @@ Example DAG for demonstrating the usage of event driven 
scheduling using assets
 
 from __future__ import annotations
 
-import os
-
 from airflow.decorators import task
 from airflow.models.baseoperator import chain
 from airflow.models.dag import DAG
-from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.providers.standard.triggers.file import FileDeleteTrigger
 from airflow.sdk import Asset, AssetWatcher
 
 file_path = "/tmp/test"
 
-with DAG(
-    dag_id="example_create_file",
-    catchup=False,
-):
-
-    @task
-    def create_file():
-        with open(file_path, "w") as file:
-            file.write("This is an example file.\n")
-
-    chain(create_file())
+trigger = FileDeleteTrigger(filepath=file_path)
+asset = Asset("example_asset", 
watchers=[AssetWatcher(name="test_asset_watcher", trigger=trigger)])
 
-trigger = FileTrigger(filepath=file_path, poke_interval=10)
-asset = Asset("example_asset", 
watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])
 
 with DAG(
     dag_id="example_asset_with_watchers",
@@ -52,8 +39,7 @@ with DAG(
 ):
 
     @task
-    def delete_file():
-        if os.path.exists(file_path):
-            os.remove(file_path)  # Delete the file
+    def test_task():
+        print("Hello world")
 
-    chain(delete_file())
+    chain(test_task())
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 08d032e873c..edbc6b4e5bd 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -102,6 +102,7 @@ if TYPE_CHECKING:
     from airflow.sdk.types import Operator
     from airflow.serialization.json_schema import Validator
     from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
+    from airflow.triggers.base import BaseEventTrigger
 
     HAS_KUBERNETES: bool
     try:
@@ -259,7 +260,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, 
Any]:
                 "trigger": _encode_trigger(watcher.trigger),
             }
 
-        def _encode_trigger(trigger: BaseTrigger | dict):
+        def _encode_trigger(trigger: BaseEventTrigger | dict):
             if isinstance(trigger, dict):
                 return trigger
             classpath, kwargs = trigger.serialize()
diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py
index d36dd40b6f8..f5585dcc800 100644
--- a/airflow/triggers/base.py
+++ b/airflow/triggers/base.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import abc
+import json
 import logging
 from collections.abc import AsyncIterator
 from dataclasses import dataclass
@@ -126,6 +127,27 @@ class BaseTrigger(abc.ABC, LoggingMixin):
         return self.repr(classpath, kwargs)
 
 
+class BaseEventTrigger(BaseTrigger):
+    """
+    Base class for triggers used to schedule DAGs based on external events.
+
+    ``BaseEventTrigger`` is a subclass of ``BaseTrigger`` designed to identify 
triggers compatible with
+    event-driven scheduling.
+    """
+
+    @staticmethod
+    def hash(classpath: str, kwargs: dict[str, Any]) -> int:
+        """
+        Return the hash of the trigger classpath and kwargs. This is used to 
uniquely identify a trigger.
+
+        We do not want to have this logic in ``BaseTrigger`` because, when 
used to defer tasks, 2 triggers
+        can have the same classpath and kwargs. This is not true for event 
driven scheduling.
+        """
+        from airflow.serialization.serialized_objects import BaseSerialization
+
+        return hash((classpath, 
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
+
+
 class TriggerEvent:
     """
     Something that a trigger can fire when its conditions are met.
diff --git a/providers/standard/src/airflow/providers/standard/triggers/file.py 
b/providers/standard/src/airflow/providers/standard/triggers/file.py
index f6a7715a035..6df163a6f7c 100644
--- a/providers/standard/src/airflow/providers/standard/triggers/file.py
+++ b/providers/standard/src/airflow/providers/standard/triggers/file.py
@@ -19,11 +19,20 @@ from __future__ import annotations
 import asyncio
 import datetime
 import os
-import typing
+from collections.abc import AsyncIterator
 from glob import glob
 from typing import Any
 
-from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.triggers.base import BaseEventTrigger, BaseTrigger, 
TriggerEvent
+else:
+    from airflow.triggers.base import (  # type: ignore
+        BaseTrigger,
+        BaseTrigger as BaseEventTrigger,
+        TriggerEvent,
+    )
 
 
 class FileTrigger(BaseTrigger):
@@ -60,7 +69,7 @@ class FileTrigger(BaseTrigger):
             },
         )
 
-    async def run(self) -> typing.AsyncIterator[TriggerEvent]:
+    async def run(self) -> AsyncIterator[TriggerEvent]:
         """Loop until the relevant files are found."""
         while True:
             for path in glob(self.filepath, recursive=self.recursive):
@@ -75,3 +84,48 @@ class FileTrigger(BaseTrigger):
                         yield TriggerEvent(True)
                         return
             await asyncio.sleep(self.poke_interval)
+
+
+class FileDeleteTrigger(BaseEventTrigger):
+    """
+    A trigger that fires exactly once after it finds the requested file and 
then delete the file.
+
+    The difference between ``FileTrigger`` and ``FileDeleteTrigger`` is 
``FileDeleteTrigger`` can only find a
+    specific file.
+
+    :param filepath: File (relative to the base path set within the 
connection).
+    :param poke_interval: Time that the job should wait in between each try
+    """
+
+    def __init__(
+        self,
+        filepath: str,
+        poke_interval: float = 5.0,
+        **kwargs,
+    ):
+        super().__init__()
+        self.filepath = filepath
+        self.poke_interval = poke_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize FileDeleteTrigger arguments and classpath."""
+        return (
+            "airflow.providers.standard.triggers.file.FileDeleteTrigger",
+            {
+                "filepath": self.filepath,
+                "poke_interval": self.poke_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Loop until the relevant file is found."""
+        while True:
+            if os.path.isfile(self.filepath):
+                mod_time_f = os.path.getmtime(self.filepath)
+                mod_time = 
datetime.datetime.fromtimestamp(mod_time_f).strftime("%Y%m%d%H%M%S")
+                self.log.info("Found file %s last modified: %s", 
self.filepath, mod_time)
+                os.remove(self.filepath)
+                self.log.info("File %s has been deleted", self.filepath)
+                yield TriggerEvent(True)
+                return
+            await asyncio.sleep(self.poke_interval)
diff --git 
a/providers/standard/tests/provider_tests/standard/triggers/test_file.py 
b/providers/standard/tests/provider_tests/standard/triggers/test_file.py
index baf0dffa80d..b69b5857e2b 100644
--- a/providers/standard/tests/provider_tests/standard/triggers/test_file.py
+++ b/providers/standard/tests/provider_tests/standard/triggers/test_file.py
@@ -20,7 +20,8 @@ import asyncio
 
 import pytest
 
-from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.providers.standard.triggers.file import FileDeleteTrigger, 
FileTrigger
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
 
 
 class TestFileTrigger:
@@ -62,3 +63,44 @@ class TestFileTrigger:
 
         # Prevents error when task is destroyed while in "pending" state
         asyncio.get_event_loop().stop()
+
+
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0")
+class TestFileDeleteTrigger:
+    FILE_PATH = "/files/dags/example_async_file.py"
+
+    def test_serialization(self):
+        """Asserts that the trigger correctly serializes its arguments and 
classpath."""
+        trigger = FileDeleteTrigger(filepath=self.FILE_PATH, poll_interval=5)
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.standard.triggers.file.FileDeleteTrigger"
+        assert kwargs == {
+            "filepath": self.FILE_PATH,
+            "poke_interval": 5,
+        }
+
+    @pytest.mark.asyncio
+    async def test_file_delete_trigger(self, tmp_path):
+        """Asserts that the trigger goes off on or after file is found and 
that the files gets deleted."""
+        tmp_dir = tmp_path / "test_dir"
+        tmp_dir.mkdir()
+        p = tmp_dir / "hello.txt"
+
+        trigger = FileDeleteTrigger(
+            filepath=str(p.resolve()),
+            poke_interval=0.2,
+        )
+
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        # It should not have produced a result
+        assert task.done() is False
+
+        p.touch()
+
+        await asyncio.sleep(0.5)
+        assert p.exists() is False
+
+        # Prevents error when task is destroyed while in "pending" state
+        asyncio.get_event_loop().stop()
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index f7a515ea434..f5e42e82e36 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
 
     from airflow.models.asset import AssetModel
     from airflow.serialization.serialized_objects import SerializedAssetWatcher
-    from airflow.triggers.base import BaseTrigger
+    from airflow.triggers.base import BaseEventTrigger
 
     AttrsInstance = attrs.AttrsInstance
 else:
@@ -254,7 +254,7 @@ class BaseAsset:
         raise NotImplementedError
 
 
[email protected](frozen=True)
[email protected](init=False)
 class AssetWatcher:
     """A representation of an asset watcher. The name uniquely identifies the 
watch."""
 
@@ -263,8 +263,22 @@ class AssetWatcher:
     # For a "normal" asset instance loaded from DAG, this holds the trigger 
used to monitor an external
     # resource. In that case, ``AssetWatcher`` is used directly by users.
     # For an asset recreated from a serialized DAG, this holds the serialized 
data of the trigger. In that
-    # case, `SerializedAssetWatcher` is used. We need to keep the two types to 
make mypy happy.
-    trigger: BaseTrigger | dict
+    # case, `SerializedAssetWatcher` is used. We need to keep the two types to 
make mypy happy because
+    # `SerializedAssetWatcher` is a subclass of `AssetWatcher`.
+    trigger: BaseEventTrigger | dict
+
+    def __init__(
+        self,
+        name: str,
+        trigger: BaseEventTrigger | dict,
+    ) -> None:
+        from airflow.triggers.base import BaseEventTrigger, BaseTrigger
+
+        if isinstance(trigger, BaseTrigger) and not isinstance(trigger, 
BaseEventTrigger):
+            raise ValueError("The trigger used to watch an asset must inherit 
``BaseEventTrigger``")
+
+        self.name = name
+        self.trigger = trigger
 
 
 @attrs.define(init=False, unsafe_hash=False)
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index 945a52e9f4f..ba3b2bca7ae 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -41,7 +41,7 @@ from airflow.models.taskinstance import SimpleTaskInstance, 
TaskInstance
 from airflow.models.xcom_arg import XComArg
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.providers.standard.operators.python import PythonOperator
-from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.providers.standard.triggers.file import FileDeleteTrigger
 from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, 
AssetUniqueKey, AssetWatcher
 from airflow.sdk.definitions.param import Param
 from airflow.sdk.execution_time.context import OutletEventAccessor, 
OutletEventAccessors
@@ -259,7 +259,7 @@ class MockLazySelectSequence(LazySelectSequence):
             Asset(
                 uri="test://asset1",
                 name="test",
-                watchers=[AssetWatcher(name="test", 
trigger=FileTrigger(filepath="/tmp"))],
+                watchers=[AssetWatcher(name="test", 
trigger=FileDeleteTrigger(filepath="/tmp"))],
             ),
             DAT.ASSET,
             equals,

Reply via email to