This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new cccc9334e31 Get rid of AssetAliasCondition (#44708)
cccc9334e31 is described below
commit cccc9334e3123423f678c7d237c544b45a76743e
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Dec 6 11:05:55 2024 +0800
Get rid of AssetAliasCondition (#44708)
* Get rid of AssetAliasCondition
Instead of having a separate class for condition evaluation, we can just
use the main AssetAlias class directly. While it technically makes sense
to subclass AssetAny, AssetAliasCondition does not really reuse much of
its implementation, and we can just implement the missing methods
ourselves instead. Whether the class actually is an AssetAny does not
really make much of a difference.
This actually allows us to simplify quite some code (including tests) a
bit since we don't need to rewrap AssetAlias back and forth.
* Fix serialization test
* Does not need this call
* Remove resolution-dependant timetable summary
---
airflow/serialization/serialized_objects.py | 5 +-
airflow/timetables/simple.py | 12 +-
providers/tests/openlineage/plugins/test_utils.py | 14 +-
.../src/airflow/sdk/definitions/asset/__init__.py | 142 +++++++--------------
task_sdk/tests/defintions/test_asset.py | 89 ++++---------
tests/models/test_dag.py | 1 -
tests/timetables/test_assets_timetable.py | 28 +---
7 files changed, 87 insertions(+), 204 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 66b24b0ad40..754f0830f62 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -59,7 +59,6 @@ from airflow.providers_manager import ProvidersManager
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
- AssetAliasCondition,
AssetAll,
AssetAny,
AssetRef,
@@ -1108,9 +1107,7 @@ class DependencyDetector:
)
)
elif isinstance(obj, AssetAlias):
- cond = AssetAliasCondition(name=obj.name, group=obj.group)
-
- deps.extend(cond.iter_dag_dependencies(source=task.dag_id,
target=""))
+ deps.extend(obj.iter_dag_dependencies(source=task.dag_id,
target=""))
return deps
@staticmethod
diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py
index 57eec884b55..20e8085fe0d 100644
--- a/airflow/timetables/simple.py
+++ b/airflow/timetables/simple.py
@@ -19,7 +19,6 @@ from __future__ import annotations
from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING, Any
-from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
from airflow.utils import timezone
@@ -162,20 +161,11 @@ class AssetTriggeredTimetable(_TrivialTimetable):
:meta private:
"""
- UNRESOLVED_ALIAS_SUMMARY = "Unresolved AssetAlias"
-
description: str = "Triggered by assets"
def __init__(self, assets: BaseAsset) -> None:
super().__init__()
self.asset_condition = assets
- if isinstance(self.asset_condition, AssetAlias):
- self.asset_condition =
AssetAliasCondition.from_asset_alias(self.asset_condition)
-
- if not next(self.asset_condition.iter_assets(), False):
- self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY
- else:
- self._summary = "Asset"
@classmethod
def deserialize(cls, data: dict[str, Any]) -> Timetable:
@@ -185,7 +175,7 @@ class AssetTriggeredTimetable(_TrivialTimetable):
@property
def summary(self) -> str:
- return self._summary
+ return "Asset"
def serialize(self) -> dict[str, Any]:
from airflow.serialization.serialized_objects import
encode_asset_condition
diff --git a/providers/tests/openlineage/plugins/test_utils.py
b/providers/tests/openlineage/plugins/test_utils.py
index 3d41e87cf01..046f836bb36 100644
--- a/providers/tests/openlineage/plugins/test_utils.py
+++ b/providers/tests/openlineage/plugins/test_utils.py
@@ -337,7 +337,7 @@ def test_serialize_timetable():
Asset(name="2", uri="test://2", group="test-group"),
AssetAlias(name="example-alias", group="test-group"),
Asset(name="3", uri="test://3", group="test-group"),
- AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
+ AssetAll(AssetAlias("another"), Asset("4")),
)
dag = MagicMock()
dag.timetable = AssetTriggeredTimetable(asset)
@@ -354,7 +354,11 @@ def test_serialize_timetable():
"name": "2",
"group": "test-group",
},
- {"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
+ {
+ "__type": DagAttributeTypes.ASSET_ALIAS,
+ "name": "example-alias",
+ "group": "test-group",
+ },
{
"__type": DagAttributeTypes.ASSET,
"extra": {},
@@ -365,7 +369,11 @@ def test_serialize_timetable():
{
"__type": DagAttributeTypes.ASSET_ALL,
"objects": [
- {"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
+ {
+ "__type": DagAttributeTypes.ASSET_ALIAS,
+ "name": "another",
+ "group": "",
+ },
{
"__type": DagAttributeTypes.ASSET,
"extra": {},
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index ee5ca25c39e..787757637a6 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -17,21 +17,12 @@
from __future__ import annotations
-import functools
import logging
import operator
import os
import urllib.parse
import warnings
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- ClassVar,
- NamedTuple,
- cast,
- overload,
-)
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload
import attrs
@@ -51,7 +42,6 @@ __all__ = [
"Model",
"AssetRef",
"AssetAlias",
- "AssetAliasCondition",
"AssetAll",
"AssetAny",
]
@@ -407,24 +397,61 @@ class AssetAlias(BaseAsset):
name: str = attrs.field(validator=_validate_non_empty_identifier)
group: str = attrs.field(kw_only=True, default="",
validator=_validate_identifier)
+ def _resolve_assets(self) -> list[Asset]:
+ from airflow.models.asset import expand_alias_to_assets
+ from airflow.utils.session import create_session
+
+ with create_session() as session:
+ asset_models = expand_alias_to_assets(self.name, session)
+ return [m.to_public() for m in asset_models]
+
+ def as_expression(self) -> Any:
+ """
+ Serialize the asset alias into its scheduling expression.
+
+ :meta private:
+ """
+ return {"alias": {"name": self.name, "group": self.group}}
+
+ def evaluate(self, statuses: dict[str, bool]) -> bool:
+ return any(x.evaluate(statuses=statuses) for x in
self._resolve_assets())
+
def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
yield self.name, self
- def iter_dag_dependencies(self, *, source: str, target: str) ->
Iterator[DagDependency]:
+ def iter_dag_dependencies(self, *, source: str = "", target: str = "") ->
Iterator[DagDependency]:
"""
- Iterate an asset alias as dag dependency.
+ Iterate an asset alias and its resolved assets as dag dependency.
:meta private:
"""
- yield DagDependency(
- source=source or "asset-alias",
- target=target or "asset-alias",
- dependency_type="asset-alias",
- dependency_id=self.name,
- )
+ if not (resolved_assets := self._resolve_assets()):
+ yield DagDependency(
+ source=source or "asset-alias",
+ target=target or "asset-alias",
+ dependency_type="asset-alias",
+ dependency_id=self.name,
+ )
+ return
+ for asset in resolved_assets:
+ asset_name = asset.name
+ # asset
+ yield DagDependency(
+ source=f"asset-alias:{self.name}" if source else "asset",
+ target="asset" if source else f"asset-alias:{self.name}",
+ dependency_type="asset",
+ dependency_id=asset_name,
+ )
+ # asset alias
+ yield DagDependency(
+ source=source or f"asset:{asset_name}",
+ target=target or f"asset:{asset_name}",
+ dependency_type="asset-alias",
+ dependency_id=self.name,
+ )
class AssetAliasEvent(TypedDict):
@@ -443,11 +470,7 @@ class _AssetBooleanCondition(BaseAsset):
def __init__(self, *objects: BaseAsset) -> None:
if not all(isinstance(o, BaseAsset) for o in objects):
raise TypeError("expect asset expressions in condition")
-
- self.objects = [
- AssetAliasCondition.from_asset_alias(obj) if isinstance(obj,
AssetAlias) else obj
- for obj in objects
- ]
+ self.objects = objects
def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in
self.objects)
@@ -499,77 +522,6 @@ class AssetAny(_AssetBooleanCondition):
return {"any": [o.as_expression() for o in self.objects]}
-class AssetAliasCondition(AssetAny):
- """
- Use to expand AssetAlias as AssetAny of its resolved Assets.
-
- :meta private:
- """
-
- def __init__(self, name: str, group: str) -> None:
- self.name = name
- self.group = group
-
- def __repr__(self) -> str:
- return f"AssetAliasCondition({', '.join(map(str, self.objects))})"
-
- @functools.cached_property
- def objects(self) -> list[BaseAsset]: # type: ignore[override]
- from airflow.models.asset import expand_alias_to_assets
- from airflow.utils.session import create_session
-
- with create_session() as session:
- asset_models = expand_alias_to_assets(self.name, session)
- return [m.to_public() for m in asset_models]
-
- def as_expression(self) -> Any:
- """
- Serialize the asset alias into its scheduling expression.
-
- :meta private:
- """
- return {"alias": {"name": self.name, "group": self.group}}
-
- def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
- yield self.name, AssetAlias(self.name)
-
- def iter_dag_dependencies(self, *, source: str = "", target: str = "") ->
Iterator[DagDependency]:
- """
- Iterate an asset alias and its resolved assets as dag dependency.
-
- :meta private:
- """
- if self.objects:
- for obj in self.objects:
- asset = cast(Asset, obj)
- asset_name = asset.name
- # asset
- yield DagDependency(
- source=f"asset-alias:{self.name}" if source else "asset",
- target="asset" if source else f"asset-alias:{self.name}",
- dependency_type="asset",
- dependency_id=asset_name,
- )
- # asset alias
- yield DagDependency(
- source=source or f"asset:{asset_name}",
- target=target or f"asset:{asset_name}",
- dependency_type="asset-alias",
- dependency_id=self.name,
- )
- else:
- yield DagDependency(
- source=source or "asset-alias",
- target=target or "asset-alias",
- dependency_type="asset-alias",
- dependency_id=self.name,
- )
-
- @staticmethod
- def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition:
- return AssetAliasCondition(name=asset_alias.name,
group=asset_alias.group)
-
-
class AssetAll(_AssetBooleanCondition):
"""Use to combine assets schedule references in an "or" relationship."""
diff --git a/task_sdk/tests/defintions/test_asset.py
b/task_sdk/tests/defintions/test_asset.py
index 6f6d40fdbe9..55439fb1bcd 100644
--- a/task_sdk/tests/defintions/test_asset.py
+++ b/task_sdk/tests/defintions/test_asset.py
@@ -27,7 +27,6 @@ from airflow.operators.empty import EmptyOperator
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
- AssetAliasCondition,
AssetAll,
AssetAny,
BaseAsset,
@@ -487,78 +486,38 @@ def
test_normalize_uri_valid_uri(mock_get_normalized_scheme):
assert asset.normalized_uri == "valid_aip60_uri"
-class FakeSession:
- def __enter__(self):
- return self
-
- def __exit__(self, *args, **kwargs):
- pass
-
-
-FAKE_SESSION = FakeSession()
-
-
-class TestAssetAliasCondition:
+class TestAssetAlias:
@pytest.fixture
- def asset_model(self):
+ def asset(self):
"""Example asset links to asset alias resolved_asset_alias_2."""
- from airflow.models.asset import AssetModel
-
- return AssetModel(
- id=1,
- uri="test://asset1/",
- name="test_name",
- group="asset",
- )
+ return Asset(uri="test://asset1/", name="test_name", group="asset")
@pytest.fixture
def asset_alias_1(self):
"""Example asset alias links to no assets."""
- from airflow.models.asset import AssetAliasModel
-
- return AssetAliasModel(name="test_name", group="test")
+ asset_alias_1 = AssetAlias(name="test_name", group="test")
+ with mock.patch.object(asset_alias_1, "_resolve_assets",
return_value=[]):
+ yield asset_alias_1
@pytest.fixture
- def resolved_asset_alias_2(self, asset_model):
- """Example asset alias links to asset asset_alias_1."""
- from airflow.models.asset import AssetAliasModel
-
- asset_alias_2 = AssetAliasModel(name="test_name_2")
- asset_alias_2.assets.append(asset_model)
- return asset_alias_2
-
- def test_as_expression(self, asset_alias_1, resolved_asset_alias_2):
- for asset_alias in (asset_alias_1, resolved_asset_alias_2):
- cond = AssetAliasCondition.from_asset_alias(asset_alias)
- assert cond.as_expression() == {"alias": {"name":
asset_alias.name, "group": asset_alias.group}}
-
- @mock.patch("airflow.models.asset.expand_alias_to_assets")
- @mock.patch("airflow.utils.session.create_session",
return_value=FAKE_SESSION)
- def test_evalute_empty(
- self, mock_create_session, mock_expand_alias_to_assets, asset_alias_1,
asset_model
- ):
- mock_expand_alias_to_assets.return_value = []
-
- cond = AssetAliasCondition.from_asset_alias(asset_alias_1)
- assert cond.evaluate({asset_model.uri: True}) is False
-
- assert mock_expand_alias_to_assets.mock_calls ==
[mock.call(asset_alias_1.name, FAKE_SESSION)]
- assert mock_create_session.mock_calls == [mock.call()]
-
- @mock.patch("airflow.models.asset.expand_alias_to_assets")
- @mock.patch("airflow.utils.session.create_session",
return_value=FAKE_SESSION)
- def test_evalute_resolved(
- self, mock_create_session, mock_expand_alias_to_assets,
resolved_asset_alias_2, asset_model
- ):
- mock_expand_alias_to_assets.return_value = [asset_model]
-
- cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2)
- assert cond.evaluate({asset_model.uri: True}) is True
-
- assert mock_expand_alias_to_assets.mock_calls == [
- mock.call(resolved_asset_alias_2.name, FAKE_SESSION),
- ]
- assert mock_create_session.mock_calls == [mock.call()]
+ def resolved_asset_alias_2(self, asset):
+ """Example asset alias links to asset."""
+ asset_alias_2 = AssetAlias(name="test_name_2")
+ with mock.patch.object(asset_alias_2, "_resolve_assets",
return_value=[asset]):
+ yield asset_alias_2
+
+ @pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1",
"resolved_asset_alias_2"])
+ def test_as_expression(self, request: pytest.FixtureRequest,
alias_fixture_name):
+ alias = request.getfixturevalue(alias_fixture_name)
+ assert alias.as_expression() == {"alias": {"name": alias.name,
"group": alias.group}}
+
+ def test_evalute_empty(self, asset_alias_1, asset):
+ assert asset_alias_1.evaluate({asset.uri: True}) is False
+ assert asset_alias_1._resolve_assets.mock_calls == [mock.call()]
+
+ def test_evalute_resolved(self, resolved_asset_alias_2, asset):
+ assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True
+ assert resolved_asset_alias_2._resolve_assets.mock_calls ==
[mock.call()]
class TestAssetSubclasses:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 384d76c7548..104e3c90494 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2250,7 +2250,6 @@ class TestDagModel:
# add queue records so we'll need a run
dag_model = dag_maker.dag_model
- asset_model: AssetModel = dag_model.schedule_assets[0]
session.add(AssetDagRunQueue(asset_id=asset_model.id,
target_dag_id=dag_model.dag_id))
session.flush()
query, _ = DagModel.dags_needing_dagruns(session)
diff --git a/tests/timetables/test_assets_timetable.py
b/tests/timetables/test_assets_timetable.py
index d456bf058bc..1dc8e6428d9 100644
--- a/tests/timetables/test_assets_timetable.py
+++ b/tests/timetables/test_assets_timetable.py
@@ -19,24 +19,21 @@
from __future__ import annotations
from collections import defaultdict
-from typing import TYPE_CHECKING, Any
+from typing import Any
import pytest
from pendulum import DateTime
from sqlalchemy.sql import select
-from airflow.models.asset import AssetAliasModel, AssetDagRunQueue,
AssetEvent, AssetModel
+from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
from airflow.models.serialized_dag import SerializedDAG, SerializedDagModel
from airflow.operators.empty import EmptyOperator
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
+from airflow.sdk.definitions.asset import Asset, AssetAll, AssetAny
from airflow.timetables.assets import AssetOrTimeSchedule
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction,
Timetable
from airflow.timetables.simple import AssetTriggeredTimetable
from airflow.utils.types import DagRunType
-if TYPE_CHECKING:
- from sqlalchemy import Session
-
class MockTimetable(Timetable):
"""
@@ -274,25 +271,6 @@ def test_run_ordering_inheritance(asset_timetable:
AssetOrTimeSchedule) -> None:
assert asset_timetable.run_ordering == parent_run_ordering, "run_ordering
does not match the parent class"
[email protected]_test
-def test_summary(session: Session) -> None:
- asset_model = AssetModel(uri="test_asset")
- asset_alias_model = AssetAliasModel(name="test_asset_alias")
- session.add_all([asset_model, asset_alias_model])
- session.commit()
-
- asset_alias = AssetAlias("test_asset_alias")
- table = AssetTriggeredTimetable(asset_alias)
- assert table.summary == "Unresolved AssetAlias"
-
- asset_alias_model.assets.append(asset_model)
- session.add(asset_alias_model)
- session.commit()
-
- table = AssetTriggeredTimetable(asset_alias)
- assert table.summary == "Asset"
-
-
@pytest.mark.db_test
class TestAssetConditionWithTimetable:
@pytest.fixture(autouse=True)