This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 53e1723f0b2 AIP-82 Handle trigger serialization (#45562)
53e1723f0b2 is described below

commit 53e1723f0b25fda86aa5594c244b55057632222e
Author: Vincent <[email protected]>
AuthorDate: Fri Jan 31 18:57:21 2025 -0500

    AIP-82 Handle trigger serialization (#45562)
---
 airflow/dag_processing/collection.py               | 39 +++++++++---------
 .../example_dags/example_asset_with_watchers.py    |  7 ++--
 airflow/serialization/schema.json                  | 12 ++++++
 airflow/serialization/serialized_objects.py        | 47 ++++++++++++++++++++--
 task_sdk/src/airflow/sdk/__init__.py               |  7 +++-
 .../src/airflow/sdk/definitions/asset/__init__.py  | 25 +++++++++---
 tests/dag_processing/test_collection.py            |  8 +++-
 tests/serialization/test_serialized_objects.py     | 12 +++++-
 8 files changed, 122 insertions(+), 35 deletions(-)

diff --git a/airflow/dag_processing/collection.py 
b/airflow/dag_processing/collection.py
index c8bac5cef96..7bd5b316714 100644
--- a/airflow/dag_processing/collection.py
+++ b/airflow/dag_processing/collection.py
@@ -30,7 +30,7 @@ from __future__ import annotations
 import json
 import logging
 import traceback
-from typing import TYPE_CHECKING, Any, NamedTuple
+from typing import TYPE_CHECKING, Any, NamedTuple, cast
 
 from sqlalchemy import and_, delete, exists, func, insert, select, tuple_
 from sqlalchemy.exc import OperationalError
@@ -53,8 +53,7 @@ from airflow.models.dagwarning import DagWarningType
 from airflow.models.errors import ParseImportError
 from airflow.models.trigger import Trigger
 from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, 
AssetUriRef
-from airflow.serialization.serialized_objects import BaseSerialization
-from airflow.triggers.base import BaseTrigger
+from airflow.serialization.serialized_objects import BaseSerialization, 
SerializedAssetWatcher
 from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
 from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.timezone import utcnow
@@ -68,7 +67,6 @@ if TYPE_CHECKING:
 
     from airflow.models.dagwarning import DagWarning
     from airflow.serialization.serialized_objects import MaybeSerializedDAG
-    from airflow.triggers.base import BaseTrigger
     from airflow.typing_compat import Self
 
 log = logging.getLogger(__name__)
@@ -747,16 +745,23 @@ class AssetModelOperation(NamedTuple):
         # Update references from assets being used
         refs_to_add: dict[tuple[str, str], set[int]] = {}
         refs_to_remove: dict[tuple[str, str], set[int]] = {}
-        triggers: dict[int, BaseTrigger] = {}
+        triggers: dict[int, dict] = {}
 
         # Optimization: if no asset collected, skip fetching active assets
         active_assets = _find_active_assets(self.assets.keys(), 
session=session) if self.assets else {}
 
         for name_uri, asset in self.assets.items():
             # If the asset belong to a DAG not active or paused, consider 
there is no watcher associated to it
-            asset_watchers = asset.watchers if name_uri in active_assets else 
[]
-            trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
-                self._get_base_trigger_hash(trigger): trigger for trigger in 
asset_watchers
+            asset_watchers: list[SerializedAssetWatcher] = (
+                [cast(SerializedAssetWatcher, watcher) for watcher in 
asset.watchers]
+                if name_uri in active_assets
+                else []
+            )
+            trigger_hash_to_trigger_dict: dict[int, dict] = {
+                self._get_trigger_hash(
+                    watcher.trigger["classpath"], watcher.trigger["kwargs"]
+                ): watcher.trigger
+                for watcher in asset_watchers
             }
             triggers.update(trigger_hash_to_trigger_dict)
             trigger_hash_from_asset: set[int] = 
set(trigger_hash_to_trigger_dict.keys())
@@ -783,7 +788,10 @@ class AssetModelOperation(NamedTuple):
             }
 
             all_trigger_keys: set[tuple[str, str]] = {
-                self._encrypt_trigger_kwargs(triggers[trigger_hash])
+                (
+                    triggers[trigger_hash]["classpath"],
+                    Trigger.encrypt_kwargs(triggers[trigger_hash]["kwargs"]),
+                )
                 for trigger_hashes in refs_to_add.values()
                 for trigger_hash in trigger_hashes
             }
@@ -800,7 +808,9 @@ class AssetModelOperation(NamedTuple):
             new_trigger_models = [
                 trigger
                 for trigger in [
-                    Trigger.from_object(triggers[trigger_hash])
+                    Trigger(
+                        classpath=triggers[trigger_hash]["classpath"], 
kwargs=triggers[trigger_hash]["kwargs"]
+                    )
                     for trigger_hash in all_trigger_hashes
                     if trigger_hash not in orm_triggers
                 ]
@@ -836,11 +846,6 @@ class AssetModelOperation(NamedTuple):
             if (asset_model.name, asset_model.uri) not in self.assets:
                 asset_model.triggers = []
 
-    @staticmethod
-    def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
-        classpath, kwargs = trigger.serialize()
-        return classpath, Trigger.encrypt_kwargs(kwargs)
-
     @staticmethod
     def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
         """
@@ -852,7 +857,3 @@ class AssetModelOperation(NamedTuple):
         This is not true for event driven scheduling.
         """
         return hash((classpath, 
json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))
-
-    def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
-        classpath, kwargs = trigger.serialize()
-        return self._get_trigger_hash(classpath, kwargs)
diff --git a/airflow/example_dags/example_asset_with_watchers.py 
b/airflow/example_dags/example_asset_with_watchers.py
index 32bbaa6bdd2..79d012a2ed4 100644
--- a/airflow/example_dags/example_asset_with_watchers.py
+++ b/airflow/example_dags/example_asset_with_watchers.py
@@ -21,15 +21,14 @@ Example DAG for demonstrating the usage of event driven 
scheduling using assets
 from __future__ import annotations
 
 import os
-import tempfile
 
 from airflow.decorators import task
 from airflow.models.baseoperator import chain
 from airflow.models.dag import DAG
 from airflow.providers.standard.triggers.file import FileTrigger
-from airflow.sdk.definitions.asset import Asset
+from airflow.sdk import Asset, AssetWatcher
 
-file_path = tempfile.NamedTemporaryFile().name
+file_path = "/tmp/test"
 
 with DAG(
     dag_id="example_create_file",
@@ -44,7 +43,7 @@ with DAG(
     chain(create_file())
 
 trigger = FileTrigger(filepath=file_path, poke_interval=10)
-asset = Asset("example_asset", watchers=[trigger])
+asset = Asset("example_asset", 
watchers=[AssetWatcher(name="test_file_watcher", trigger=trigger)])
 
 with DAG(
     dag_id="example_asset_with_watchers",
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index 292415ce11e..82fb5f908d3 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -64,6 +64,10 @@
                 {"type": "null"},
                 { "$ref": "#/definitions/dict" }
             ]
+        },
+        "watchers": {
+            "type": "array",
+            "items": { "$ref": "#/definitions/trigger" }
         }
       },
       "required": [ "uri", "extra" ]
@@ -126,6 +130,14 @@
         ],
         "additionalProperties": false
     },
+    "trigger": {
+      "type": "object",
+      "properties": {
+        "classpath": { "type": "string" },
+        "kwargs": { "$ref": "#/definitions/dict" }
+      },
+      "required": [ "classpath", "kwargs" ]
+    },
     "dict": {
       "description": "A python dictionary containing values of any type",
       "type": "object"
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 9c5f43c0c0b..89ea668f02f 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -59,6 +59,7 @@ from airflow.sdk.definitions.asset import (
     AssetAny,
     AssetRef,
     AssetUniqueKey,
+    AssetWatcher,
     BaseAsset,
 )
 from airflow.sdk.definitions.baseoperator import BaseOperator as 
TaskSDKBaseOperator
@@ -251,13 +252,34 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, 
Any]:
     :meta private:
     """
     if isinstance(var, Asset):
-        return {
+
+        def _encode_watcher(watcher: AssetWatcher):
+            return {
+                "name": watcher.name,
+                "trigger": _encode_trigger(watcher.trigger),
+            }
+
+        def _encode_trigger(trigger: BaseTrigger | dict):
+            if isinstance(trigger, dict):
+                return trigger
+            classpath, kwargs = trigger.serialize()
+            return {
+                "classpath": classpath,
+                "kwargs": kwargs,
+            }
+
+        asset = {
             "__type": DAT.ASSET,
             "name": var.name,
             "uri": var.uri,
             "group": var.group,
             "extra": var.extra,
         }
+
+        if len(var.watchers) > 0:
+            asset["watchers"] = [_encode_watcher(watcher) for watcher in 
var.watchers]
+
+        return asset
     if isinstance(var, AssetAlias):
         return {"__type": DAT.ASSET_ALIAS, "name": var.name, "group": 
var.group}
     if isinstance(var, AssetAll):
@@ -283,7 +305,7 @@ def decode_asset_condition(var: dict[str, Any]) -> 
BaseAsset:
     """
     dat = var["__type"]
     if dat == DAT.ASSET:
-        return Asset(name=var["name"], uri=var["uri"], group=var["group"], 
extra=var["extra"])
+        return decode_asset(var)
     if dat == DAT.ASSET_ALL:
         return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
     if dat == DAT.ASSET_ANY:
@@ -295,6 +317,19 @@ def decode_asset_condition(var: dict[str, Any]) -> 
BaseAsset:
     raise ValueError(f"deserialization not implemented for DAT {dat!r}")
 
 
+def decode_asset(var: dict[str, Any]):
+    watchers = var.get("watchers", [])
+    return Asset(
+        name=var["name"],
+        uri=var["uri"],
+        group=var["group"],
+        extra=var["extra"],
+        watchers=[
+            SerializedAssetWatcher(name=watcher["name"], 
trigger=watcher["trigger"]) for watcher in watchers
+        ],
+    )
+
+
 def encode_outlet_event_accessor(var: OutletEventAccessor) -> dict[str, Any]:
     key = var.key
     return {
@@ -874,7 +909,7 @@ class BaseSerialization:
         elif type_ == DAT.XCOM_REF:
             return _XComRef(var)  # Delay deserializing XComArg objects until 
we have the entire DAG.
         elif type_ == DAT.ASSET:
-            return Asset(**var)
+            return decode_asset(var)
         elif type_ == DAT.ASSET_ALIAS:
             return AssetAlias(**var)
         elif type_ == DAT.ASSET_ANY:
@@ -1810,6 +1845,12 @@ class TaskGroupSerialization(BaseSerialization):
         return group
 
 
+class SerializedAssetWatcher(AssetWatcher):
+    """JSON serializable representation of an asset watcher."""
+
+    trigger: dict
+
+
 def _has_kubernetes() -> bool:
     global HAS_KUBERNETES
     if "HAS_KUBERNETES" in globals():
diff --git a/task_sdk/src/airflow/sdk/__init__.py 
b/task_sdk/src/airflow/sdk/__init__.py
index 6762f43ef8a..95b08be37aa 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -19,6 +19,9 @@ from __future__ import annotations
 from typing import TYPE_CHECKING
 
 __all__ = [
+    "__version__",
+    "Asset",
+    "AssetWatcher",
     "BaseOperator",
     "Connection",
     "DAG",
@@ -27,7 +30,6 @@ __all__ = [
     "MappedOperator",
     "TaskGroup",
     "XComArg",
-    "__version__",
     "dag",
     "get_current_context",
     "get_parsing_context",
@@ -36,6 +38,7 @@ __all__ = [
 __version__ = "1.0.0.alpha1"
 
 if TYPE_CHECKING:
+    from airflow.sdk.definitions.asset import Asset, AssetWatcher
     from airflow.sdk.definitions.baseoperator import BaseOperator
     from airflow.sdk.definitions.connection import Connection
     from airflow.sdk.definitions.context import get_current_context, 
get_parsing_context
@@ -60,6 +63,8 @@ __lazy_imports: dict[str, str] = {
     "dag": ".definitions.dag",
     "get_current_context": ".definitions.context",
     "get_parsing_context": ".definitions.context",
+    "Asset": ".definitions.asset",
+    "AssetWatcher": ".definitions.asset",
 }
 
 
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index b976bb8c156..f7a515ea434 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.models.asset import AssetModel
+    from airflow.serialization.serialized_objects import SerializedAssetWatcher
     from airflow.triggers.base import BaseTrigger
 
     AttrsInstance = attrs.AttrsInstance
@@ -54,6 +55,7 @@ __all__ = [
     "AssetNameRef",
     "AssetRef",
     "AssetUriRef",
+    "AssetWatcher",
 ]
 
 
@@ -252,6 +254,19 @@ class BaseAsset:
         raise NotImplementedError
 
 
[email protected](frozen=True)
+class AssetWatcher:
+    """A representation of an asset watcher. The name uniquely identifies the 
watch."""
+
+    name: str
+    # This attribute serves double purpose.
+    # For a "normal" asset instance loaded from DAG, this holds the trigger 
used to monitor an external
+    # resource. In that case, ``AssetWatcher`` is used directly by users.
+    # For an asset recreated from a serialized DAG, this holds the serialized 
data of the trigger. In that
+    # case, `SerializedAssetWatcher` is used. We need to keep the two types to 
make mypy happy.
+    trigger: BaseTrigger | dict
+
+
 @attrs.define(init=False, unsafe_hash=False)
 class Asset(os.PathLike, BaseAsset):
     """A representation of data asset dependencies between workflows."""
@@ -271,7 +286,7 @@ class Asset(os.PathLike, BaseAsset):
         factory=dict,
         converter=_set_extra_default,
     )
-    watchers: list[BaseTrigger] = attrs.field(
+    watchers: list[AssetWatcher | SerializedAssetWatcher] = attrs.field(
         factory=list,
     )
 
@@ -286,7 +301,7 @@ class Asset(os.PathLike, BaseAsset):
         *,
         group: str = ...,
         extra: dict | None = None,
-        watchers: list[BaseTrigger] = ...,
+        watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
     ) -> None:
         """Canonical; both name and uri are provided."""
 
@@ -297,7 +312,7 @@ class Asset(os.PathLike, BaseAsset):
         *,
         group: str = ...,
         extra: dict | None = None,
-        watchers: list[BaseTrigger] = ...,
+        watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
     ) -> None:
         """It's possible to only provide the name, either by keyword or as the 
only positional argument."""
 
@@ -308,7 +323,7 @@ class Asset(os.PathLike, BaseAsset):
         uri: str,
         group: str = ...,
         extra: dict | None = None,
-        watchers: list[BaseTrigger] = ...,
+        watchers: list[AssetWatcher | SerializedAssetWatcher] = ...,
     ) -> None:
         """It's possible to only provide the URI as a keyword argument."""
 
@@ -319,7 +334,7 @@ class Asset(os.PathLike, BaseAsset):
         *,
         group: str | None = None,
         extra: dict | None = None,
-        watchers: list[BaseTrigger] | None = None,
+        watchers: list[AssetWatcher | SerializedAssetWatcher] | None = None,
     ) -> None:
         if name is None and uri is None:
             raise TypeError("Asset() requires either 'name' or 'uri'")
diff --git a/tests/dag_processing/test_collection.py 
b/tests/dag_processing/test_collection.py
index 2a0a40a634a..ac2bfe141c3 100644
--- a/tests/dag_processing/test_collection.py
+++ b/tests/dag_processing/test_collection.py
@@ -50,7 +50,7 @@ from airflow.models.errors import ParseImportError
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger
-from airflow.sdk.definitions.asset import Asset
+from airflow.sdk.definitions.asset import Asset, AssetWatcher
 from airflow.serialization.serialized_objects import LazyDeserializedDAG, 
SerializedDAG
 from airflow.utils import timezone as tz
 from airflow.utils.session import create_session
@@ -131,7 +131,11 @@ class TestAssetModelOperation:
     )
     def test_add_asset_trigger_references(self, is_active, is_paused, 
expected_num_triggers, dag_maker):
         trigger = TimeDeltaTrigger(timedelta(seconds=0))
-        asset = Asset("test_add_asset_trigger_references_asset", 
watchers=[trigger])
+        classpath, kwargs = trigger.serialize()
+        asset = Asset(
+            "test_add_asset_trigger_references_asset",
+            watchers=[AssetWatcher(name="test", trigger={"classpath": 
classpath, "kwargs": kwargs})],
+        )
 
         with dag_maker(dag_id="test_add_asset_trigger_references_dag", 
schedule=[asset]) as dag:
             EmptyOperator(task_id="mytask")
diff --git a/tests/serialization/test_serialized_objects.py 
b/tests/serialization/test_serialized_objects.py
index ca6cb78a627..b07ebece4f3 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -41,7 +41,8 @@ from airflow.models.taskinstance import SimpleTaskInstance, 
TaskInstance
 from airflow.models.xcom_arg import XComArg
 from airflow.operators.empty import EmptyOperator
 from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, 
AssetUniqueKey
+from airflow.providers.standard.triggers.file import FileTrigger
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, 
AssetUniqueKey, AssetWatcher
 from airflow.sdk.definitions.param import Param
 from airflow.sdk.execution_time.context import OutletEventAccessor, 
OutletEventAccessors
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
@@ -254,6 +255,15 @@ class MockLazySelectSequence(LazySelectSequence):
             lambda a, b: len(a) == len(b) and isinstance(b, list),
         ),
         (Asset(uri="test://asset1", name="test"), DAT.ASSET, equals),
+        (
+            Asset(
+                uri="test://asset1",
+                name="test",
+                watchers=[AssetWatcher(name="test", 
trigger=FileTrigger(filepath="/tmp"))],
+            ),
+            DAT.ASSET,
+            equals,
+        ),
         (SimpleTaskInstance.from_ti(ti=TI), DAT.SIMPLE_TASK_INSTANCE, equals),
         (
             Connection(conn_id="TEST_ID", uri="mysql://"),

Reply via email to