This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3ad8787a3e1 AIP-82 Introduce `BaseEventTrigger` as base class for
triggers used with event driven scheduling (#46391)
3ad8787a3e1 is described below
commit 3ad8787a3e13a6733b0cf277ad3800defa74dcee
Author: Vincent <[email protected]>
AuthorDate: Thu Feb 13 09:50:05 2025 -0500
AIP-82 Introduce `BaseEventTrigger` as base class for triggers used with
event driven scheduling (#46391)
---
airflow/dag_processing/collection.py | 28 +++-------
.../example_dags/example_asset_with_watchers.py | 26 +++-------
airflow/serialization/serialized_objects.py | 3 +-
airflow/triggers/base.py | 22 ++++++++
.../airflow/providers/standard/triggers/file.py | 60 ++++++++++++++++++++--
.../provider_tests/standard/triggers/test_file.py | 44 +++++++++++++++-
.../src/airflow/sdk/definitions/asset/__init__.py | 22 ++++++--
tests/serialization/test_serialized_objects.py | 4 +-
8 files changed, 158 insertions(+), 51 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index 7bd5b316714..e581b02478b 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -27,10 +27,9 @@ This should generally only be called by internal methods
such as
from __future__ import annotations
-import json
import logging
import traceback
-from typing import TYPE_CHECKING, Any, NamedTuple, cast
+from typing import TYPE_CHECKING, NamedTuple, cast
from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
from sqlalchemy.exc import OperationalError
@@ -53,7 +52,8 @@ from airflow.models.dagwarning import DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef,
AssetUriRef
-from airflow.serialization.serialized_objects import BaseSerialization,
SerializedAssetWatcher
+from airflow.serialization.serialized_objects import SerializedAssetWatcher
+from airflow.triggers.base import BaseEventTrigger
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
@@ -758,7 +758,7 @@ class AssetModelOperation(NamedTuple):
else []
)
trigger_hash_to_trigger_dict: dict[int, dict] = {
- self._get_trigger_hash(
+ BaseEventTrigger.hash(
watcher.trigger["classpath"], watcher.trigger["kwargs"]
): watcher.trigger
for watcher in asset_watchers
@@ -768,7 +768,7 @@ class AssetModelOperation(NamedTuple):
asset_model = assets[name_uri]
trigger_hash_from_asset_model: set[int] = {
- self._get_trigger_hash(trigger.classpath, trigger.kwargs) for
trigger in asset_model.triggers
+ BaseEventTrigger.hash(trigger.classpath, trigger.kwargs) for
trigger in asset_model.triggers
}
# Optimization: no diff between the DB and DAG definitions, no
update needed
@@ -796,7 +796,7 @@ class AssetModelOperation(NamedTuple):
for trigger_hash in trigger_hashes
}
orm_triggers: dict[int, Trigger] = {
- self._get_trigger_hash(trigger.classpath, trigger.kwargs):
trigger
+ BaseEventTrigger.hash(trigger.classpath, trigger.kwargs):
trigger
for trigger in session.scalars(
select(Trigger).where(
tuple_(Trigger.classpath,
Trigger.encrypted_kwargs).in_(all_trigger_keys)
@@ -817,7 +817,7 @@ class AssetModelOperation(NamedTuple):
]
session.add_all(new_trigger_models)
orm_triggers.update(
- (self._get_trigger_hash(trigger.classpath, trigger.kwargs),
trigger)
+ (BaseEventTrigger.hash(trigger.classpath, trigger.kwargs),
trigger)
for trigger in new_trigger_models
)
@@ -835,7 +835,7 @@ class AssetModelOperation(NamedTuple):
asset_model.triggers = [
trigger
for trigger in asset_model.triggers
- if self._get_trigger_hash(trigger.classpath,
trigger.kwargs) not in trigger_hashes
+ if BaseEventTrigger.hash(trigger.classpath,
trigger.kwargs) not in trigger_hashes
]
# Remove references from assets no longer used
@@ -845,15 +845,3 @@ class AssetModelOperation(NamedTuple):
for asset_model in orphan_assets:
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []
-
- @staticmethod
- def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
- """
- Return the hash of the trigger classpath and kwargs. This is used to
uniquely identify a trigger.
-
- We do not want to move this logic in a `__hash__` method in
`BaseTrigger` because we do not want to
- make the triggers hashable. The reason being, when the triggerer
retrieve the list of triggers, we do
- not want it dedupe them. When used to defer tasks, 2 triggers can have
the same classpath and kwargs.
- This is not true for event driven scheduling.
- """
- return hash((classpath,
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
diff --git a/airflow/example_dags/example_asset_with_watchers.py
b/airflow/example_dags/example_asset_with_watchers.py
index 79d012a2ed4..4f65a7f9305 100644
--- a/airflow/example_dags/example_asset_with_watchers.py
+++ b/airflow/example_dags/example_asset_with_watchers.py
@@ -20,30 +20,17 @@ Example DAG for demonstrating the usage of event driven
scheduling using assets
from __future__ import annotations
-import os
-
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
-from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.providers.standard.triggers.file import FileDeleteTrigger
from airflow.sdk import Asset, AssetWatcher
file_path = "/tmp/test"
-with DAG(
- dag_id="example_create_file",
- catchup=False,
-):
-
- @task
- def create_file():
- with open(file_path, "w") as file:
- file.write("This is an example file.\n")
-
- chain(create_file())
+trigger = FileDeleteTrigger(filepath=file_path)
+asset = Asset("example_asset",
watchers=[AssetWatcher(name="test_asset_watcher", trigger=trigger)])
-trigger = FileTrigger(filepath=file_path, poke_interval=10)
-asset = Asset("example_asset",
watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])
with DAG(
dag_id="example_asset_with_watchers",
@@ -52,8 +39,7 @@ with DAG(
):
@task
- def delete_file():
- if os.path.exists(file_path):
- os.remove(file_path) # Delete the file
+ def test_task():
+ print("Hello world")
- chain(delete_file())
+ chain(test_task())
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 08d032e873c..edbc6b4e5bd 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -102,6 +102,7 @@ if TYPE_CHECKING:
from airflow.sdk.types import Operator
from airflow.serialization.json_schema import Validator
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
+ from airflow.triggers.base import BaseEventTrigger
HAS_KUBERNETES: bool
try:
@@ -259,7 +260,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str,
Any]:
"trigger": _encode_trigger(watcher.trigger),
}
- def _encode_trigger(trigger: BaseTrigger | dict):
+ def _encode_trigger(trigger: BaseEventTrigger | dict):
if isinstance(trigger, dict):
return trigger
classpath, kwargs = trigger.serialize()
diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py
index d36dd40b6f8..f5585dcc800 100644
--- a/airflow/triggers/base.py
+++ b/airflow/triggers/base.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import abc
+import json
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
@@ -126,6 +127,27 @@ class BaseTrigger(abc.ABC, LoggingMixin):
return self.repr(classpath, kwargs)
+class BaseEventTrigger(BaseTrigger):
+ """
+ Base class for triggers used to schedule DAGs based on external events.
+
+ ``BaseEventTrigger`` is a subclass of ``BaseTrigger`` designed to identify
triggers compatible with
+ event-driven scheduling.
+ """
+
+ @staticmethod
+ def hash(classpath: str, kwargs: dict[str, Any]) -> int:
+ """
+ Return the hash of the trigger classpath and kwargs. This is used to
uniquely identify a trigger.
+
+ We do not want to have this logic in ``BaseTrigger`` because, when
used to defer tasks, 2 triggers
+ can have the same classpath and kwargs. This is not true for event
driven scheduling.
+ """
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ return hash((classpath,
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
+
+
class TriggerEvent:
"""
Something that a trigger can fire when its conditions are met.
diff --git a/providers/standard/src/airflow/providers/standard/triggers/file.py
b/providers/standard/src/airflow/providers/standard/triggers/file.py
index f6a7715a035..6df163a6f7c 100644
--- a/providers/standard/src/airflow/providers/standard/triggers/file.py
+++ b/providers/standard/src/airflow/providers/standard/triggers/file.py
@@ -19,11 +19,20 @@ from __future__ import annotations
import asyncio
import datetime
import os
-import typing
+from collections.abc import AsyncIterator
from glob import glob
from typing import Any
-from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.triggers.base import BaseEventTrigger, BaseTrigger,
TriggerEvent
+else:
+ from airflow.triggers.base import ( # type: ignore
+ BaseTrigger,
+ BaseTrigger as BaseEventTrigger,
+ TriggerEvent,
+ )
class FileTrigger(BaseTrigger):
@@ -60,7 +69,7 @@ class FileTrigger(BaseTrigger):
},
)
- async def run(self) -> typing.AsyncIterator[TriggerEvent]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the relevant files are found."""
while True:
for path in glob(self.filepath, recursive=self.recursive):
@@ -75,3 +84,48 @@ class FileTrigger(BaseTrigger):
yield TriggerEvent(True)
return
await asyncio.sleep(self.poke_interval)
+
+
+class FileDeleteTrigger(BaseEventTrigger):
+ """
+ A trigger that fires exactly once after it finds the requested file and
then delete the file.
+
+ The difference between ``FileTrigger`` and ``FileDeleteTrigger`` is
``FileDeleteTrigger`` can only find a
+ specific file.
+
+ :param filepath: File (relative to the base path set within the
connection).
+ :param poke_interval: Time that the job should wait in between each try
+ """
+
+ def __init__(
+ self,
+ filepath: str,
+ poke_interval: float = 5.0,
+ **kwargs,
+ ):
+ super().__init__()
+ self.filepath = filepath
+ self.poke_interval = poke_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize FileDeleteTrigger arguments and classpath."""
+ return (
+ "airflow.providers.standard.triggers.file.FileDeleteTrigger",
+ {
+ "filepath": self.filepath,
+ "poke_interval": self.poke_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Loop until the relevant file is found."""
+ while True:
+ if os.path.isfile(self.filepath):
+ mod_time_f = os.path.getmtime(self.filepath)
+ mod_time =
datetime.datetime.fromtimestamp(mod_time_f).strftime("%Y%m%d%H%M%S")
+ self.log.info("Found file %s last modified: %s",
self.filepath, mod_time)
+ os.remove(self.filepath)
+ self.log.info("File %s has been deleted", self.filepath)
+ yield TriggerEvent(True)
+ return
+ await asyncio.sleep(self.poke_interval)
diff --git
a/providers/standard/tests/provider_tests/standard/triggers/test_file.py
b/providers/standard/tests/provider_tests/standard/triggers/test_file.py
index baf0dffa80d..b69b5857e2b 100644
--- a/providers/standard/tests/provider_tests/standard/triggers/test_file.py
+++ b/providers/standard/tests/provider_tests/standard/triggers/test_file.py
@@ -20,7 +20,8 @@ import asyncio
import pytest
-from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.providers.standard.triggers.file import FileDeleteTrigger,
FileTrigger
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
class TestFileTrigger:
@@ -62,3 +63,44 @@ class TestFileTrigger:
# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
+
+
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Skip on Airflow < 3.0")
+class TestFileDeleteTrigger:
+ FILE_PATH = "/files/dags/example_async_file.py"
+
+ def test_serialization(self):
+ """Asserts that the trigger correctly serializes its arguments and
classpath."""
+ trigger = FileDeleteTrigger(filepath=self.FILE_PATH, poll_interval=5)
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.standard.triggers.file.FileDeleteTrigger"
+ assert kwargs == {
+ "filepath": self.FILE_PATH,
+ "poke_interval": 5,
+ }
+
+ @pytest.mark.asyncio
+ async def test_file_delete_trigger(self, tmp_path):
+ """Asserts that the trigger goes off on or after file is found and
that the files gets deleted."""
+ tmp_dir = tmp_path / "test_dir"
+ tmp_dir.mkdir()
+ p = tmp_dir / "hello.txt"
+
+ trigger = FileDeleteTrigger(
+ filepath=str(p.resolve()),
+ poke_interval=0.2,
+ )
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # It should not have produced a result
+ assert task.done() is False
+
+ p.touch()
+
+ await asyncio.sleep(0.5)
+ assert p.exists() is False
+
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index f7a515ea434..f5e42e82e36 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from airflow.models.asset import AssetModel
from airflow.serialization.serialized_objects import SerializedAssetWatcher
- from airflow.triggers.base import BaseTrigger
+ from airflow.triggers.base import BaseEventTrigger
AttrsInstance = attrs.AttrsInstance
else:
@@ -254,7 +254,7 @@ class BaseAsset:
raise NotImplementedError
[email protected](frozen=True)
[email protected](init=False)
class AssetWatcher:
"""A representation of an asset watcher. The name uniquely identifies the
watch."""
@@ -263,8 +263,22 @@ class AssetWatcher:
# For a "normal" asset instance loaded from DAG, this holds the trigger
used to monitor an external
# resource. In that case, ``AssetWatcher`` is used directly by users.
# For an asset recreated from a serialized DAG, this holds the serialized
data of the trigger. In that
- # case, `SerializedAssetWatcher` is used. We need to keep the two types to
make mypy happy.
- trigger: BaseTrigger | dict
+ # case, `SerializedAssetWatcher` is used. We need to keep the two types to
make mypy happy because
+ # `SerializedAssetWatcher` is a subclass of `AssetWatcher`.
+ trigger: BaseEventTrigger | dict
+
+ def __init__(
+ self,
+ name: str,
+ trigger: BaseEventTrigger | dict,
+ ) -> None:
+ from airflow.triggers.base import BaseEventTrigger, BaseTrigger
+
+ if isinstance(trigger, BaseTrigger) and not isinstance(trigger,
BaseEventTrigger):
+ raise ValueError("The trigger used to watch an asset must inherit
``BaseEventTrigger``")
+
+ self.name = name
+ self.trigger = trigger
@attrs.define(init=False, unsafe_hash=False)
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 945a52e9f4f..ba3b2bca7ae 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -41,7 +41,7 @@ from airflow.models.taskinstance import SimpleTaskInstance,
TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
-from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.providers.standard.triggers.file import FileDeleteTrigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent,
AssetUniqueKey, AssetWatcher
from airflow.sdk.definitions.param import Param
from airflow.sdk.execution_time.context import OutletEventAccessor,
OutletEventAccessors
@@ -259,7 +259,7 @@ class MockLazySelectSequence(LazySelectSequence):
Asset(
uri="test://asset1",
name="test",
- watchers=[AssetWatcher(name="test",
trigger=FileTrigger(filepath="/tmp"))],
+ watchers=[AssetWatcher(name="test",
trigger=FileDeleteTrigger(filepath="/tmp"))],
),
DAT.ASSET,
equals,