This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 53e1723f0b2 AIP-82 Handle trigger serialization (#45562)
53e1723f0b2 is described below
commit 53e1723f0b25fda86aa5594c244b55057632222e
Author: Vincent <[email protected]>
AuthorDate: Fri Jan 31 18:57:21 2025 -0500
AIP-82 Handle trigger serialization (#45562)
---
airflow/dag_processing/collection.py | 39 +++++++++---------
.../example_dags/example_asset_with_watchers.py | 7 ++--
airflow/serialization/schema.json | 12 ++++++
airflow/serialization/serialized_objects.py | 47 ++++++++++++++++++++--
task_sdk/src/airflow/sdk/__init__.py | 7 +++-
.../src/airflow/sdk/definitions/asset/__init__.py | 25 +++++++++---
tests/dag_processing/test_collection.py | 8 +++-
tests/serialization/test_serialized_objects.py | 12 +++++-
8 files changed, 122 insertions(+), 35 deletions(-)
diff --git a/airflow/dag_processing/collection.py
b/airflow/dag_processing/collection.py
index c8bac5cef96..7bd5b316714 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -30,7 +30,7 @@ from __future__ import annotations
import json
import logging
import traceback
-from typing import TYPE_CHECKING, Any, NamedTuple
+from typing import TYPE_CHECKING, Any, NamedTuple, cast
from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
from sqlalchemy.exc import OperationalError
@@ -53,8 +53,7 @@ from airflow.models.dagwarning import DagWarningType
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef,
AssetUriRef
-from airflow.serialization.serialized_objects import BaseSerialization
-from airflow.triggers.base import BaseTrigger
+from airflow.serialization.serialized_objects import BaseSerialization,
SerializedAssetWatcher
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.timezone import utcnow
@@ -68,7 +67,6 @@ if TYPE_CHECKING:
from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import MaybeSerializedDAG
- from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self
log = logging.getLogger(__name__)
@@ -747,16 +745,23 @@ class AssetModelOperation(NamedTuple):
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[int]] = {}
refs_to_remove: dict[tuple[str, str], set[int]] = {}
- triggers: dict[int, BaseTrigger] = {}
+ triggers: dict[int, dict] = {}
# Optimization: if no asset collected, skip fetching active assets
active_assets = _find_active_assets(self.assets.keys(),
session=session) if self.assets else {}
for name_uri, asset in self.assets.items():
# If the asset belong to a DAG not active or paused, consider
there is no watcher associated to it
- asset_watchers = asset.watchers if name_uri in active_assets else
[]
- trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
- self._get_base_trigger_hash(trigger): trigger for trigger in
asset_watchers
+ asset_watchers: list[SerializedAssetWatcher] = (
+ [cast(SerializedAssetWatcher, watcher) for watcher in
asset.watchers]
+ if name_uri in active_assets
+ else []
+ )
+ trigger_hash_to_trigger_dict: dict[int, dict] = {
+ self._get_trigger_hash(
+ watcher.trigger["classpath"], watcher.trigger["kwargs"]
+ ): watcher.trigger
+ for watcher in asset_watchers
}
triggers.update(trigger_hash_to_trigger_dict)
trigger_hash_from_asset: set[int] =
set(trigger_hash_to_trigger_dict.keys())
@@ -783,7 +788,10 @@ class AssetModelOperation(NamedTuple):
}
all_trigger_keys: set[tuple[str, str]] = {
- self._encrypt_trigger_kwargs(triggers[trigger_hash])
+ (
+ triggers[trigger_hash]["classpath"],
+ Trigger.encrypt_kwargs(triggers[trigger_hash]["kwargs"]),
+ )
for trigger_hashes in refs_to_add.values()
for trigger_hash in trigger_hashes
}
@@ -800,7 +808,9 @@ class AssetModelOperation(NamedTuple):
new_trigger_models = [
trigger
for trigger in [
- Trigger.from_object(triggers[trigger_hash])
+ Trigger(
+ classpath=triggers[trigger_hash]["classpath"],
kwargs=triggers[trigger_hash]["kwargs"]
+ )
for trigger_hash in all_trigger_hashes
if trigger_hash not in orm_triggers
]
@@ -836,11 +846,6 @@ class AssetModelOperation(NamedTuple):
if (asset_model.name, asset_model.uri) not in self.assets:
asset_model.triggers = []
- @staticmethod
- def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
- classpath, kwargs = trigger.serialize()
- return classpath, Trigger.encrypt_kwargs(kwargs)
-
@staticmethod
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
"""
@@ -852,7 +857,3 @@ class AssetModelOperation(NamedTuple):
This is not true for event driven scheduling.
"""
return hash((classpath,
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
-
- def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
- classpath, kwargs = trigger.serialize()
- return self._get_trigger_hash(classpath, kwargs)
diff --git a/airflow/example_dags/example_asset_with_watchers.py
b/airflow/example_dags/example_asset_with_watchers.py
index 32bbaa6bdd2..79d012a2ed4 100644
--- a/airflow/example_dags/example_asset_with_watchers.py
+++ b/airflow/example_dags/example_asset_with_watchers.py
@@ -21,15 +21,14 @@ Example DAG for demonstrating the usage of event driven
scheduling using assets
from __future__ import annotations
import os
-import tempfile
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.standard.triggers.file import FileTrigger
-from airflow.sdk.definitions.asset import Asset
+from airflow.sdk import Asset, AssetWatcher
-file_path = tempfile.NamedTemporaryFile().name
+file_path = "/tmp/test"
with DAG(
dag_id="example_create_file",
@@ -44,7 +43,7 @@ with DAG(
chain(create_file())
trigger = FileTrigger(filepath=file_path, poke_interval=10)
-asset = Asset("example_asset", watchers=[trigger])
+asset = Asset("example_asset",
watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])
with DAG(
dag_id="example_asset_with_watchers",
diff --git a/airflow/serialization/schema.json
b/airflow/serialization/schema.json
index 292415ce11e..82fb5f908d3 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -64,6 +64,10 @@
{"type": "null"},
{ "$ref": "#/definitions/dict" }
]
+ },
+ "watchers": {
+ "type": "array",
+ "items": { "$ref": "#/definitions/trigger" }
}
},
"required": [ "uri", "extra" ]
@@ -126,6 +130,14 @@
],
"additionalProperties": false
},
+ "trigger": {
+ "type": "object",
+ "properties": {
+ "classpath": { "type": "string" },
+ "kwargs": { "$ref": "#/definitions/dict" }
+ },
+ "required": [ "classpath", "kwargs" ]
+ },
"dict": {
"description": "A python dictionary containing values of any type",
"type": "object"
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 9c5f43c0c0b..89ea668f02f 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -59,6 +59,7 @@ from airflow.sdk.definitions.asset import (
AssetAny,
AssetRef,
AssetUniqueKey,
+ AssetWatcher,
BaseAsset,
)
from airflow.sdk.definitions.baseoperator import BaseOperator as
TaskSDKBaseOperator
@@ -251,13 +252,34 @@ def encode_asset_condition(var: BaseAsset) -> dict[str,
Any]:
:meta private:
"""
if isinstance(var, Asset):
- return {
+
+ def _encode_watcher(watcher: AssetWatcher):
+ return {
+ "name": watcher.name,
+ "trigger": _encode_trigger(watcher.trigger),
+ }
+
+ def _encode_trigger(trigger: BaseTrigger | dict):
+ if isinstance(trigger, dict):
+ return trigger
+ classpath, kwargs = trigger.serialize()
+ return {
+ "classpath": classpath,
+ "kwargs": kwargs,
+ }
+
+ asset = {
"__type": DAT.ASSET,
"name": var.name,
"uri": var.uri,
"group": var.group,
"extra": var.extra,
}
+
+ if len(var.watchers) > 0:
+ asset["watchers"] = [_encode_watcher(watcher) for watcher in
var.watchers]
+
+ return asset
if isinstance(var, AssetAlias):
return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group":
var.group}
if isinstance(var, AssetAll):
@@ -283,7 +305,7 @@ def decode_asset_condition(var: dict[str, Any]) ->
BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
- return Asset(name=var["name"], uri=var["uri"], group=var["group"],
extra=var["extra"])
+ return decode_asset(var)
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ANY:
@@ -295,6 +317,19 @@ def decode_asset_condition(var: dict[str, Any]) ->
BaseAsset:
raise ValueError(f"deserialization not implemented for DAT {dat!r}")
+def decode_asset(var: dict[str, Any]):
+ watchers = var.get("watchers", [])
+ return Asset(
+ name=var["name"],
+ uri=var["uri"],
+ group=var["group"],
+ extra=var["extra"],
+ watchers=[
+ SerializedAssetWatcher(name=watcher["name"],
trigger=watcher["trigger"]) for watcher in watchers
+ ],
+ )
+
+
def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
key = var.key
return {
@@ -874,7 +909,7 @@ class BaseSerialization:
elif type_ == DAT.XCOM_REF:
return _XComRef(var) # Delay deserializing XComArg objects until
we have the entire DAG.
elif type_ == DAT.ASSET:
- return Asset(**var)
+ return decode_asset(var)
elif type_ == DAT.ASSET_ALIAS:
return AssetAlias(**var)
elif type_ == DAT.ASSET_ANY:
@@ -1810,6 +1845,12 @@ class TaskGroupSerialization(BaseSerialization):
return group
+class SerializedAssetWatcher(AssetWatcher):
+ """JSON serializable representation of an asset watcher."""
+
+ trigger: dict
+
+
def _has_kubernetes() -> bool:
global HAS_KUBERNETES
if "HAS_KUBERNETES" in globals():
diff --git a/task_sdk/src/airflow/sdk/__init__.py
b/task_sdk/src/airflow/sdk/__init__.py
index 6762f43ef8a..95b08be37aa 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -19,6 +19,9 @@ from __future__ import annotations
from typing import TYPE_CHECKING
__all__ = [
+ "__version__",
+ "Asset",
+ "AssetWatcher",
"BaseOperator",
"Connection",
"DAG",
@@ -27,7 +30,6 @@ __all__ = [
"MappedOperator",
"TaskGroup",
"XComArg",
- "__version__",
"dag",
"get_current_context",
"get_parsing_context",
@@ -36,6 +38,7 @@ __all__ = [
__version__ = "1.0.0.alpha1"
if TYPE_CHECKING:
+ from airflow.sdk.definitions.asset import Asset, AssetWatcher
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.context import get_current_context,
get_parsing_context
@@ -60,6 +63,8 @@ __lazy_imports: dict[str, str] = {
"dag": ".definitions.dag",
"get_current_context": ".definitions.context",
"get_parsing_context": ".definitions.context",
+ "Asset": ".definitions.asset",
+ "AssetWatcher": ".definitions.asset",
}
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index b976bb8c156..f7a515ea434 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.asset import AssetModel
+ from airflow.serialization.serialized_objects import SerializedAssetWatcher
from airflow.triggers.base import BaseTrigger
AttrsInstance = attrs.AttrsInstance
@@ -54,6 +55,7 @@ __all__ = [
"AssetNameRef",
"AssetRef",
"AssetUriRef",
+ "AssetWatcher",
]
@@ -252,6 +254,19 @@ class BaseAsset:
raise NotImplementedError
[email protected](frozen=True)
+class AssetWatcher:
+ """A representation of an asset watcher. The name uniquely identifies the
watch."""
+
+ name: str
+ # This attribute serves double purpose.
+ # For a "normal" asset instance loaded from DAG, this holds the trigger
used to monitor an external
+ # resource. In that case, ``AssetWatcher`` is used directly by users.
+ # For an asset recreated from a serialized DAG, this holds the serialized
data of the trigger. In that
+ # case, `SerializedAssetWatcher` is used. We need to keep the two types to
make mypy happy.
+ trigger: BaseTrigger | dict
+
+
@attrs.define(init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""
@@ -271,7 +286,7 @@ class Asset(os.PathLike, BaseAsset):
factory=dict,
converter=_set_extra_default,
)
- watchers: list[BaseTrigger] = attrs.field(
+ watchers: list[AssetWatcher | SerializedAssetWatcher] = attrs.field(
factory=list,
)
@@ -286,7 +301,7 @@ class Asset(os.PathLike, BaseAsset):
*,
group: str = ...,
extra: dict | None = None,
- watchers: list[BaseTrigger] = ...,
+ watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
) -> None:
"""Canonical; both name and uri are provided."""
@@ -297,7 +312,7 @@ class Asset(os.PathLike, BaseAsset):
*,
group: str = ...,
extra: dict | None = None,
- watchers: list[BaseTrigger] = ...,
+ watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
) -> None:
"""It's possible to only provide the name, either by keyword or as the
only positional argument."""
@@ -308,7 +323,7 @@ class Asset(os.PathLike, BaseAsset):
uri: str,
group: str = ...,
extra: dict | None = None,
- watchers: list[BaseTrigger] = ...,
+ watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""
@@ -319,7 +334,7 @@ class Asset(os.PathLike, BaseAsset):
*,
group: str | None = None,
extra: dict | None = None,
- watchers: list[BaseTrigger] | None = None,
+ watchers: list[AssetWatcher | SerializedAssetWatcher] | None = None,
) -> None:
if name is None and uri is None:
raise TypeError("Asset() requires either 'name' or 'uri'")
diff --git a/tests/dag_processing/test_collection.py
b/tests/dag_processing/test_collection.py
index 2a0a40a634a..ac2bfe141c3 100644
--- a/tests/dag_processing/test_collection.py
+++ b/tests/dag_processing/test_collection.py
@@ -50,7 +50,7 @@ from airflow.models.errors import ParseImportError
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
-from airflow.sdk.definitions.asset import Asset
+from airflow.sdk.definitions.asset import Asset, AssetWatcher
from airflow.serialization.serialized_objects import LazyDeserializedDAG,
SerializedDAG
from airflow.utils import timezone as tz
from airflow.utils.session import create_session
@@ -131,7 +131,11 @@ class TestAssetModelOperation:
)
def test_add_asset_trigger_references(self, is_active, is_paused,
expected_num_triggers, dag_maker):
trigger = TimeDeltaTrigger(timedelta(seconds=0))
- asset = Asset("test_add_asset_trigger_references_asset",
watchers=[trigger])
+ classpath, kwargs = trigger.serialize()
+ asset = Asset(
+ "test_add_asset_trigger_references_asset",
+ watchers=[AssetWatcher(name="test", trigger={"classpath":
classpath, "kwargs": kwargs})],
+ )
with dag_maker(dag_id="test_add_asset_trigger_references_dag",
schedule=[asset]) as dag:
EmptyOperator(task_id="mytask")
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index ca6cb78a627..b07ebece4f3 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -41,7 +41,8 @@ from airflow.models.taskinstance import SimpleTaskInstance,
TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent,
AssetUniqueKey
+from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent,
AssetUniqueKey, AssetWatcher
from airflow.sdk.definitions.param import Param
from airflow.sdk.execution_time.context import OutletEventAccessor,
OutletEventAccessors
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
@@ -254,6 +255,15 @@ class MockLazySelectSequence(LazySelectSequence):
lambda a, b: len(a) == len(b) and isinstance(b, list),
),
(Asset(uri="test://asset1", name="test"), DAT.ASSET, equals),
+ (
+ Asset(
+ uri="test://asset1",
+ name="test",
+ watchers=[AssetWatcher(name="test",
trigger=FileTrigger(filepath="/tmp"))],
+ ),
+ DAT.ASSET,
+ equals,
+ ),
(SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
(
Connection(conn_id="TEST_ID", uri="mysql://"),