This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1af5b93157a Add 'name' and 'group' to public Asset class (#42812)
1af5b93157a is described below
commit 1af5b93157ab8310e8b61c7ef7f923c03a69aa2e
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Oct 16 10:23:29 2024 +0800
Add 'name' and 'group' to public Asset class (#42812)
---
airflow/assets/__init__.py | 73 ++++++++++++++++++++------
airflow/models/asset.py | 20 ++++---
tests/assets/{tests_asset.py => test_asset.py} | 41 ++++++++++++++-
tests/models/test_dag.py | 6 +--
tests/serialization/test_serialized_objects.py | 2 +-
5 files changed, 115 insertions(+), 27 deletions(-)
diff --git a/airflow/assets/__init__.py b/airflow/assets/__init__.py
index e11b9c49df3..15805418472 100644
--- a/airflow/assets/__init__.py
+++ b/airflow/assets/__init__.py
@@ -21,7 +21,7 @@ import logging
import os
import urllib.parse
import warnings
-from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator,
cast
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator,
cast, overload
import attr
from sqlalchemy import select
@@ -74,12 +74,6 @@ def _sanitize_uri(uri: str) -> str:
This checks for URI validity, and normalizes the URI if needed. A fully
normalized URI is returned.
"""
- if not uri:
- raise ValueError("Asset URI cannot be empty")
- if uri.isspace():
- raise ValueError("Asset URI cannot be just whitespace")
- if not uri.isascii():
- raise ValueError("Asset URI must only consist of ASCII characters")
parsed = urllib.parse.urlsplit(uri)
if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
return uri
@@ -126,6 +120,24 @@ def _sanitize_uri(uri: str) -> str:
return urllib.parse.urlunsplit(parsed)
+def _validate_identifier(instance, attribute, value):
+ if not isinstance(value, str):
+ raise ValueError(f"{type(instance).__name__} {attribute.name} must be
a string")
+ if len(value) > 1500:
+ raise ValueError(f"{type(instance).__name__} {attribute.name} cannot
exceed 1500 characters")
+ if value.isspace():
+ raise ValueError(f"{type(instance).__name__} {attribute.name} cannot
be just whitespace")
+ if not value.isascii():
+ raise ValueError(f"{type(instance).__name__} {attribute.name} must
only consist of ASCII characters")
+ return value
+
+
+def _validate_non_empty_identifier(instance, attribute, value):
+ if not _validate_identifier(instance, attribute, value):
+ raise ValueError(f"{type(instance).__name__} {attribute.name} cannot
be empty")
+ return value
+
+
def extract_event_key(value: str | Asset | AssetAlias) -> str:
"""
Extract the key of an inlet or an outlet event.
@@ -157,7 +169,7 @@ def expand_alias_to_assets(alias: str | AssetAlias, *,
session: Session = NEW_SE
select(AssetAliasModel).where(AssetAliasModel.name ==
alias_name).limit(1)
)
if asset_alias_obj:
- return [Asset(uri=asset.uri, extra=asset.extra) for asset in
asset_alias_obj.datasets]
+ return [asset.to_public() for asset in asset_alias_obj.datasets]
return []
@@ -214,7 +226,7 @@ class BaseAsset:
class AssetAlias(BaseAsset):
"""A represeation of asset alias which is used to create asset during the
runtime."""
- name: str
+ name: str = attr.field(validator=_validate_non_empty_identifier)
def iter_assets(self) -> Iterator[tuple[str, Asset]]:
return iter(())
@@ -256,18 +268,49 @@ def _set_extra_default(extra: dict | None) -> dict:
return extra
[email protected](unsafe_hash=False)
[email protected](init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data dependencies between workflows."""
- uri: str = attr.field(
- converter=_sanitize_uri,
- validator=[attr.validators.min_len(1), attr.validators.max_len(1500)],
- )
- extra: dict[str, Any] = attr.field(factory=dict,
converter=_set_extra_default)
+ name: str = attr.field()
+ uri: str = attr.field()
+ group: str = attr.field()
+ extra: dict[str, Any] = attr.field()
__version__: ClassVar[int] = 1
+ @overload
+ def __init__(self, name: str, uri: str, *, group: str = "", extra: dict |
None = None) -> None:
+ """Canonical; both name and uri are provided."""
+
+ @overload
+ def __init__(self, name: str, *, group: str = "", extra: dict | None =
None) -> None:
+ """It's possible to only provide the name, either by keyword or as the
only positional argument."""
+
+ @overload
+ def __init__(self, *, uri: str, group: str = "", extra: dict | None =
None) -> None:
+ """It's possible to only provide the URI as a keyword argument."""
+
+ def __init__(
+ self,
+ name: str | None = None,
+ uri: str | None = None,
+ *,
+ group: str = "",
+ extra: dict | None = None,
+ ) -> None:
+ if name is None and uri is None:
+ raise TypeError("Asset() requires either 'name' or 'uri'")
+ elif name is None:
+ name = uri
+ elif uri is None:
+ uri = name
+ fields = attr.fields_dict(Asset)
+ self.name = _validate_non_empty_identifier(self, fields["name"], name)
+ self.uri = _sanitize_uri(_validate_non_empty_identifier(self,
fields["uri"], uri))
+ self.group = _validate_identifier(self, fields["group"], group)
+ self.extra = _set_extra_default(extra)
+
def __fspath__(self) -> str:
return self.uri
diff --git a/airflow/models/asset.py b/airflow/models/asset.py
index d5ca0ea513f..b565c9a100e 100644
--- a/airflow/models/asset.py
+++ b/airflow/models/asset.py
@@ -205,17 +205,23 @@ class AssetModel(Base):
@classmethod
def from_public(cls, obj: Asset) -> AssetModel:
- return cls(uri=obj.uri, extra=obj.extra)
-
- def __init__(self, uri: str, **kwargs):
+ return cls(name=obj.name, uri=obj.uri, group=obj.group,
extra=obj.extra)
+
+ def __init__(self, name: str = "", uri: str = "", **kwargs):
+ if not name and not uri:
+ raise TypeError("must provide either 'name' or 'uri'")
+ elif not name:
+ name = uri
+ elif not uri:
+ uri = name
try:
uri.encode("ascii")
except UnicodeEncodeError:
- raise ValueError("URI must be ascii")
+ raise ValueError("URI must be ascii") from None
parsed = urlsplit(uri)
if parsed.scheme and parsed.scheme.lower() == "airflow":
- raise ValueError("Scheme `airflow` is reserved.")
- super().__init__(name=uri, uri=uri, **kwargs)
+ raise ValueError("Scheme 'airflow' is reserved.")
+ super().__init__(name=name, uri=uri, **kwargs)
def __eq__(self, other):
if isinstance(other, (self.__class__, Asset)):
@@ -229,7 +235,7 @@ class AssetModel(Base):
return f"{self.__class__.__name__}(uri={self.uri!r},
extra={self.extra!r})"
def to_public(self) -> Asset:
- return Asset(uri=self.uri, extra=self.extra)
+ return Asset(name=self.name, uri=self.uri, group=self.group,
extra=self.extra)
class AssetActive(Base):
diff --git a/tests/assets/tests_asset.py b/tests/assets/test_asset.py
similarity index 95%
rename from tests/assets/tests_asset.py
rename to tests/assets/test_asset.py
index 0bcfb83e88a..4d3466b90c1 100644
--- a/tests/assets/tests_asset.py
+++ b/tests/assets/test_asset.py
@@ -50,12 +50,26 @@ def clear_assets():
clear_db_assets()
[email protected](
+ ["name"],
+ [
+ pytest.param("", id="empty"),
+ pytest.param("\n\t", id="whitespace"),
+ pytest.param("a" * 1501, id="too_long"),
+ pytest.param("😊", id="non-ascii"),
+ ],
+)
+def test_invalid_names(name):
+ with pytest.raises(ValueError):
+ Asset(name=name)
+
+
@pytest.mark.parametrize(
["uri"],
[
pytest.param("", id="empty"),
pytest.param("\n\t", id="whitespace"),
- pytest.param("a" * 3001, id="too_long"),
+ pytest.param("a" * 1501, id="too_long"),
pytest.param("airflow://xcom/dag/task", id="reserved_scheme"),
pytest.param("😊", id="non-ascii"),
],
@@ -65,6 +79,31 @@ def test_invalid_uris(uri):
Asset(uri=uri)
+def test_only_name():
+ asset = Asset(name="foobar")
+ assert asset.name == "foobar"
+ assert asset.uri == "foobar"
+
+
+def test_only_uri():
+ asset = Asset(uri="s3://bucket/key/path")
+ assert asset.name == "s3://bucket/key/path"
+ assert asset.uri == "s3://bucket/key/path"
+
+
[email protected]("arg", ["foobar", "s3://bucket/key/path"])
+def test_only_posarg(arg):
+ asset = Asset(arg)
+ assert asset.name == arg
+ assert asset.uri == arg
+
+
+def test_both_name_and_uri():
+ asset = Asset("foobar", "s3://bucket/key/path")
+ assert asset.name == "foobar"
+ assert asset.uri == "s3://bucket/key/path"
+
+
@pytest.mark.parametrize(
"uri, normalized",
[
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index b439487b016..e6f8042253e 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2715,10 +2715,10 @@ class TestDagModel:
dag = DAG(
dag_id="test_dag_asset_expression",
schedule=AssetAny(
- Asset("s3://dag1/output_1.txt", {"hi": "bye"}),
+ Asset("s3://dag1/output_1.txt", extra={"hi": "bye"}),
AssetAll(
- Asset("s3://dag2/output_1.txt", {"hi": "bye"}),
- Asset("s3://dag3/output_3.txt", {"hi": "bye"}),
+ Asset("s3://dag2/output_1.txt", extra={"hi": "bye"}),
+ Asset("s3://dag3/output_3.txt", extra={"hi": "bye"}),
),
AssetAlias(name="test_name"),
),
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 5d35278d89b..56a31d4d38b 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -327,7 +327,7 @@ sample_objects = {
id=1, filename="test_file", elasticsearch_id="test_id",
created_at=datetime.now()
),
DagTagPydantic: DagTag(),
- AssetPydantic: Asset("uri", {}),
+ AssetPydantic: Asset("uri", extra={}),
AssetEventPydantic: AssetEvent(),
}