This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b3362f841f3 Add "@asset" to decorate a function as a DAG and an asset 
(#41325)
b3362f841f3 is described below

commit b3362f841f31b5629f09f619cfa2d69e7199493a
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Nov 14 17:14:45 2024 +0800

    Add "@asset" to decorate a function as a DAG and an asset (#41325)
    
    * Implement asset definition creating a DAG
    
    * Basic inlet dependency
    
    * Make AssetDefinition subclass Asset
    
    This seems to be the best way for 'schedule' dependencies to work. Still
    not entirely sure; we'll revisit this.
    
    * style: fix mypy error
    
    * feat(asset): allow uri to be None
    
    * fix: temporarily serialize AssetDefintion into a string
    
    * feat(decorators/assets): rewrite how asset definition is serialized
    
    * test(decorators/assets): add test cases to check whether asset decorator 
generate the right asset definition
    
    * test(decorators/assets): add test cases to AssetDefinition
    
    * test(decorators/asset): add test cases to Test_AssetMainOperator
    
    * test(decorators/assets): remove unused fixtures
    
    * docs(example_dag): add example dag for asset_decorator
    
    * feat(decorators/assets): allow passing self and context into asset
    
    * feat(decorators/assets): return actual asset in asset decorator
    
    * refactor(decorators/assets): extract active assets fetching logic as 
_fetch_active_assets_by_name
    
    * feat(decorators/assets): allow fethcing inlet events through AssetRef
    
    * feat(decorators/assets): reorder import paths
    
    * docs: update asset decorator example dag
    
    * test: fix tests
    
    * test(decorators/assets): extend test_determine_kwargs to cover active 
asset
    
    * fix: address easy to fix comments
    
    * fix: fix asset serialization
    
    * refactor(decorators/assets): postpone the attribute check to 
AssetDefinition instead of asset decorator
    
    * Simplify group validators
    
    The validate_identifier validator already checks the length, so we don't
    need an extra one doing that.
    
    * style(dag): remove _wrapped_definition
    
    * style(decorators/assets): change types.FunctionType to Callable
    
    * refactor(decorators/assets): make session in _fetch_active_assets_by_name 
required
    
    * fix(decorators/asets): remove DAG.bulk_write_to_db and remove self 
handling
    
    * feat(utils/context): fetch asset_refs all at once
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
---
 airflow/assets/__init__.py                         |  34 ++--
 airflow/decorators/assets.py                       | 131 +++++++++++++++
 airflow/example_dags/example_asset_decorator.py    |  52 ++++++
 airflow/models/asset.py                            |  22 +++
 airflow/models/dag.py                              |   4 +-
 airflow/serialization/enums.py                     |   1 +
 airflow/serialization/schema.json                  |  17 +-
 airflow/serialization/serialized_objects.py        |   9 +-
 airflow/utils/context.py                           |  21 ++-
 airflow/utils/file.py                              |   4 +-
 .../api_fastapi/core_api/routes/ui/test_assets.py  |   2 +-
 tests/decorators/test_assets.py                    | 177 +++++++++++++++++++++
 tests/timetables/test_assets_timetable.py          |   4 +-
 13 files changed, 451 insertions(+), 27 deletions(-)

diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py
index 59e0b866844..f1d36ac12b7 100644
--- a/airflow/assets/__init__.py
+++ b/airflow/assets/__init__.py
@@ -23,7 +23,7 @@ import urllib.parse
 import warnings
 from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, 
cast, overload
 
-import attr
+import attrs
 from sqlalchemy import select
 
 from airflow.api_internal.internal_api_call import internal_api_call
@@ -123,6 +123,13 @@ def _validate_non_empty_identifier(instance, attribute, 
value):
     return value
 
 
+def _validate_asset_name(instance, attribute, value):
+    _validate_non_empty_identifier(instance, attribute, value)
+    if value == "self" or value == "context":
+        raise ValueError(f"prohibited name for asset: {value}")
+    return value
+
+
 def extract_event_key(value: str | Asset | AssetAlias) -> str:
     """
     Extract the key of an inlet or an outlet event.
@@ -158,6 +165,13 @@ def expand_alias_to_assets(alias: str | AssetAlias, *, 
session: Session = NEW_SE
     return []
 
 
[email protected](kw_only=True)
+class AssetRef:
+    """Reference to an asset."""
+
+    name: str
+
+
 class BaseAsset:
     """
     Protocol for all asset triggers to use in ``DAG(schedule=...)``.
@@ -207,16 +221,12 @@ class BaseAsset:
         raise NotImplementedError
 
 
[email protected](unsafe_hash=False)
[email protected](unsafe_hash=False)
 class AssetAlias(BaseAsset):
     """A represeation of asset alias which is used to create asset during the 
runtime."""
 
-    name: str = attr.field(validator=_validate_non_empty_identifier)
-    group: str = attr.field(
-        kw_only=True,
-        default="",
-        validator=[attr.validators.max_len(1500), _validate_identifier],
-    )
+    name: str = attrs.field(validator=_validate_non_empty_identifier)
+    group: str = attrs.field(kw_only=True, default="", 
validator=_validate_identifier)
 
     def iter_assets(self) -> Iterator[tuple[str, Asset]]:
         return iter(())
@@ -258,7 +268,7 @@ def _set_extra_default(extra: dict | None) -> dict:
     return extra
 
 
[email protected](init=False, unsafe_hash=False)
[email protected](init=False, unsafe_hash=False)
 class Asset(os.PathLike, BaseAsset):
     """A representation of data asset dependencies between workflows."""
 
@@ -267,7 +277,7 @@ class Asset(os.PathLike, BaseAsset):
     group: str
     extra: dict[str, Any]
 
-    asset_type: ClassVar[str] = ""
+    asset_type: ClassVar[str] = "asset"
     __version__: ClassVar[int] = 1
 
     @overload
@@ -296,8 +306,8 @@ class Asset(os.PathLike, BaseAsset):
             name = uri
         elif uri is None:
             uri = name
-        fields = attr.fields_dict(Asset)
-        self.name = _validate_non_empty_identifier(self, fields["name"], name)
+        fields = attrs.fields_dict(Asset)
+        self.name = _validate_asset_name(self, fields["name"], name)
         self.uri = _sanitize_uri(_validate_non_empty_identifier(self, 
fields["uri"], uri))
         self.group = _validate_identifier(self, fields["group"], group) if 
group else self.asset_type
         self.extra = _set_extra_default(extra)
diff --git a/airflow/decorators/assets.py b/airflow/decorators/assets.py
new file mode 100644
index 00000000000..2f5052c2d5c
--- /dev/null
+++ b/airflow/decorators/assets.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping
+
+import attrs
+
+from airflow.assets import Asset, AssetRef
+from airflow.models.asset import _fetch_active_assets_by_name
+from airflow.models.dag import DAG, ScheduleArg
+from airflow.providers.standard.operators.python import PythonOperator
+from airflow.utils.session import create_session
+
+if TYPE_CHECKING:
+    from airflow.io.path import ObjectStoragePath
+
+
+class _AssetMainOperator(PythonOperator):
+    def __init__(self, *, definition_name: str, uri: str | None = None, 
**kwargs) -> None:
+        super().__init__(**kwargs)
+        self._definition_name = definition_name
+        self._uri = uri
+
+    def _iter_kwargs(
+        self, context: Mapping[str, Any], active_assets: dict[str, Asset]
+    ) -> Iterator[tuple[str, Any]]:
+        value: Any
+        for key in inspect.signature(self.python_callable).parameters:
+            if key == "self":
+                value = active_assets.get(self._definition_name)
+            elif key == "context":
+                value = context
+            else:
+                value = active_assets.get(key, Asset(name=key))
+            yield key, value
+
+    def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, 
Any]:
+        active_assets: dict[str, Asset] = {}
+        asset_names = [asset_ref.name for asset_ref in self.inlets if 
isinstance(asset_ref, AssetRef)]
+        if "self" in inspect.signature(self.python_callable).parameters:
+            asset_names.append(self._definition_name)
+
+        if asset_names:
+            with create_session() as session:
+                active_assets = _fetch_active_assets_by_name(asset_names, 
session)
+        return dict(self._iter_kwargs(context, active_assets))
+
+
[email protected](kw_only=True)
+class AssetDefinition(Asset):
+    """
+    Asset representation from decorating a function with ``@asset``.
+
+    :meta private:
+    """
+
+    function: Callable
+    schedule: ScheduleArg
+
+    def __attrs_post_init__(self) -> None:
+        parameters = inspect.signature(self.function).parameters
+
+        with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
+            _AssetMainOperator(
+                task_id="__main__",
+                inlets=[
+                    AssetRef(name=inlet_asset_name)
+                    for inlet_asset_name in parameters
+                    if inlet_asset_name not in ("self", "context")
+                ],
+                outlets=[self.to_asset()],
+                python_callable=self.function,
+                definition_name=self.name,
+                uri=self.uri,
+            )
+
+    def to_asset(self) -> Asset:
+        return Asset(
+            name=self.name,
+            uri=self.uri,
+            group=self.group,
+            extra=self.extra,
+        )
+
+    def serialize(self):
+        return {
+            "uri": self.uri,
+            "name": self.name,
+            "group": self.group,
+            "extra": self.extra,
+        }
+
+
[email protected](kw_only=True)
+class asset:
+    """Create an asset by decorating a materialization function."""
+
+    schedule: ScheduleArg
+    uri: str | ObjectStoragePath | None = None
+    group: str = ""
+    extra: dict[str, Any] = attrs.field(factory=dict)
+
+    def __call__(self, f: Callable) -> AssetDefinition:
+        if (name := f.__name__) != f.__qualname__:
+            raise ValueError("nested function not supported")
+
+        return AssetDefinition(
+            name=name,
+            uri=name if self.uri is None else str(self.uri),
+            group=self.group,
+            extra=self.extra,
+            function=f,
+            schedule=self.schedule,
+        )
diff --git a/airflow/example_dags/example_asset_decorator.py 
b/airflow/example_dags/example_asset_decorator.py
new file mode 100644
index 00000000000..b4de09c2314
--- /dev/null
+++ b/airflow/example_dags/example_asset_decorator.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pendulum
+
+from airflow.assets import Asset
+from airflow.decorators import dag, task
+from airflow.decorators.assets import asset
+
+
+@asset(uri="s3://bucket/asset1_producer", schedule=None)
+def asset1_producer():
+    pass
+
+
+@asset(uri="s3://bucket/object", schedule=None)
+def asset2_producer(self, context, asset1_producer):
+    print(self)
+    print(context["inlet_events"][asset1_producer])
+
+
+@dag(
+    schedule=Asset(uri="s3://bucket/asset1_producer", name="asset1_producer")
+    | Asset(uri="s3://bucket/object", name="asset2_producer"),
+    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+    catchup=False,
+    tags=["consumes", "asset-scheduled"],
+)
+def consumes_asset_decorator():
+    @task(outlets=[Asset(name="process_nothing")])
+    def process_nothing():
+        pass
+
+    process_nothing()
+
+
+consumes_asset_decorator()
diff --git a/airflow/models/asset.py b/airflow/models/asset.py
index 8ade71bd0b1..50914d51650 100644
--- a/airflow/models/asset.py
+++ b/airflow/models/asset.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from typing import TYPE_CHECKING
 from urllib.parse import urlsplit
 
 import sqlalchemy_jsonfield
@@ -29,6 +30,7 @@ from sqlalchemy import (
     PrimaryKeyConstraint,
     String,
     Table,
+    select,
     text,
 )
 from sqlalchemy.orm import relationship
@@ -39,6 +41,26 @@ from airflow.settings import json
 from airflow.utils import timezone
 from airflow.utils.sqlalchemy import UtcDateTime
 
+if TYPE_CHECKING:
+    from typing import Sequence
+
+    from sqlalchemy.orm import Session
+
+
+def _fetch_active_assets_by_name(
+    names: Sequence[str],
+    session: Session,
+) -> dict[str, Asset]:
+    return {
+        asset_model[0].name: asset_model[0].to_public()
+        for asset_model in session.execute(
+            select(AssetModel)
+            .join(AssetActive, AssetActive.name == AssetModel.name)
+            .where(AssetActive.name.in_(name for name in names))
+        )
+    }
+
+
 alias_association_table = Table(
     "asset_alias_asset",
     Base.metadata,
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e6a67c6ad7e..e48ec0a9a9c 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -777,9 +777,7 @@ class DAG(TaskSDKDag, LoggingMixin):
     @classmethod
     def get_serialized_fields(cls):
         """Stringified DAGs and operators contain exactly these fields."""
-        return TaskSDKDag.get_serialized_fields() | {
-            "_processor_dags_folder",
-        }
+        return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"}
 
     @staticmethod
     @internal_api_call
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index dd63366b8a9..d1b946b38ef 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -59,6 +59,7 @@ class DagAttributeTypes(str, Enum):
     ASSET_ALIAS = "asset_alias"
     ASSET_ANY = "asset_any"
     ASSET_ALL = "asset_all"
+    ASSET_REF = "asset_ref"
     SIMPLE_TASK_INSTANCE = "simple_task_instance"
     BASE_JOB = "Job"
     TASK_INSTANCE = "task_instance"
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index b26b5933981..1e7232aa81a 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -53,6 +53,21 @@
         { "type": "integer" }
       ]
     },
+    "asset_definition": {
+      "type": "object",
+      "properties": {
+        "uri": { "type": "string" },
+        "name": { "type": "string" },
+        "group": { "type": "string" },
+        "extra": {
+            "anyOf": [
+                {"type": "null"},
+                { "$ref": "#/definitions/dict" }
+            ]
+        }
+      },
+      "required": [ "uri", "extra" ]
+    },
     "asset": {
       "type": "object",
       "properties": {
@@ -153,7 +168,7 @@
         "_processor_dags_folder": {
             "anyOf": [
                 { "type": "null" },
-                {"type": "string"}
+                { "type": "string" }
             ]
         },
         "dag_display_name": { "type" : "string"},
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 52b0bcb1530..4b7ee6d0871 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -40,6 +40,7 @@ from airflow.assets import (
     AssetAlias,
     AssetAll,
     AssetAny,
+    AssetRef,
     BaseAsset,
     _AssetAliasCondition,
 )
@@ -254,7 +255,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str, 
Any]:
     :meta private:
     """
     if isinstance(var, Asset):
-        return {"__type": DAT.ASSET, "uri": var.uri, "extra": var.extra}
+        return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri, 
"extra": var.extra}
     if isinstance(var, AssetAlias):
         return {"__type": DAT.ASSET_ALIAS, "name": var.name}
     if isinstance(var, AssetAll):
@@ -272,7 +273,7 @@ def decode_asset_condition(var: dict[str, Any]) -> 
BaseAsset:
     """
     dat = var["__type"]
     if dat == DAT.ASSET:
-        return Asset(var["uri"], extra=var["extra"])
+        return Asset(uri=var["uri"], name=var["name"], extra=var["extra"])
     if dat == DAT.ASSET_ALL:
         return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
     if dat == DAT.ASSET_ANY:
@@ -743,6 +744,8 @@ class BaseSerialization:
         elif isinstance(var, BaseAsset):
             serialized_asset = encode_asset_condition(var)
             return cls._encode(serialized_asset, 
type_=serialized_asset.pop("__type"))
+        elif isinstance(var, AssetRef):
+            return cls._encode({"name": var.name}, type_=DAT.ASSET_REF)
         elif isinstance(var, SimpleTaskInstance):
             return cls._encode(
                 cls.serialize(var.__dict__, strict=strict, 
use_pydantic_models=use_pydantic_models),
@@ -876,6 +879,8 @@ class BaseSerialization:
             return AssetAny(*(decode_asset_condition(x) for x in 
var["objects"]))
         elif type_ == DAT.ASSET_ALL:
             return AssetAll(*(decode_asset_condition(x) for x in 
var["objects"]))
+        elif type_ == DAT.ASSET_REF:
+            return AssetRef(name=var["name"])
         elif type_ == DAT.SIMPLE_TASK_INSTANCE:
             return SimpleTaskInstance(**cls.deserialize(var))
         elif type_ == DAT.CONNECTION:
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 9af1486f914..b28559999e5 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -44,10 +44,11 @@ from airflow.assets import (
     Asset,
     AssetAlias,
     AssetAliasEvent,
+    AssetRef,
     extract_event_key,
 )
 from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel
+from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel, 
_fetch_active_assets_by_name
 from airflow.utils.db import LazySelectSequence
 from airflow.utils.types import NOTSET
 
@@ -257,11 +258,18 @@ class InletEventsAccessors(Mapping[str, 
LazyAssetEventSelectSequence]):
         self._assets = {}
         self._asset_aliases = {}
 
+        _asset_ref_names: list[str] = []
         for inlet in inlets:
             if isinstance(inlet, Asset):
-                self._assets[inlet.uri] = inlet
+                self._assets[inlet.name] = inlet
             elif isinstance(inlet, AssetAlias):
                 self._asset_aliases[inlet.name] = inlet
+            elif isinstance(inlet, AssetRef):
+                _asset_ref_names.append(inlet.name)
+
+        if _asset_ref_names:
+            for asset_name, asset in 
_fetch_active_assets_by_name(_asset_ref_names, self._session).items():
+                self._assets[asset_name] = asset
 
     def __iter__(self) -> Iterator[str]:
         return iter(self._inlets)
@@ -272,7 +280,7 @@ class InletEventsAccessors(Mapping[str, 
LazyAssetEventSelectSequence]):
     def __getitem__(self, key: int | str | Asset | AssetAlias) -> 
LazyAssetEventSelectSequence:
         if isinstance(key, int):  # Support index access; it's easier for 
trivial cases.
             obj = self._inlets[key]
-            if not isinstance(obj, (Asset, AssetAlias)):
+            if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
                 raise IndexError(key)
         else:
             obj = key
@@ -281,10 +289,13 @@ class InletEventsAccessors(Mapping[str, 
LazyAssetEventSelectSequence]):
             asset_alias = self._asset_aliases[obj.name]
             join_clause = AssetEvent.source_aliases
             where_clause = AssetAliasModel.name == asset_alias.name
-        elif isinstance(obj, (Asset, str)):
+        elif isinstance(obj, (Asset, AssetRef)):
+            join_clause = AssetEvent.asset
+            where_clause = AssetModel.name == self._assets[obj.name].name
+        elif isinstance(obj, str):
             asset = self._assets[extract_event_key(obj)]
             join_clause = AssetEvent.asset
-            where_clause = AssetModel.uri == asset.uri
+            where_clause = AssetModel.name == asset.name
         else:
             raise ValueError(key)
 
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index 86b7a7891ca..962f97c8fcf 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -328,7 +328,9 @@ def might_contain_dag_via_default_heuristic(file_path: str, 
zip_file: zipfile.Zi
         with open(file_path, "rb") as dag_file:
             content = dag_file.read()
     content = content.lower()
-    return all(s in content for s in (b"dag", b"airflow"))
+    if b"airflow" not in content:
+        return False
+    return any(s in content for s in (b"dag", b"asset"))
 
 
 def _find_imported_modules(module: ast.Module) -> Generator[str, None, None]:
diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py 
b/tests/api_fastapi/core_api/routes/ui/test_assets.py
index b71d80ae9d3..b5c85b98ba6 100644
--- a/tests/api_fastapi/core_api/routes/ui/test_assets.py
+++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py
@@ -47,5 +47,5 @@ def test_next_run_assets(test_client, dag_maker):
     assert response.status_code == 200
     assert response.json() == {
         "asset_expression": {"all": ["s3://bucket/key/1"]},
-        "events": [{"id": 17, "uri": "s3://bucket/key/1", "lastUpdate": None}],
+        "events": [{"id": 20, "uri": "s3://bucket/key/1", "lastUpdate": None}],
     }
diff --git a/tests/decorators/test_assets.py b/tests/decorators/test_assets.py
new file mode 100644
index 00000000000..a3821140e54
--- /dev/null
+++ b/tests/decorators/test_assets.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import ANY
+
+import pytest
+
+from airflow.assets import Asset
+from airflow.decorators.assets import AssetRef, _AssetMainOperator, asset
+from airflow.models.asset import AssetActive, AssetModel
+
+pytestmark = pytest.mark.db_test
+
+
[email protected]
+def example_asset_func(request):
+    name = "example_asset_func"
+    if getattr(request, "param", None) is not None:
+        name = request.param
+
+    def _example_asset_func():
+        return "This is example_asset"
+
+    _example_asset_func.__name__ = name
+    _example_asset_func.__qualname__ = name
+    return _example_asset_func
+
+
[email protected]
+def example_asset_definition(example_asset_func):
+    return asset(schedule=None, uri="s3://bucket/object", group="MLModel", 
extra={"k": "v"})(
+        example_asset_func
+    )
+
+
[email protected]
+def example_asset_func_with_valid_arg_as_inlet_asset():
+    def _example_asset_func(self, context, inlet_asset_1, inlet_asset_2):
+        return "This is example_asset"
+
+    _example_asset_func.__name__ = "example_asset_func"
+    _example_asset_func.__qualname__ = "example_asset_func"
+    return _example_asset_func
+
+
+class TestAssetDecorator:
+    def test_without_uri(self, example_asset_func):
+        asset_definition = asset(schedule=None)(example_asset_func)
+
+        assert asset_definition.name == "example_asset_func"
+        assert asset_definition.uri == "example_asset_func"
+        assert asset_definition.group == ""
+        assert asset_definition.extra == {}
+        assert asset_definition.function == example_asset_func
+        assert asset_definition.schedule is None
+
+    def test_with_uri(self, example_asset_func):
+        asset_definition = asset(schedule=None, 
uri="s3://bucket/object")(example_asset_func)
+
+        assert asset_definition.name == "example_asset_func"
+        assert asset_definition.uri == "s3://bucket/object"
+        assert asset_definition.group == ""
+        assert asset_definition.extra == {}
+        assert asset_definition.function == example_asset_func
+        assert asset_definition.schedule is None
+
+    def test_with_group_and_extra(self, example_asset_func):
+        asset_definition = asset(schedule=None, uri="s3://bucket/object", 
group="MLModel", extra={"k": "v"})(
+            example_asset_func
+        )
+        assert asset_definition.name == "example_asset_func"
+        assert asset_definition.uri == "s3://bucket/object"
+        assert asset_definition.group == "MLModel"
+        assert asset_definition.extra == {"k": "v"}
+        assert asset_definition.function == example_asset_func
+        assert asset_definition.schedule is None
+
+    def test_nested_function(self):
+        def root_func():
+            @asset(schedule=None)
+            def asset_func():
+                pass
+
+        with pytest.raises(ValueError) as err:
+            root_func()
+
+        assert err.value.args[0] == "nested function not supported"
+
+    @pytest.mark.parametrize("example_asset_func", ("self", "context"), 
indirect=True)
+    def test_with_invalid_asset_name(self, example_asset_func):
+        with pytest.raises(ValueError) as err:
+            asset(schedule=None)(example_asset_func)
+
+        assert err.value.args[0].startswith("prohibited name for asset: ")
+
+
+class TestAssetDefinition:
+    def test_serialzie(self, example_asset_definition):
+        assert example_asset_definition.serialize() == {
+            "extra": {"k": "v"},
+            "group": "MLModel",
+            "name": "example_asset_func",
+            "uri": "s3://bucket/object",
+        }
+
+    @mock.patch("airflow.decorators.assets._AssetMainOperator")
+    @mock.patch("airflow.decorators.assets.DAG")
+    def test__attrs_post_init__(
+        self, DAG, _AssetMainOperator, 
example_asset_func_with_valid_arg_as_inlet_asset
+    ):
+        asset_definition = asset(schedule=None, uri="s3://bucket/object", 
group="MLModel", extra={"k": "v"})(
+            example_asset_func_with_valid_arg_as_inlet_asset
+        )
+
+        DAG.assert_called_once_with(dag_id="example_asset_func", 
schedule=None, auto_register=True)
+        _AssetMainOperator.assert_called_once_with(
+            task_id="__main__",
+            inlets=[
+                AssetRef(name="inlet_asset_1"),
+                AssetRef(name="inlet_asset_2"),
+            ],
+            outlets=[asset_definition.to_asset()],
+            python_callable=ANY,
+            definition_name="example_asset_func",
+            uri="s3://bucket/object",
+        )
+
+        python_callable = 
_AssetMainOperator.call_args.kwargs["python_callable"]
+        assert python_callable == 
example_asset_func_with_valid_arg_as_inlet_asset
+
+
+class Test_AssetMainOperator:
+    def test_determine_kwargs(self, 
example_asset_func_with_valid_arg_as_inlet_asset, session):
+        example_asset_model = AssetModel(uri="s3://bucket/object1", 
name="inlet_asset_1")
+        asset_definition = asset(schedule=None, uri="s3://bucket/object", 
group="MLModel", extra={"k": "v"})(
+            example_asset_func_with_valid_arg_as_inlet_asset
+        )
+
+        ad_asset_model = AssetModel.from_public(asset_definition)
+
+        session.add(example_asset_model)
+        session.add(ad_asset_model)
+        session.add(AssetActive.for_asset(example_asset_model))
+        session.add(AssetActive.for_asset(ad_asset_model))
+        session.commit()
+
+        op = _AssetMainOperator(
+            task_id="__main__",
+            inlets=[AssetRef(name="inlet_asset_1"), 
AssetRef(name="inlet_asset_2")],
+            outlets=[asset_definition],
+            python_callable=example_asset_func_with_valid_arg_as_inlet_asset,
+            definition_name="example_asset_func",
+        )
+        assert op.determine_kwargs(context={"k": "v"}) == {
+            "self": Asset(
+                name="example_asset_func", uri="s3://bucket/object", 
group="MLModel", extra={"k": "v"}
+            ),
+            "context": {"k": "v"},
+            "inlet_asset_1": Asset(name="inlet_asset_1", 
uri="s3://bucket/object1"),
+            "inlet_asset_2": Asset(name="inlet_asset_2"),
+        }
diff --git a/tests/timetables/test_assets_timetable.py 
b/tests/timetables/test_assets_timetable.py
index f461afa31c8..bb942a4a01d 100644
--- a/tests/timetables/test_assets_timetable.py
+++ b/tests/timetables/test_assets_timetable.py
@@ -134,7 +134,7 @@ def test_serialization(asset_timetable: 
AssetOrTimeSchedule, monkeypatch: Any) -
         "timetable": "mock_serialized_timetable",
         "asset_condition": {
             "__type": "asset_all",
-            "objects": [{"__type": "asset", "uri": "test_asset", "extra": {}}],
+            "objects": [{"__type": "asset", "uri": "test_asset", "name": 
"test_asset", "extra": {}}],
         },
     }
 
@@ -152,7 +152,7 @@ def test_deserialization(monkeypatch: Any) -> None:
         "timetable": "mock_serialized_timetable",
         "asset_condition": {
             "__type": "asset_all",
-            "objects": [{"__type": "asset", "uri": "test_asset", "extra": 
None}],
+            "objects": [{"__type": "asset", "name": "test_asset", "uri": 
"test_asset", "extra": None}],
         },
     }
     deserialized = AssetOrTimeSchedule.deserialize(mock_serialized_data)

Reply via email to