This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cccc9334e31 Get rid of AssetAliasCondition (#44708)
cccc9334e31 is described below

commit cccc9334e3123423f678c7d237c544b45a76743e
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Dec 6 11:05:55 2024 +0800

    Get rid of AssetAliasCondition (#44708)
    
    * Get rid of AssetAliasCondition
    
    Instead of having a separate class for condition evaluation, we can just
    use the main AssetAlias class directly. While it technically makes sense
    to subclass AssetAny, AssetAliasCondition does not really reuse much of
    its implementation, and we can just implement the missing methods
    ourselves instead. Whether the class actually is an AssetAny does not
    really make much of a difference.
    
    This actually allows us to simplify quite some code (including tests) a
    bit since we don't need to rewrap AssetAlias back and forth.
    
    * Fix serialization test
    
    * Does not need this call
    
    * Remove resolution-dependant timetable summary
---
 airflow/serialization/serialized_objects.py        |   5 +-
 airflow/timetables/simple.py                       |  12 +-
 providers/tests/openlineage/plugins/test_utils.py  |  14 +-
 .../src/airflow/sdk/definitions/asset/__init__.py  | 142 +++++++--------------
 task_sdk/tests/defintions/test_asset.py            |  89 ++++---------
 tests/models/test_dag.py                           |   1 -
 tests/timetables/test_assets_timetable.py          |  28 +---
 7 files changed, 87 insertions(+), 204 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index 66b24b0ad40..754f0830f62 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -59,7 +59,6 @@ from airflow.providers_manager import ProvidersManager
 from airflow.sdk.definitions.asset import (
     Asset,
     AssetAlias,
-    AssetAliasCondition,
     AssetAll,
     AssetAny,
     AssetRef,
@@ -1108,9 +1107,7 @@ class DependencyDetector:
                     )
                 )
             elif isinstance(obj, AssetAlias):
-                cond = AssetAliasCondition(name=obj.name, group=obj.group)
-
-                deps.extend(cond.iter_dag_dependencies(source=task.dag_id, 
target=""))
+                deps.extend(obj.iter_dag_dependencies(source=task.dag_id, 
target=""))
         return deps
 
     @staticmethod
diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py
index 57eec884b55..20e8085fe0d 100644
--- a/airflow/timetables/simple.py
+++ b/airflow/timetables/simple.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 from collections.abc import Collection, Sequence
 from typing import TYPE_CHECKING, Any
 
-from airflow.sdk.definitions.asset import AssetAlias, AssetAliasCondition
 from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
 from airflow.utils import timezone
 
@@ -162,20 +161,11 @@ class AssetTriggeredTimetable(_TrivialTimetable):
     :meta private:
     """
 
-    UNRESOLVED_ALIAS_SUMMARY = "Unresolved AssetAlias"
-
     description: str = "Triggered by assets"
 
     def __init__(self, assets: BaseAsset) -> None:
         super().__init__()
         self.asset_condition = assets
-        if isinstance(self.asset_condition, AssetAlias):
-            self.asset_condition = 
AssetAliasCondition.from_asset_alias(self.asset_condition)
-
-        if not next(self.asset_condition.iter_assets(), False):
-            self._summary = AssetTriggeredTimetable.UNRESOLVED_ALIAS_SUMMARY
-        else:
-            self._summary = "Asset"
 
     @classmethod
     def deserialize(cls, data: dict[str, Any]) -> Timetable:
@@ -185,7 +175,7 @@ class AssetTriggeredTimetable(_TrivialTimetable):
 
     @property
     def summary(self) -> str:
-        return self._summary
+        return "Asset"
 
     def serialize(self) -> dict[str, Any]:
         from airflow.serialization.serialized_objects import 
encode_asset_condition
diff --git a/providers/tests/openlineage/plugins/test_utils.py 
b/providers/tests/openlineage/plugins/test_utils.py
index 3d41e87cf01..046f836bb36 100644
--- a/providers/tests/openlineage/plugins/test_utils.py
+++ b/providers/tests/openlineage/plugins/test_utils.py
@@ -337,7 +337,7 @@ def test_serialize_timetable():
         Asset(name="2", uri="test://2", group="test-group"),
         AssetAlias(name="example-alias", group="test-group"),
         Asset(name="3", uri="test://3", group="test-group"),
-        AssetAll(AssetAlias("this-should-not-be-seen"), Asset("4")),
+        AssetAll(AssetAlias("another"), Asset("4")),
     )
     dag = MagicMock()
     dag.timetable = AssetTriggeredTimetable(asset)
@@ -354,7 +354,11 @@ def test_serialize_timetable():
                     "name": "2",
                     "group": "test-group",
                 },
-                {"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
+                {
+                    "__type": DagAttributeTypes.ASSET_ALIAS,
+                    "name": "example-alias",
+                    "group": "test-group",
+                },
                 {
                     "__type": DagAttributeTypes.ASSET,
                     "extra": {},
@@ -365,7 +369,11 @@ def test_serialize_timetable():
                 {
                     "__type": DagAttributeTypes.ASSET_ALL,
                     "objects": [
-                        {"__type": DagAttributeTypes.ASSET_ANY, "objects": []},
+                        {
+                            "__type": DagAttributeTypes.ASSET_ALIAS,
+                            "name": "another",
+                            "group": "",
+                        },
                         {
                             "__type": DagAttributeTypes.ASSET,
                             "extra": {},
diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
index ee5ca25c39e..787757637a6 100644
--- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -17,21 +17,12 @@
 
 from __future__ import annotations
 
-import functools
 import logging
 import operator
 import os
 import urllib.parse
 import warnings
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    ClassVar,
-    NamedTuple,
-    cast,
-    overload,
-)
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, NamedTuple, overload
 
 import attrs
 
@@ -51,7 +42,6 @@ __all__ = [
     "Model",
     "AssetRef",
     "AssetAlias",
-    "AssetAliasCondition",
     "AssetAll",
     "AssetAny",
 ]
@@ -407,24 +397,61 @@ class AssetAlias(BaseAsset):
     name: str = attrs.field(validator=_validate_non_empty_identifier)
     group: str = attrs.field(kw_only=True, default="", 
validator=_validate_identifier)
 
+    def _resolve_assets(self) -> list[Asset]:
+        from airflow.models.asset import expand_alias_to_assets
+        from airflow.utils.session import create_session
+
+        with create_session() as session:
+            asset_models = expand_alias_to_assets(self.name, session)
+        return [m.to_public() for m in asset_models]
+
+    def as_expression(self) -> Any:
+        """
+        Serialize the asset alias into its scheduling expression.
+
+        :meta private:
+        """
+        return {"alias": {"name": self.name, "group": self.group}}
+
+    def evaluate(self, statuses: dict[str, bool]) -> bool:
+        return any(x.evaluate(statuses=statuses) for x in 
self._resolve_assets())
+
     def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
         return iter(())
 
     def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
         yield self.name, self
 
-    def iter_dag_dependencies(self, *, source: str, target: str) -> 
Iterator[DagDependency]:
+    def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> 
Iterator[DagDependency]:
         """
-        Iterate an asset alias as dag dependency.
+        Iterate an asset alias and its resolved assets as dag dependency.
 
         :meta private:
         """
-        yield DagDependency(
-            source=source or "asset-alias",
-            target=target or "asset-alias",
-            dependency_type="asset-alias",
-            dependency_id=self.name,
-        )
+        if not (resolved_assets := self._resolve_assets()):
+            yield DagDependency(
+                source=source or "asset-alias",
+                target=target or "asset-alias",
+                dependency_type="asset-alias",
+                dependency_id=self.name,
+            )
+            return
+        for asset in resolved_assets:
+            asset_name = asset.name
+            # asset
+            yield DagDependency(
+                source=f"asset-alias:{self.name}" if source else "asset",
+                target="asset" if source else f"asset-alias:{self.name}",
+                dependency_type="asset",
+                dependency_id=asset_name,
+            )
+            # asset alias
+            yield DagDependency(
+                source=source or f"asset:{asset_name}",
+                target=target or f"asset:{asset_name}",
+                dependency_type="asset-alias",
+                dependency_id=self.name,
+            )
 
 
 class AssetAliasEvent(TypedDict):
@@ -443,11 +470,7 @@ class _AssetBooleanCondition(BaseAsset):
     def __init__(self, *objects: BaseAsset) -> None:
         if not all(isinstance(o, BaseAsset) for o in objects):
             raise TypeError("expect asset expressions in condition")
-
-        self.objects = [
-            AssetAliasCondition.from_asset_alias(obj) if isinstance(obj, 
AssetAlias) else obj
-            for obj in objects
-        ]
+        self.objects = objects
 
     def evaluate(self, statuses: dict[str, bool]) -> bool:
         return self.agg_func(x.evaluate(statuses=statuses) for x in 
self.objects)
@@ -499,77 +522,6 @@ class AssetAny(_AssetBooleanCondition):
         return {"any": [o.as_expression() for o in self.objects]}
 
 
-class AssetAliasCondition(AssetAny):
-    """
-    Use to expand AssetAlias as AssetAny of its resolved Assets.
-
-    :meta private:
-    """
-
-    def __init__(self, name: str, group: str) -> None:
-        self.name = name
-        self.group = group
-
-    def __repr__(self) -> str:
-        return f"AssetAliasCondition({', '.join(map(str, self.objects))})"
-
-    @functools.cached_property
-    def objects(self) -> list[BaseAsset]:  # type: ignore[override]
-        from airflow.models.asset import expand_alias_to_assets
-        from airflow.utils.session import create_session
-
-        with create_session() as session:
-            asset_models = expand_alias_to_assets(self.name, session)
-        return [m.to_public() for m in asset_models]
-
-    def as_expression(self) -> Any:
-        """
-        Serialize the asset alias into its scheduling expression.
-
-        :meta private:
-        """
-        return {"alias": {"name": self.name, "group": self.group}}
-
-    def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
-        yield self.name, AssetAlias(self.name)
-
-    def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> 
Iterator[DagDependency]:
-        """
-        Iterate an asset alias and its resolved assets as dag dependency.
-
-        :meta private:
-        """
-        if self.objects:
-            for obj in self.objects:
-                asset = cast(Asset, obj)
-                asset_name = asset.name
-                # asset
-                yield DagDependency(
-                    source=f"asset-alias:{self.name}" if source else "asset",
-                    target="asset" if source else f"asset-alias:{self.name}",
-                    dependency_type="asset",
-                    dependency_id=asset_name,
-                )
-                # asset alias
-                yield DagDependency(
-                    source=source or f"asset:{asset_name}",
-                    target=target or f"asset:{asset_name}",
-                    dependency_type="asset-alias",
-                    dependency_id=self.name,
-                )
-        else:
-            yield DagDependency(
-                source=source or "asset-alias",
-                target=target or "asset-alias",
-                dependency_type="asset-alias",
-                dependency_id=self.name,
-            )
-
-    @staticmethod
-    def from_asset_alias(asset_alias: AssetAlias) -> AssetAliasCondition:
-        return AssetAliasCondition(name=asset_alias.name, 
group=asset_alias.group)
-
-
 class AssetAll(_AssetBooleanCondition):
     """Use to combine assets schedule references in an "or" relationship."""
 
diff --git a/task_sdk/tests/defintions/test_asset.py 
b/task_sdk/tests/defintions/test_asset.py
index 6f6d40fdbe9..55439fb1bcd 100644
--- a/task_sdk/tests/defintions/test_asset.py
+++ b/task_sdk/tests/defintions/test_asset.py
@@ -27,7 +27,6 @@ from airflow.operators.empty import EmptyOperator
 from airflow.sdk.definitions.asset import (
     Asset,
     AssetAlias,
-    AssetAliasCondition,
     AssetAll,
     AssetAny,
     BaseAsset,
@@ -487,78 +486,38 @@ def 
test_normalize_uri_valid_uri(mock_get_normalized_scheme):
     assert asset.normalized_uri == "valid_aip60_uri"
 
 
-class FakeSession:
-    def __enter__(self):
-        return self
-
-    def __exit__(self, *args, **kwargs):
-        pass
-
-
-FAKE_SESSION = FakeSession()
-
-
-class TestAssetAliasCondition:
+class TestAssetAlias:
     @pytest.fixture
-    def asset_model(self):
+    def asset(self):
         """Example asset links to asset alias resolved_asset_alias_2."""
-        from airflow.models.asset import AssetModel
-
-        return AssetModel(
-            id=1,
-            uri="test://asset1/",
-            name="test_name",
-            group="asset",
-        )
+        return Asset(uri="test://asset1/", name="test_name", group="asset")
 
     @pytest.fixture
     def asset_alias_1(self):
         """Example asset alias links to no assets."""
-        from airflow.models.asset import AssetAliasModel
-
-        return AssetAliasModel(name="test_name", group="test")
+        asset_alias_1 = AssetAlias(name="test_name", group="test")
+        with mock.patch.object(asset_alias_1, "_resolve_assets", 
return_value=[]):
+            yield asset_alias_1
 
     @pytest.fixture
-    def resolved_asset_alias_2(self, asset_model):
-        """Example asset alias links to asset asset_alias_1."""
-        from airflow.models.asset import AssetAliasModel
-
-        asset_alias_2 = AssetAliasModel(name="test_name_2")
-        asset_alias_2.assets.append(asset_model)
-        return asset_alias_2
-
-    def test_as_expression(self, asset_alias_1, resolved_asset_alias_2):
-        for asset_alias in (asset_alias_1, resolved_asset_alias_2):
-            cond = AssetAliasCondition.from_asset_alias(asset_alias)
-            assert cond.as_expression() == {"alias": {"name": 
asset_alias.name, "group": asset_alias.group}}
-
-    @mock.patch("airflow.models.asset.expand_alias_to_assets")
-    @mock.patch("airflow.utils.session.create_session", 
return_value=FAKE_SESSION)
-    def test_evalute_empty(
-        self, mock_create_session, mock_expand_alias_to_assets, asset_alias_1, 
asset_model
-    ):
-        mock_expand_alias_to_assets.return_value = []
-
-        cond = AssetAliasCondition.from_asset_alias(asset_alias_1)
-        assert cond.evaluate({asset_model.uri: True}) is False
-
-        assert mock_expand_alias_to_assets.mock_calls == 
[mock.call(asset_alias_1.name, FAKE_SESSION)]
-        assert mock_create_session.mock_calls == [mock.call()]
-
-    @mock.patch("airflow.models.asset.expand_alias_to_assets")
-    @mock.patch("airflow.utils.session.create_session", 
return_value=FAKE_SESSION)
-    def test_evalute_resolved(
-        self, mock_create_session, mock_expand_alias_to_assets, 
resolved_asset_alias_2, asset_model
-    ):
-        mock_expand_alias_to_assets.return_value = [asset_model]
-
-        cond = AssetAliasCondition.from_asset_alias(resolved_asset_alias_2)
-        assert cond.evaluate({asset_model.uri: True}) is True
-
-        assert mock_expand_alias_to_assets.mock_calls == [
-            mock.call(resolved_asset_alias_2.name, FAKE_SESSION),
-        ]
-        assert mock_create_session.mock_calls == [mock.call()]
+    def resolved_asset_alias_2(self, asset):
+        """Example asset alias links to asset."""
+        asset_alias_2 = AssetAlias(name="test_name_2")
+        with mock.patch.object(asset_alias_2, "_resolve_assets", 
return_value=[asset]):
+            yield asset_alias_2
+
+    @pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1", 
"resolved_asset_alias_2"])
+    def test_as_expression(self, request: pytest.FixtureRequest, 
alias_fixture_name):
+        alias = request.getfixturevalue(alias_fixture_name)
+        assert alias.as_expression() == {"alias": {"name": alias.name, 
"group": alias.group}}
+
+    def test_evalute_empty(self, asset_alias_1, asset):
+        assert asset_alias_1.evaluate({asset.uri: True}) is False
+        assert asset_alias_1._resolve_assets.mock_calls == [mock.call()]
+
+    def test_evalute_resolved(self, resolved_asset_alias_2, asset):
+        assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True
+        assert resolved_asset_alias_2._resolve_assets.mock_calls == 
[mock.call()]
 
 
 class TestAssetSubclasses:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 384d76c7548..104e3c90494 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -2250,7 +2250,6 @@ class TestDagModel:
 
         # add queue records so we'll need a run
         dag_model = dag_maker.dag_model
-        asset_model: AssetModel = dag_model.schedule_assets[0]
         session.add(AssetDagRunQueue(asset_id=asset_model.id, 
target_dag_id=dag_model.dag_id))
         session.flush()
         query, _ = DagModel.dags_needing_dagruns(session)
diff --git a/tests/timetables/test_assets_timetable.py 
b/tests/timetables/test_assets_timetable.py
index d456bf058bc..1dc8e6428d9 100644
--- a/tests/timetables/test_assets_timetable.py
+++ b/tests/timetables/test_assets_timetable.py
@@ -19,24 +19,21 @@
 from __future__ import annotations
 
 from collections import defaultdict
-from typing import TYPE_CHECKING, Any
+from typing import Any
 
 import pytest
 from pendulum import DateTime
 from sqlalchemy.sql import select
 
-from airflow.models.asset import AssetAliasModel, AssetDagRunQueue, 
AssetEvent, AssetModel
+from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel
 from airflow.models.serialized_dag import SerializedDAG, SerializedDagModel
 from airflow.operators.empty import EmptyOperator
-from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
+from airflow.sdk.definitions.asset import Asset, AssetAll, AssetAny
 from airflow.timetables.assets import AssetOrTimeSchedule
 from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, 
Timetable
 from airflow.timetables.simple import AssetTriggeredTimetable
 from airflow.utils.types import DagRunType
 
-if TYPE_CHECKING:
-    from sqlalchemy import Session
-
 
 class MockTimetable(Timetable):
     """
@@ -274,25 +271,6 @@ def test_run_ordering_inheritance(asset_timetable: 
AssetOrTimeSchedule) -> None:
     assert asset_timetable.run_ordering == parent_run_ordering, "run_ordering 
does not match the parent class"
 
 
[email protected]_test
-def test_summary(session: Session) -> None:
-    asset_model = AssetModel(uri="test_asset")
-    asset_alias_model = AssetAliasModel(name="test_asset_alias")
-    session.add_all([asset_model, asset_alias_model])
-    session.commit()
-
-    asset_alias = AssetAlias("test_asset_alias")
-    table = AssetTriggeredTimetable(asset_alias)
-    assert table.summary == "Unresolved AssetAlias"
-
-    asset_alias_model.assets.append(asset_model)
-    session.add(asset_alias_model)
-    session.commit()
-
-    table = AssetTriggeredTimetable(asset_alias)
-    assert table.summary == "Asset"
-
-
 @pytest.mark.db_test
 class TestAssetConditionWithTimetable:
     @pytest.fixture(autouse=True)

Reply via email to