This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b3362f841f3 Add "@asset" to decorate a function as a DAG and an asset
(#41325)
b3362f841f3 is described below
commit b3362f841f31b5629f09f619cfa2d69e7199493a
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Nov 14 17:14:45 2024 +0800
Add "@asset" to decorate a function as a DAG and an asset (#41325)
* Implement asset definition creating a DAG
* Basic inlet dependency
* Make AssetDefinition subclass Asset
This seems to be the best way for 'schedule' dependencies to work. Still
not entirely sure; we'll revisit this.
* style: fix mypy error
* feat(asset): allow uri to be None
* fix: temporarily serialize AssetDefintion into a string
* feat(decorators/assets): rewrite how asset definition is serialized
* test(decorators/assets): add test cases to check whether asset decorator
generate the right asset definition
* test(decorators/assets): add test cases to AssetDefinition
* test(decorators/asset): add test cases to Test_AssetMainOperator
* test(decorators/assets): remove unused fixtures
* docs(example_dag): add example dag for asset_decorator
* feat(decorators/assets): allow passing self and context into asset
* feat(decorators/assets): return actual asset in asset decorator
* refactor(decorators/assets): extract active assets fetching logic as
_fetch_active_assets_by_name
* feat(decorators/assets): allow fethcing inlet events through AssetRef
* feat(decorators/assets): reorder import paths
* docs: update asset decorator example dag
* test: fix tests
* test(decorators/assets): extend test_determine_kwargs to cover active
asset
* fix: address easy to fix comments
* fix: fix asset serialization
* refactor(decorators/assets): postpone the attribute check to
AssetDefinition instead of asset decorator
* Simplify group validators
The validate_identifier validator already checks the length, so we don't
need an extra one doing that.
* style(dag): remove _wrapped_definition
* style(decorators/assets): change types.FunctionType to Callable
* refactor(decorators/assets): make session in _fetch_active_assets_by_name
required
* fix(decorators/asets): remove DAG.bulk_write_to_db and remove self
handling
* feat(utils/context): fetch asset_refs all at once
---------
Co-authored-by: Wei Lee <[email protected]>
---
airflow/assets/__init__.py | 34 ++--
airflow/decorators/assets.py | 131 +++++++++++++++
airflow/example_dags/example_asset_decorator.py | 52 ++++++
airflow/models/asset.py | 22 +++
airflow/models/dag.py | 4 +-
airflow/serialization/enums.py | 1 +
airflow/serialization/schema.json | 17 +-
airflow/serialization/serialized_objects.py | 9 +-
airflow/utils/context.py | 21 ++-
airflow/utils/file.py | 4 +-
.../api_fastapi/core_api/routes/ui/test_assets.py | 2 +-
tests/decorators/test_assets.py | 177 +++++++++++++++++++++
tests/timetables/test_assets_timetable.py | 4 +-
13 files changed, 451 insertions(+), 27 deletions(-)
diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py
index 59e0b866844..f1d36ac12b7 100644
--- a/airflow/assets/__init__.py
+++ b/airflow/assets/__init__.py
@@ -23,7 +23,7 @@ import urllib.parse
import warnings
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator,
cast, overload
-import attr
+import attrs
from sqlalchemy import select
from airflow.api_internal.internal_api_call import internal_api_call
@@ -123,6 +123,13 @@ def _validate_non_empty_identifier(instance, attribute,
value):
return value
+def _validate_asset_name(instance, attribute, value):
+ _validate_non_empty_identifier(instance, attribute, value)
+ if value == "self" or value == "context":
+ raise ValueError(f"prohibited name for asset: {value}")
+ return value
+
+
def extract_event_key(value: str | Asset | AssetAlias) -> str:
"""
Extract the key of an inlet or an outlet event.
@@ -158,6 +165,13 @@ def expand_alias_to_assets(alias: str | AssetAlias, *,
session: Session = NEW_SE
return []
[email protected](kw_only=True)
+class AssetRef:
+ """Reference to an asset."""
+
+ name: str
+
+
class BaseAsset:
"""
Protocol for all asset triggers to use in ``DAG(schedule=...)``.
@@ -207,16 +221,12 @@ class BaseAsset:
raise NotImplementedError
[email protected](unsafe_hash=False)
[email protected](unsafe_hash=False)
class AssetAlias(BaseAsset):
"""A represeation of asset alias which is used to create asset during the
runtime."""
- name: str = attr.field(validator=_validate_non_empty_identifier)
- group: str = attr.field(
- kw_only=True,
- default="",
- validator=[attr.validators.max_len(1500), _validate_identifier],
- )
+ name: str = attrs.field(validator=_validate_non_empty_identifier)
+ group: str = attrs.field(kw_only=True, default="",
validator=_validate_identifier)
def iter_assets(self) -> Iterator[tuple[str, Asset]]:
return iter(())
@@ -258,7 +268,7 @@ def _set_extra_default(extra: dict | None) -> dict:
return extra
[email protected](init=False, unsafe_hash=False)
[email protected](init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""
@@ -267,7 +277,7 @@ class Asset(os.PathLike, BaseAsset):
group: str
extra: dict[str, Any]
- asset_type: ClassVar[str] = ""
+ asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1
@overload
@@ -296,8 +306,8 @@ class Asset(os.PathLike, BaseAsset):
name = uri
elif uri is None:
uri = name
- fields = attr.fields_dict(Asset)
- self.name = _validate_non_empty_identifier(self, fields["name"], name)
+ fields = attrs.fields_dict(Asset)
+ self.name = _validate_asset_name(self, fields["name"], name)
self.uri = _sanitize_uri(_validate_non_empty_identifier(self,
fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if
group else self.asset_type
self.extra = _set_extra_default(extra)
diff --git a/airflow/decorators/assets.py b/airflow/decorators/assets.py
new file mode 100644
index 00000000000..2f5052c2d5c
--- /dev/null
+++ b/airflow/decorators/assets.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping
+
+import attrs
+
+from airflow.assets import Asset, AssetRef
+from airflow.models.asset import _fetch_active_assets_by_name
+from airflow.models.dag import DAG, ScheduleArg
+from airflow.providers.standard.operators.python import PythonOperator
+from airflow.utils.session import create_session
+
+if TYPE_CHECKING:
+ from airflow.io.path import ObjectStoragePath
+
+
+class _AssetMainOperator(PythonOperator):
+ def __init__(self, *, definition_name: str, uri: str | None = None,
**kwargs) -> None:
+ super().__init__(**kwargs)
+ self._definition_name = definition_name
+ self._uri = uri
+
+ def _iter_kwargs(
+ self, context: Mapping[str, Any], active_assets: dict[str, Asset]
+ ) -> Iterator[tuple[str, Any]]:
+ value: Any
+ for key in inspect.signature(self.python_callable).parameters:
+ if key == "self":
+ value = active_assets.get(self._definition_name)
+ elif key == "context":
+ value = context
+ else:
+ value = active_assets.get(key, Asset(name=key))
+ yield key, value
+
+ def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str,
Any]:
+ active_assets: dict[str, Asset] = {}
+ asset_names = [asset_ref.name for asset_ref in self.inlets if
isinstance(asset_ref, AssetRef)]
+ if "self" in inspect.signature(self.python_callable).parameters:
+ asset_names.append(self._definition_name)
+
+ if asset_names:
+ with create_session() as session:
+ active_assets = _fetch_active_assets_by_name(asset_names,
session)
+ return dict(self._iter_kwargs(context, active_assets))
+
+
[email protected](kw_only=True)
+class AssetDefinition(Asset):
+ """
+ Asset representation from decorating a function with ``@asset``.
+
+ :meta private:
+ """
+
+ function: Callable
+ schedule: ScheduleArg
+
+ def __attrs_post_init__(self) -> None:
+ parameters = inspect.signature(self.function).parameters
+
+ with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
+ _AssetMainOperator(
+ task_id="__main__",
+ inlets=[
+ AssetRef(name=inlet_asset_name)
+ for inlet_asset_name in parameters
+ if inlet_asset_name not in ("self", "context")
+ ],
+ outlets=[self.to_asset()],
+ python_callable=self.function,
+ definition_name=self.name,
+ uri=self.uri,
+ )
+
+ def to_asset(self) -> Asset:
+ return Asset(
+ name=self.name,
+ uri=self.uri,
+ group=self.group,
+ extra=self.extra,
+ )
+
+ def serialize(self):
+ return {
+ "uri": self.uri,
+ "name": self.name,
+ "group": self.group,
+ "extra": self.extra,
+ }
+
+
[email protected](kw_only=True)
+class asset:
+ """Create an asset by decorating a materialization function."""
+
+ schedule: ScheduleArg
+ uri: str | ObjectStoragePath | None = None
+ group: str = ""
+ extra: dict[str, Any] = attrs.field(factory=dict)
+
+ def __call__(self, f: Callable) -> AssetDefinition:
+ if (name := f.__name__) != f.__qualname__:
+ raise ValueError("nested function not supported")
+
+ return AssetDefinition(
+ name=name,
+ uri=name if self.uri is None else str(self.uri),
+ group=self.group,
+ extra=self.extra,
+ function=f,
+ schedule=self.schedule,
+ )
diff --git a/airflow/example_dags/example_asset_decorator.py
b/airflow/example_dags/example_asset_decorator.py
new file mode 100644
index 00000000000..b4de09c2314
--- /dev/null
+++ b/airflow/example_dags/example_asset_decorator.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pendulum
+
+from airflow.assets import Asset
+from airflow.decorators import dag, task
+from airflow.decorators.assets import asset
+
+
+@asset(uri="s3://bucket/asset1_producer", schedule=None)
+def asset1_producer():
+ pass
+
+
+@asset(uri="s3://bucket/object", schedule=None)
+def asset2_producer(self, context, asset1_producer):
+ print(self)
+ print(context["inlet_events"][asset1_producer])
+
+
+@dag(
+ schedule=Asset(uri="s3://bucket/asset1_producer", name="asset1_producer")
+ | Asset(uri="s3://bucket/object", name="asset2_producer"),
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=["consumes", "asset-scheduled"],
+)
+def consumes_asset_decorator():
+ @task(outlets=[Asset(name="process_nothing")])
+ def process_nothing():
+ pass
+
+ process_nothing()
+
+
+consumes_asset_decorator()
diff --git a/airflow/models/asset.py b/airflow/models/asset.py
index 8ade71bd0b1..50914d51650 100644
--- a/airflow/models/asset.py
+++ b/airflow/models/asset.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from typing import TYPE_CHECKING
from urllib.parse import urlsplit
import sqlalchemy_jsonfield
@@ -29,6 +30,7 @@ from sqlalchemy import (
PrimaryKeyConstraint,
String,
Table,
+ select,
text,
)
from sqlalchemy.orm import relationship
@@ -39,6 +41,26 @@ from airflow.settings import json
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime
+if TYPE_CHECKING:
+ from typing import Sequence
+
+ from sqlalchemy.orm import Session
+
+
+def _fetch_active_assets_by_name(
+ names: Sequence[str],
+ session: Session,
+) -> dict[str, Asset]:
+ return {
+ asset_model[0].name: asset_model[0].to_public()
+ for asset_model in session.execute(
+ select(AssetModel)
+ .join(AssetActive, AssetActive.name == AssetModel.name)
+ .where(AssetActive.name.in_(name for name in names))
+ )
+ }
+
+
alias_association_table = Table(
"asset_alias_asset",
Base.metadata,
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e6a67c6ad7e..e48ec0a9a9c 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -777,9 +777,7 @@ class DAG(TaskSDKDag, LoggingMixin):
@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
- return TaskSDKDag.get_serialized_fields() | {
- "_processor_dags_folder",
- }
+ return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"}
@staticmethod
@internal_api_call
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index dd63366b8a9..d1b946b38ef 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -59,6 +59,7 @@ class DagAttributeTypes(str, Enum):
ASSET_ALIAS = "asset_alias"
ASSET_ANY = "asset_any"
ASSET_ALL = "asset_all"
+ ASSET_REF = "asset_ref"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
BASE_JOB = "Job"
TASK_INSTANCE = "task_instance"
diff --git a/airflow/serialization/schema.json
b/airflow/serialization/schema.json
index b26b5933981..1e7232aa81a 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -53,6 +53,21 @@
{ "type": "integer" }
]
},
+ "asset_definition": {
+ "type": "object",
+ "properties": {
+ "uri": { "type": "string" },
+ "name": { "type": "string" },
+ "group": { "type": "string" },
+ "extra": {
+ "anyOf": [
+ {"type": "null"},
+ { "$ref": "#/definitions/dict" }
+ ]
+ }
+ },
+ "required": [ "uri", "extra" ]
+ },
"asset": {
"type": "object",
"properties": {
@@ -153,7 +168,7 @@
"_processor_dags_folder": {
"anyOf": [
{ "type": "null" },
- {"type": "string"}
+ { "type": "string" }
]
},
"dag_display_name": { "type" : "string"},
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 52b0bcb1530..4b7ee6d0871 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -40,6 +40,7 @@ from airflow.assets import (
AssetAlias,
AssetAll,
AssetAny,
+ AssetRef,
BaseAsset,
_AssetAliasCondition,
)
@@ -254,7 +255,7 @@ def encode_asset_condition(var: BaseAsset) -> dict[str,
Any]:
:meta private:
"""
if isinstance(var, Asset):
- return {"__type": DAT.ASSET, "uri": var.uri, "extra": var.extra}
+ return {"__type": DAT.ASSET, "name": var.name, "uri": var.uri,
"extra": var.extra}
if isinstance(var, AssetAlias):
return {"__type": DAT.ASSET_ALIAS, "name": var.name}
if isinstance(var, AssetAll):
@@ -272,7 +273,7 @@ def decode_asset_condition(var: dict[str, Any]) ->
BaseAsset:
"""
dat = var["__type"]
if dat == DAT.ASSET:
- return Asset(var["uri"], extra=var["extra"])
+ return Asset(uri=var["uri"], name=var["name"], extra=var["extra"])
if dat == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in var["objects"]))
if dat == DAT.ASSET_ANY:
@@ -743,6 +744,8 @@ class BaseSerialization:
elif isinstance(var, BaseAsset):
serialized_asset = encode_asset_condition(var)
return cls._encode(serialized_asset,
type_=serialized_asset.pop("__type"))
+ elif isinstance(var, AssetRef):
+ return cls._encode({"name": var.name}, type_=DAT.ASSET_REF)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict,
use_pydantic_models=use_pydantic_models),
@@ -876,6 +879,8 @@ class BaseSerialization:
return AssetAny(*(decode_asset_condition(x) for x in
var["objects"]))
elif type_ == DAT.ASSET_ALL:
return AssetAll(*(decode_asset_condition(x) for x in
var["objects"]))
+ elif type_ == DAT.ASSET_REF:
+ return AssetRef(name=var["name"])
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.CONNECTION:
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 9af1486f914..b28559999e5 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -44,10 +44,11 @@ from airflow.assets import (
Asset,
AssetAlias,
AssetAliasEvent,
+ AssetRef,
extract_event_key,
)
from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel
+from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel,
_fetch_active_assets_by_name
from airflow.utils.db import LazySelectSequence
from airflow.utils.types import NOTSET
@@ -257,11 +258,18 @@ class InletEventsAccessors(Mapping[str,
LazyAssetEventSelectSequence]):
self._assets = {}
self._asset_aliases = {}
+ _asset_ref_names: list[str] = []
for inlet in inlets:
if isinstance(inlet, Asset):
- self._assets[inlet.uri] = inlet
+ self._assets[inlet.name] = inlet
elif isinstance(inlet, AssetAlias):
self._asset_aliases[inlet.name] = inlet
+ elif isinstance(inlet, AssetRef):
+ _asset_ref_names.append(inlet.name)
+
+ if _asset_ref_names:
+ for asset_name, asset in
_fetch_active_assets_by_name(_asset_ref_names, self._session).items():
+ self._assets[asset_name] = asset
def __iter__(self) -> Iterator[str]:
return iter(self._inlets)
@@ -272,7 +280,7 @@ class InletEventsAccessors(Mapping[str,
LazyAssetEventSelectSequence]):
def __getitem__(self, key: int | str | Asset | AssetAlias) ->
LazyAssetEventSelectSequence:
if isinstance(key, int): # Support index access; it's easier for
trivial cases.
obj = self._inlets[key]
- if not isinstance(obj, (Asset, AssetAlias)):
+ if not isinstance(obj, (Asset, AssetAlias, AssetRef)):
raise IndexError(key)
else:
obj = key
@@ -281,10 +289,13 @@ class InletEventsAccessors(Mapping[str,
LazyAssetEventSelectSequence]):
asset_alias = self._asset_aliases[obj.name]
join_clause = AssetEvent.source_aliases
where_clause = AssetAliasModel.name == asset_alias.name
- elif isinstance(obj, (Asset, str)):
+ elif isinstance(obj, (Asset, AssetRef)):
+ join_clause = AssetEvent.asset
+ where_clause = AssetModel.name == self._assets[obj.name].name
+ elif isinstance(obj, str):
asset = self._assets[extract_event_key(obj)]
join_clause = AssetEvent.asset
- where_clause = AssetModel.uri == asset.uri
+ where_clause = AssetModel.name == asset.name
else:
raise ValueError(key)
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index 86b7a7891ca..962f97c8fcf 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -328,7 +328,9 @@ def might_contain_dag_via_default_heuristic(file_path: str,
zip_file: zipfile.Zi
with open(file_path, "rb") as dag_file:
content = dag_file.read()
content = content.lower()
- return all(s in content for s in (b"dag", b"airflow"))
+ if b"airflow" not in content:
+ return False
+ return any(s in content for s in (b"dag", b"asset"))
def _find_imported_modules(module: ast.Module) -> Generator[str, None, None]:
diff --git a/tests/api_fastapi/core_api/routes/ui/test_assets.py
b/tests/api_fastapi/core_api/routes/ui/test_assets.py
index b71d80ae9d3..b5c85b98ba6 100644
--- a/tests/api_fastapi/core_api/routes/ui/test_assets.py
+++ b/tests/api_fastapi/core_api/routes/ui/test_assets.py
@@ -47,5 +47,5 @@ def test_next_run_assets(test_client, dag_maker):
assert response.status_code == 200
assert response.json() == {
"asset_expression": {"all": ["s3://bucket/key/1"]},
- "events": [{"id": 17, "uri": "s3://bucket/key/1", "lastUpdate": None}],
+ "events": [{"id": 20, "uri": "s3://bucket/key/1", "lastUpdate": None}],
}
diff --git a/tests/decorators/test_assets.py b/tests/decorators/test_assets.py
new file mode 100644
index 00000000000..a3821140e54
--- /dev/null
+++ b/tests/decorators/test_assets.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import ANY
+
+import pytest
+
+from airflow.assets import Asset
+from airflow.decorators.assets import AssetRef, _AssetMainOperator, asset
+from airflow.models.asset import AssetActive, AssetModel
+
+pytestmark = pytest.mark.db_test
+
+
[email protected]
+def example_asset_func(request):
+ name = "example_asset_func"
+ if getattr(request, "param", None) is not None:
+ name = request.param
+
+ def _example_asset_func():
+ return "This is example_asset"
+
+ _example_asset_func.__name__ = name
+ _example_asset_func.__qualname__ = name
+ return _example_asset_func
+
+
[email protected]
+def example_asset_definition(example_asset_func):
+ return asset(schedule=None, uri="s3://bucket/object", group="MLModel",
extra={"k": "v"})(
+ example_asset_func
+ )
+
+
[email protected]
+def example_asset_func_with_valid_arg_as_inlet_asset():
+ def _example_asset_func(self, context, inlet_asset_1, inlet_asset_2):
+ return "This is example_asset"
+
+ _example_asset_func.__name__ = "example_asset_func"
+ _example_asset_func.__qualname__ = "example_asset_func"
+ return _example_asset_func
+
+
+class TestAssetDecorator:
+ def test_without_uri(self, example_asset_func):
+ asset_definition = asset(schedule=None)(example_asset_func)
+
+ assert asset_definition.name == "example_asset_func"
+ assert asset_definition.uri == "example_asset_func"
+ assert asset_definition.group == ""
+ assert asset_definition.extra == {}
+ assert asset_definition.function == example_asset_func
+ assert asset_definition.schedule is None
+
+ def test_with_uri(self, example_asset_func):
+ asset_definition = asset(schedule=None,
uri="s3://bucket/object")(example_asset_func)
+
+ assert asset_definition.name == "example_asset_func"
+ assert asset_definition.uri == "s3://bucket/object"
+ assert asset_definition.group == ""
+ assert asset_definition.extra == {}
+ assert asset_definition.function == example_asset_func
+ assert asset_definition.schedule is None
+
+ def test_with_group_and_extra(self, example_asset_func):
+ asset_definition = asset(schedule=None, uri="s3://bucket/object",
group="MLModel", extra={"k": "v"})(
+ example_asset_func
+ )
+ assert asset_definition.name == "example_asset_func"
+ assert asset_definition.uri == "s3://bucket/object"
+ assert asset_definition.group == "MLModel"
+ assert asset_definition.extra == {"k": "v"}
+ assert asset_definition.function == example_asset_func
+ assert asset_definition.schedule is None
+
+ def test_nested_function(self):
+ def root_func():
+ @asset(schedule=None)
+ def asset_func():
+ pass
+
+ with pytest.raises(ValueError) as err:
+ root_func()
+
+ assert err.value.args[0] == "nested function not supported"
+
+ @pytest.mark.parametrize("example_asset_func", ("self", "context"),
indirect=True)
+ def test_with_invalid_asset_name(self, example_asset_func):
+ with pytest.raises(ValueError) as err:
+ asset(schedule=None)(example_asset_func)
+
+ assert err.value.args[0].startswith("prohibited name for asset: ")
+
+
+class TestAssetDefinition:
+ def test_serialzie(self, example_asset_definition):
+ assert example_asset_definition.serialize() == {
+ "extra": {"k": "v"},
+ "group": "MLModel",
+ "name": "example_asset_func",
+ "uri": "s3://bucket/object",
+ }
+
+ @mock.patch("airflow.decorators.assets._AssetMainOperator")
+ @mock.patch("airflow.decorators.assets.DAG")
+ def test__attrs_post_init__(
+ self, DAG, _AssetMainOperator,
example_asset_func_with_valid_arg_as_inlet_asset
+ ):
+ asset_definition = asset(schedule=None, uri="s3://bucket/object",
group="MLModel", extra={"k": "v"})(
+ example_asset_func_with_valid_arg_as_inlet_asset
+ )
+
+ DAG.assert_called_once_with(dag_id="example_asset_func",
schedule=None, auto_register=True)
+ _AssetMainOperator.assert_called_once_with(
+ task_id="__main__",
+ inlets=[
+ AssetRef(name="inlet_asset_1"),
+ AssetRef(name="inlet_asset_2"),
+ ],
+ outlets=[asset_definition.to_asset()],
+ python_callable=ANY,
+ definition_name="example_asset_func",
+ uri="s3://bucket/object",
+ )
+
+ python_callable =
_AssetMainOperator.call_args.kwargs["python_callable"]
+ assert python_callable ==
example_asset_func_with_valid_arg_as_inlet_asset
+
+
+class Test_AssetMainOperator:
+ def test_determine_kwargs(self,
example_asset_func_with_valid_arg_as_inlet_asset, session):
+ example_asset_model = AssetModel(uri="s3://bucket/object1",
name="inlet_asset_1")
+ asset_definition = asset(schedule=None, uri="s3://bucket/object",
group="MLModel", extra={"k": "v"})(
+ example_asset_func_with_valid_arg_as_inlet_asset
+ )
+
+ ad_asset_model = AssetModel.from_public(asset_definition)
+
+ session.add(example_asset_model)
+ session.add(ad_asset_model)
+ session.add(AssetActive.for_asset(example_asset_model))
+ session.add(AssetActive.for_asset(ad_asset_model))
+ session.commit()
+
+ op = _AssetMainOperator(
+ task_id="__main__",
+ inlets=[AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2")],
+ outlets=[asset_definition],
+ python_callable=example_asset_func_with_valid_arg_as_inlet_asset,
+ definition_name="example_asset_func",
+ )
+ assert op.determine_kwargs(context={"k": "v"}) == {
+ "self": Asset(
+ name="example_asset_func", uri="s3://bucket/object",
group="MLModel", extra={"k": "v"}
+ ),
+ "context": {"k": "v"},
+ "inlet_asset_1": Asset(name="inlet_asset_1",
uri="s3://bucket/object1"),
+ "inlet_asset_2": Asset(name="inlet_asset_2"),
+ }
diff --git a/tests/timetables/test_assets_timetable.py
b/tests/timetables/test_assets_timetable.py
index f461afa31c8..bb942a4a01d 100644
--- a/tests/timetables/test_assets_timetable.py
+++ b/tests/timetables/test_assets_timetable.py
@@ -134,7 +134,7 @@ def test_serialization(asset_timetable:
AssetOrTimeSchedule, monkeypatch: Any) -
"timetable": "mock_serialized_timetable",
"asset_condition": {
"__type": "asset_all",
- "objects": [{"__type": "asset", "uri": "test_asset", "extra": {}}],
+ "objects": [{"__type": "asset", "uri": "test_asset", "name":
"test_asset", "extra": {}}],
},
}
@@ -152,7 +152,7 @@ def test_deserialization(monkeypatch: Any) -> None:
"timetable": "mock_serialized_timetable",
"asset_condition": {
"__type": "asset_all",
- "objects": [{"__type": "asset", "uri": "test_asset", "extra":
None}],
+ "objects": [{"__type": "asset", "name": "test_asset", "uri":
"test_asset", "extra": None}],
},
}
deserialized = AssetOrTimeSchedule.deserialize(mock_serialized_data)