This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cd78078e4e4 Move asset evaluation logic out of SDK (#47484)
cd78078e4e4 is described below

commit cd78078e4e4f94889d8930853229f054f06538d0
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Mar 7 19:47:30 2025 +0800

    Move asset evaluation logic out of SDK (#47484)
    
    Asset evaluation is only done in the scheduler, and requires the
    database in various cases. It is better to split it out into a dedicated
    class in Airflow core.
    
    Dependency resolution requires database calls for non-direct asset
    references, but we don't want that to happen in the SDK. This removes
    the eager resolution code so all asset aliases and refs are no longer
    resolved, but one only keeps a marker for itself.
    
    An additional PR will be submitted later to do the resolution. This is a
    part of the asset UI project that involves changes in the API anyway.
---
 airflow/assets/evaluation.py                       |  78 ++++++++
 airflow/models/asset.py                            |   6 +-
 airflow/models/dag.py                              |   5 +-
 .../src/airflow/sdk/definitions/asset/__init__.py  | 102 +++--------
 .../airflow/sdk/definitions/asset/decorators.py    |   9 +-
 task-sdk/tests/task_sdk/definitions/test_asset.py  | 156 +---------------
 tests/assets/test_evaluation.py                    | 199 +++++++++++++++++++++
 tests/models/test_asset.py                         |   4 +-
 tests/timetables/test_assets_timetable.py          |   5 +-
 9 files changed, 319 insertions(+), 245 deletions(-)

diff --git a/airflow/assets/evaluation.py b/airflow/assets/evaluation.py
new file mode 100644
index 00000000000..b1a877d4d54
--- /dev/null
+++ b/airflow/assets/evaluation.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import functools
+from typing import TYPE_CHECKING
+
+import attrs
+
+from airflow.models.asset import expand_alias_to_assets, resolve_ref_to_asset
+from airflow.sdk.definitions.asset import (
+    Asset,
+    AssetAlias,
+    AssetBooleanCondition,
+    AssetRef,
+    AssetUniqueKey,
+    BaseAsset,
+)
+from airflow.sdk.definitions.asset.decorators import MultiAssetDefinition
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
+
[email protected]
+class AssetEvaluator:
+    """Evaluates whether an asset-like object has been satisfied."""
+
+    _session: Session
+
+    def _resolve_asset_ref(self, o: AssetRef) -> Asset | None:
+        asset = resolve_ref_to_asset(**attrs.asdict(o), session=self._session)
+        return asset.to_public() if asset else None
+
+    def _resolve_asset_alias(self, o: AssetAlias) -> list[Asset]:
+        asset_models = expand_alias_to_assets(o.name, session=self._session)
+        return [m.to_public() for m in asset_models]
+
+    @functools.singledispatchmethod
+    def run(self, o: BaseAsset, statuses: dict[AssetUniqueKey, bool]) -> bool:
+        raise NotImplementedError(f"can not evaluate {o!r}")
+
+    @run.register
+    def _(self, o: Asset, statuses: dict[AssetUniqueKey, bool]) -> bool:
+        return statuses.get(AssetUniqueKey.from_asset(o), False)
+
+    @run.register
+    def _(self, o: AssetRef, statuses: dict[AssetUniqueKey, bool]) -> bool:
+        if asset := self._resolve_asset_ref(o):
+            return self.run(asset, statuses)
+        return False
+
+    @run.register
+    def _(self, o: AssetAlias, statuses: dict[AssetUniqueKey, bool]) -> bool:
+        return any(self.run(x, statuses) for x in self._resolve_asset_alias(o))
+
+    @run.register
+    def _(self, o: AssetBooleanCondition, statuses: dict[AssetUniqueKey, 
bool]) -> bool:
+        return o.agg_func(self.run(x, statuses) for x in o.objects)
+
+    @run.register
+    def _(self, o: MultiAssetDefinition, statuses: dict[AssetUniqueKey, bool]) 
-> bool:
+        return all(self.run(x, statuses) for x in o.iter_outlets())
diff --git a/airflow/models/asset.py b/airflow/models/asset.py
index 212a0b3a84c..5ec0b2e977c 100644
--- a/airflow/models/asset.py
+++ b/airflow/models/asset.py
@@ -70,14 +70,14 @@ def fetch_active_assets_by_uri(uris: Iterable[str], 
session: Session) -> dict[st
     }
 
 
-def expand_alias_to_assets(alias_name: str, session: Session) -> 
Iterable[AssetModel]:
+def expand_alias_to_assets(alias_name: str, *, session: Session) -> 
Iterable[AssetModel]:
     """Expand asset alias to resolved assets."""
     asset_alias_obj = session.scalar(
         select(AssetAliasModel).where(AssetAliasModel.name == 
alias_name).limit(1)
     )
     if asset_alias_obj:
-        return list(asset_alias_obj.assets)
-    return []
+        return iter(asset_alias_obj.assets)
+    return iter(())
 
 
 def resolve_ref_to_asset(
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 16cbf28d647..cbb73e45f84 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -68,6 +68,7 @@ from sqlalchemy.orm import backref, load_only, relationship
 from sqlalchemy.sql import Select, expression
 
 from airflow import settings, utils
+from airflow.assets.evaluation import AssetEvaluator
 from airflow.configuration import conf as airflow_conf, secrets_backend_list
 from airflow.exceptions import (
     AirflowException,
@@ -2323,12 +2324,14 @@ class DagModel(Base):
         """
         from airflow.models.serialized_dag import SerializedDagModel
 
+        evaluator = AssetEvaluator(session)
+
         def dag_ready(dag_id: str, cond: BaseAsset, statuses: 
dict[AssetUniqueKey, bool]) -> bool | None:
             # if dag was serialized before 2.9 and we *just* upgraded,
             # we may be dealing with old version.  In that case,
             # just wait for the dag to be reserialized.
             try:
-                return cond.evaluate(statuses, session=session)
+                return evaluator.run(cond, statuses)
             except AttributeError:
                 log.warning("dag '%s' has old serialization; skipping DAG run 
creation.", dag_id)
                 return None
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py 
b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
index 11b7a1b2dae..94559988c9d 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -17,7 +17,6 @@
 
 from __future__ import annotations
 
-import contextlib
 import logging
 import operator
 import os
@@ -34,8 +33,6 @@ if TYPE_CHECKING:
     from collections.abc import Iterable, Iterator
     from urllib.parse import SplitResult
 
-    from sqlalchemy.orm import Session
-
     from airflow.models.asset import AssetModel
     from airflow.serialization.serialized_objects import SerializedAssetWatcher
     from airflow.triggers.base import BaseEventTrigger
@@ -233,9 +230,6 @@ class BaseAsset:
         """
         raise NotImplementedError
 
-    def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: 
Session | None = None) -> bool:
-        raise NotImplementedError
-
     def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
         raise NotImplementedError
 
@@ -442,9 +436,6 @@ class Asset(os.PathLike, BaseAsset):
     def iter_asset_refs(self) -> Iterator[AssetRef]:
         return iter(())
 
-    def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: 
Session | None = None) -> bool:
-        return statuses.get(AssetUniqueKey.from_asset(self), False)
-
     def iter_dag_dependencies(self, *, source: str, target: str) -> 
Iterator[DagDependency]:
         """
         Iterate an asset as dag dependency.
@@ -489,35 +480,14 @@ class AssetRef(BaseAsset, AttrsInstance):
     def iter_asset_refs(self) -> Iterator[AssetRef]:
         yield self
 
-    def _resolve_asset(self, *, session: Session | None = None) -> Asset | 
None:
-        from airflow.models.asset import resolve_ref_to_asset
-        from airflow.utils.session import create_session
-
-        with contextlib.nullcontext(session) if session else create_session() 
as session:
-            asset = resolve_ref_to_asset(**attrs.asdict(self), session=session)
-        return asset.to_public() if asset else None
-
-    def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: 
Session | None = None) -> bool:
-        if asset := self._resolve_asset(session=session):
-            return asset.evaluate(statuses=statuses, session=session)
-        return False
-
     def iter_dag_dependencies(self, *, source: str = "", target: str = "") -> 
Iterator[DagDependency]:
         (dependency_id,) = attrs.astuple(self)
-        if asset := self._resolve_asset():
-            yield DagDependency(
-                source=f"asset-ref:{dependency_id}" if source else "asset",
-                target="asset" if source else f"asset-ref:{dependency_id}",
-                dependency_type="asset",
-                dependency_id=asset.name,
-            )
-        else:
-            yield DagDependency(
-                source=source or "asset-ref",
-                target=target or "asset-ref",
-                dependency_type="asset-ref",
-                dependency_id=dependency_id,
-            )
+        yield DagDependency(
+            source=source or "asset-ref",
+            target=target or "asset-ref",
+            dependency_type="asset-ref",
+            dependency_id=dependency_id,
+        )
 
 
 @attrs.define(hash=True)
@@ -553,14 +523,6 @@ class AssetAlias(BaseAsset):
     name: str = attrs.field(validator=_validate_non_empty_identifier)
     group: str = attrs.field(kw_only=True, default="asset", 
validator=_validate_identifier)
 
-    def _resolve_assets(self, session: Session | None = None) -> list[Asset]:
-        from airflow.models.asset import expand_alias_to_assets
-        from airflow.utils.session import create_session
-
-        with contextlib.nullcontext(session) if session else create_session() 
as session:
-            asset_models = expand_alias_to_assets(self.name, session)
-        return [m.to_public() for m in asset_models]
-
     def as_expression(self) -> Any:
         """
         Serialize the asset alias into its scheduling expression.
@@ -569,9 +531,6 @@ class AssetAlias(BaseAsset):
         """
         return {"alias": {"name": self.name, "group": self.group}}
 
-    def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: 
Session | None = None) -> bool:
-        return any(x.evaluate(statuses=statuses, session=session) for x in 
self._resolve_assets(session))
-
     def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
         return iter(())
 
@@ -587,34 +546,20 @@ class AssetAlias(BaseAsset):
 
         :meta private:
         """
-        if not (resolved_assets := self._resolve_assets()):
-            yield DagDependency(
-                source=source or "asset-alias",
-                target=target or "asset-alias",
-                dependency_type="asset-alias",
-                dependency_id=self.name,
-            )
-            return
-        for asset in resolved_assets:
-            asset_name = asset.name
-            # asset
-            yield DagDependency(
-                source=f"asset-alias:{self.name}" if source else "asset",
-                target="asset" if source else f"asset-alias:{self.name}",
-                dependency_type="asset",
-                dependency_id=asset_name,
-            )
-            # asset alias
-            yield DagDependency(
-                source=source or f"asset:{asset_name}",
-                target=target or f"asset:{asset_name}",
-                dependency_type="asset-alias",
-                dependency_id=self.name,
-            )
-
-
-class _AssetBooleanCondition(BaseAsset):
-    """Base class for asset boolean logic."""
+        yield DagDependency(
+            source=source or "asset-alias",
+            target=target or "asset-alias",
+            dependency_type="asset-alias",
+            dependency_id=self.name,
+        )
+
+
+class AssetBooleanCondition(BaseAsset):
+    """
+    Base class for asset boolean logic.
+
+    :meta private:
+    """
 
     agg_func: Callable[[Iterable], bool]
 
@@ -623,9 +568,6 @@ class _AssetBooleanCondition(BaseAsset):
             raise TypeError("expect asset expressions in condition")
         self.objects = objects
 
-    def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: 
Session | None = None) -> bool:
-        return self.agg_func(x.evaluate(statuses=statuses, session=session) 
for x in self.objects)
-
     def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
         for o in self.objects:
             yield from o.iter_assets()
@@ -648,7 +590,7 @@ class _AssetBooleanCondition(BaseAsset):
             yield from obj.iter_dag_dependencies(source=source, target=target)
 
 
-class AssetAny(_AssetBooleanCondition):
+class AssetAny(AssetBooleanCondition):
     """Use to combine assets schedule references in an "or" relationship."""
 
     agg_func = any
@@ -671,7 +613,7 @@ class AssetAny(_AssetBooleanCondition):
         return {"any": [o.as_expression() for o in self.objects]}
 
 
-class AssetAll(_AssetBooleanCondition):
+class AssetAll(AssetBooleanCondition):
     """Use to combine assets schedule references in an "and" relationship."""
 
     agg_func = all
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py 
b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
index c70a224858b..77ab57074bb 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -28,8 +28,6 @@ from airflow.sdk.definitions.asset import Asset, 
AssetNameRef, AssetRef, BaseAss
 if TYPE_CHECKING:
     from collections.abc import Callable, Collection, Iterator, Mapping
 
-    from sqlalchemy.orm import Session
-
     from airflow.io.path import ObjectStoragePath
     from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
     from airflow.sdk.definitions.dag import DAG, DagStateChangeCallback, 
ScheduleArg
@@ -122,9 +120,6 @@ class MultiAssetDefinition(BaseAsset):
         with self._source.create_dag(dag_id=self._function.__name__):
             _AssetMainOperator.from_definition(self)
 
-    def evaluate(self, statuses: dict[AssetUniqueKey, bool], *, session: 
Session | None = None) -> bool:
-        return all(o.evaluate(statuses=statuses, session=session) for o in 
self._source.outlets)
-
     def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
         for o in self._source.outlets:
             yield from o.iter_assets()
@@ -141,6 +136,10 @@ class MultiAssetDefinition(BaseAsset):
         for obj in self._source.outlets:
             yield from obj.iter_dag_dependencies(source=source, target=target)
 
+    def iter_outlets(self) -> Iterator[BaseAsset]:
+        """For asset evaluation in the scheduler."""
+        return iter(self._source.outlets)
+
 
 @attrs.define(kw_only=True)
 class _DAGFactory:
diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py 
b/task-sdk/tests/task_sdk/definitions/test_asset.py
index 767cd9e1be7..1637cffac61 100644
--- a/task-sdk/tests/task_sdk/definitions/test_asset.py
+++ b/task-sdk/tests/task_sdk/definitions/test_asset.py
@@ -37,7 +37,7 @@ from airflow.sdk.definitions.asset import (
     _sanitize_uri,
 )
 from airflow.sdk.definitions.dag import DAG
-from airflow.serialization.serialized_objects import BaseSerialization, 
SerializedDAG
+from airflow.serialization.serialized_objects import SerializedDAG
 
 ASSET_MODULE_PATH = "airflow.sdk.definitions.asset"
 
@@ -185,18 +185,6 @@ def test_asset_iter_asset_aliases():
     ]
 
 
[email protected](
-    "statuses, result",
-    [
-        ({AssetUniqueKey.from_asset(asset1): True}, True),
-        ({AssetUniqueKey.from_asset(asset1): False}, False),
-        ({}, False),
-    ],
-)
-def test_asset_evaluate(statuses, result):
-    assert asset1.evaluate(statuses) is result
-
-
 def test_asset_any_operations():
     result_or = (asset1 | asset2) | asset3
     assert isinstance(result_or, AssetAny)
@@ -212,116 +200,6 @@ def test_asset_all_operations():
     assert isinstance(result_and, AssetAll)
 
 
[email protected](
-    "condition, statuses, result",
-    [
-        (
-            AssetAny(asset1, asset2),
-            {AssetUniqueKey.from_asset(asset1): False, 
AssetUniqueKey.from_asset(asset2): True},
-            True,
-        ),
-        (
-            AssetAll(asset1, asset2),
-            {AssetUniqueKey.from_asset(asset1): True, 
AssetUniqueKey.from_asset(asset2): False},
-            False,
-        ),
-    ],
-)
-def test_assset_boolean_condition_evaluate_iter(condition, statuses, result):
-    """
-    Tests _AssetBooleanCondition's evaluate and iter_assets methods through 
AssetAny and AssetAll.
-    Ensures AssetAny evaluate returns True with any true condition, AssetAll 
evaluate returns False if
-    any condition is false, and both classes correctly iterate over assets 
without duplication.
-    """
-    assert condition.evaluate(statuses) is result
-    assert dict(condition.iter_assets()) == {
-        AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1,
-        AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2,
-    }
-
-
[email protected](
-    "inputs, scenario, expected",
-    [
-        # Scenarios for AssetAny
-        ((True, True, True), "any", True),
-        ((True, True, False), "any", True),
-        ((True, False, True), "any", True),
-        ((True, False, False), "any", True),
-        ((False, False, True), "any", True),
-        ((False, True, False), "any", True),
-        ((False, True, True), "any", True),
-        ((False, False, False), "any", False),
-        # Scenarios for AssetAll
-        ((True, True, True), "all", True),
-        ((True, True, False), "all", False),
-        ((True, False, True), "all", False),
-        ((True, False, False), "all", False),
-        ((False, False, True), "all", False),
-        ((False, True, False), "all", False),
-        ((False, True, True), "all", False),
-        ((False, False, False), "all", False),
-    ],
-)
-def test_asset_logical_conditions_evaluation_and_serialization(inputs, 
scenario, expected):
-    class_ = AssetAny if scenario == "any" else AssetAll
-    assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in 
range(123, 126)]
-    condition = class_(*assets)
-
-    statuses = {AssetUniqueKey.from_asset(asset): status for asset, status in 
zip(assets, inputs)}
-    assert (
-        condition.evaluate(statuses) == expected
-    ), f"Condition evaluation failed for inputs {inputs} and scenario 
'{scenario}'"
-
-    # Serialize and deserialize the condition to test persistence
-    serialized = BaseSerialization.serialize(condition)
-    deserialized = BaseSerialization.deserialize(serialized)
-    assert deserialized.evaluate(statuses) == expected, "Serialization 
round-trip failed"
-
-
[email protected](
-    "status_values, expected_evaluation",
-    [
-        (
-            (False, True, True),
-            False,
-        ),  # AssetAll requires all conditions to be True, but asset1 is False
-        ((True, True, True), True),  # All conditions are True
-        (
-            (True, False, True),
-            True,
-        ),  # asset1 is True, and AssetAny condition (asset2 or asset3 being 
True) is met
-        (
-            (True, False, False),
-            False,
-        ),  # asset1 is True, but neither asset2 nor asset3 meet the AssetAny 
condition
-    ],
-)
-def test_nested_asset_conditions_with_serialization(status_values, 
expected_evaluation):
-    # Define assets
-    asset1 = Asset(uri="s3://abc/123")
-    asset2 = Asset(uri="s3://abc/124")
-    asset3 = Asset(uri="s3://abc/125")
-
-    # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 
and asset3
-    nested_condition = AssetAll(asset1, AssetAny(asset2, asset3))
-
-    statuses = {
-        AssetUniqueKey.from_asset(asset1): status_values[0],
-        AssetUniqueKey.from_asset(asset2): status_values[1],
-        AssetUniqueKey.from_asset(asset3): status_values[2],
-    }
-
-    assert nested_condition.evaluate(statuses) == expected_evaluation, 
"Initial evaluation mismatch"
-
-    serialized_condition = BaseSerialization.serialize(nested_condition)
-    deserialized_condition = 
BaseSerialization.deserialize(serialized_condition)
-
-    assert (
-        deserialized_condition.evaluate(statuses) == expected_evaluation
-    ), "Post-serialization evaluation mismatch"
-
-
 @pytest.fixture
 def create_test_assets():
     """Fixture to create test assets and corresponding models."""
@@ -500,38 +378,10 @@ def 
test_normalize_uri_valid_uri(mock_get_normalized_scheme):
 
 
 class TestAssetAlias:
-    @pytest.fixture
-    def asset(self):
-        """Example asset links to asset alias resolved_asset_alias_2."""
-        return Asset(uri="test://asset1/", name="test_name", group="asset")
-
-    @pytest.fixture
-    def asset_alias_1(self):
-        """Example asset alias links to no assets."""
-        asset_alias_1 = AssetAlias(name="test_name", group="test")
-        with mock.patch.object(asset_alias_1, "_resolve_assets", 
return_value=[]):
-            yield asset_alias_1
-
-    @pytest.fixture
-    def resolved_asset_alias_2(self, asset):
-        """Example asset alias links to asset."""
-        asset_alias_2 = AssetAlias(name="test_name_2")
-        with mock.patch.object(asset_alias_2, "_resolve_assets", 
return_value=[asset]):
-            yield asset_alias_2
-
-    @pytest.mark.parametrize("alias_fixture_name", ["asset_alias_1", 
"resolved_asset_alias_2"])
-    def test_as_expression(self, request: pytest.FixtureRequest, 
alias_fixture_name):
-        alias = request.getfixturevalue(alias_fixture_name)
+    def test_as_expression(self):
+        alias = AssetAlias(name="test_name", group="test")
         assert alias.as_expression() == {"alias": {"name": alias.name, 
"group": alias.group}}
 
-    def test_evalute_empty(self, asset_alias_1, asset):
-        assert asset_alias_1.evaluate({AssetUniqueKey.from_asset(asset): 
True}) is False
-        assert asset_alias_1._resolve_assets.mock_calls == [mock.call(None)]
-
-    def test_evalute_resolved(self, resolved_asset_alias_2, asset):
-        assert 
resolved_asset_alias_2.evaluate({AssetUniqueKey.from_asset(asset): True}) is 
True
-        assert resolved_asset_alias_2._resolve_assets.mock_calls == 
[mock.call(None)]
-
 
 class TestAssetSubclasses:
     @pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, 
"dataset")))
diff --git a/tests/assets/test_evaluation.py b/tests/assets/test_evaluation.py
new file mode 100644
index 00000000000..1c8e909eee1
--- /dev/null
+++ b/tests/assets/test_evaluation.py
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.assets.evaluation import AssetEvaluator
+from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, 
AssetAny, AssetUniqueKey
+from airflow.serialization.serialized_objects import BaseSerialization
+
+pytestmark = pytest.mark.db_test
+
+asset1 = Asset(uri="s3://bucket1/data1", name="asset-1")
+asset2 = Asset(uri="s3://bucket2/data2", name="asset-2")
+
+
[email protected]
+def evaluator(session):
+    return AssetEvaluator(session)
+
+
[email protected](
+    "statuses, result",
+    [
+        ({AssetUniqueKey.from_asset(asset1): True}, True),
+        ({AssetUniqueKey.from_asset(asset1): False}, False),
+        ({}, False),
+    ],
+)
+def test_asset_evaluate(evaluator, statuses, result):
+    assert evaluator.run(asset1, statuses) is result
+
+
[email protected](
+    "condition, statuses, result",
+    [
+        (
+            AssetAny(asset1, asset2),
+            {AssetUniqueKey.from_asset(asset1): False, 
AssetUniqueKey.from_asset(asset2): True},
+            True,
+        ),
+        (
+            AssetAll(asset1, asset2),
+            {AssetUniqueKey.from_asset(asset1): True, 
AssetUniqueKey.from_asset(asset2): False},
+            False,
+        ),
+    ],
+)
+def test_assset_boolean_condition_evaluate_iter(evaluator, condition, 
statuses, result):
+    """
+    Tests _AssetBooleanCondition's evaluate and iter_assets methods through 
AssetAny and AssetAll.
+
+    Ensures AssetAny evaluate returns True with any true condition, AssetAll 
evaluate returns False if
+    any condition is false, and both classes correctly iterate over assets 
without duplication.
+    """
+    assert evaluator.run(condition, statuses) is result
+    assert dict(condition.iter_assets()) == {
+        AssetUniqueKey("asset-1", "s3://bucket1/data1"): asset1,
+        AssetUniqueKey("asset-2", "s3://bucket2/data2"): asset2,
+    }
+
+
[email protected](
+    "inputs, scenario, expected",
+    [
+        # Scenarios for AssetAny
+        ((True, True, True), "any", True),
+        ((True, True, False), "any", True),
+        ((True, False, True), "any", True),
+        ((True, False, False), "any", True),
+        ((False, False, True), "any", True),
+        ((False, True, False), "any", True),
+        ((False, True, True), "any", True),
+        ((False, False, False), "any", False),
+        # Scenarios for AssetAll
+        ((True, True, True), "all", True),
+        ((True, True, False), "all", False),
+        ((True, False, True), "all", False),
+        ((True, False, False), "all", False),
+        ((False, False, True), "all", False),
+        ((False, True, False), "all", False),
+        ((False, True, True), "all", False),
+        ((False, False, False), "all", False),
+    ],
+)
+def test_asset_logical_conditions_evaluation_and_serialization(evaluator, 
inputs, scenario, expected):
+    class_ = AssetAny if scenario == "any" else AssetAll
+    assets = [Asset(uri=f"s3://abc/{i}", name=f"asset_{i}") for i in 
range(123, 126)]
+    condition = class_(*assets)
+
+    statuses = {AssetUniqueKey.from_asset(asset): status for asset, status in 
zip(assets, inputs)}
+    assert (
+        evaluator.run(condition, statuses) == expected
+    ), f"Condition evaluation failed for inputs {inputs} and scenario 
'{scenario}'"
+
+    # Serialize and deserialize the condition to test persistence
+    serialized = BaseSerialization.serialize(condition)
+    deserialized = BaseSerialization.deserialize(serialized)
+    assert evaluator.run(deserialized, statuses) == expected, "Serialization 
round-trip failed"
+
+
[email protected](
+    "status_values, expected_evaluation",
+    [
+        pytest.param(
+            (False, True, True),
+            False,
+            id="f & (t | t)",
+        ),  # AssetAll requires all conditions to be True, but asset1 is False
+        pytest.param(
+            (True, True, True),
+            True,
+            id="t & (t | t)",
+        ),  # All conditions are True
+        pytest.param(
+            (True, False, True),
+            True,
+            id="t & (f | t)",
+        ),  # asset1 is True, and AssetAny condition (asset2 or asset3 being 
True) is met
+        pytest.param(
+            (True, False, False),
+            False,
+            id="t & (f | f)",
+        ),  # asset1 is True, but neither asset2 nor asset3 meet the AssetAny 
condition
+    ],
+)
+def test_nested_asset_conditions_with_serialization(evaluator, status_values, 
expected_evaluation):
+    # Define assets
+    asset1 = Asset(uri="s3://abc/123")
+    asset2 = Asset(uri="s3://abc/124")
+    asset3 = Asset(uri="s3://abc/125")
+
+    # Create a nested condition: AssetAll with asset1 and AssetAny with asset2 
and asset3
+    nested_condition = AssetAll(asset1, AssetAny(asset2, asset3))
+
+    statuses = {
+        AssetUniqueKey.from_asset(asset1): status_values[0],
+        AssetUniqueKey.from_asset(asset2): status_values[1],
+        AssetUniqueKey.from_asset(asset3): status_values[2],
+    }
+
+    assert evaluator.run(nested_condition, statuses) == expected_evaluation, 
"Initial evaluation mismatch"
+
+    serialized_condition = BaseSerialization.serialize(nested_condition)
+    deserialized_condition = 
BaseSerialization.deserialize(serialized_condition)
+
+    assert (
+        evaluator.run(deserialized_condition, statuses) == expected_evaluation
+    ), "Post-serialization evaluation mismatch"
+
+
+class TestAssetAlias:
+    @pytest.fixture
+    def asset(self):
+        """Example asset links to asset alias resolved_asset_alias_2."""
+        return Asset(uri="test://asset1/", name="test_name", group="asset")
+
+    @pytest.fixture
+    def asset_alias_1(self):
+        """Example asset alias links to no assets."""
+        return AssetAlias(name="test_name", group="test")
+
+    @pytest.fixture
+    def resolved_asset_alias_2(self):
+        """Example asset alias links to asset."""
+        return AssetAlias(name="test_name_2")
+
+    @pytest.fixture
+    def evaluator(self, session, asset_alias_1, resolved_asset_alias_2, asset):
+        class _AssetEvaluator(AssetEvaluator):  # Can't use mock because 
AssetEvaluator sets __slots__.
+            def _resolve_asset_alias(self, o):
+                if o is asset_alias_1:
+                    return []
+                elif o is resolved_asset_alias_2:
+                    return [asset]
+                return super()._resolve_asset_alias(o)
+
+        return _AssetEvaluator(session)
+
+    def test_evaluate_empty(self, evaluator, asset_alias_1, asset):
+        assert evaluator.run(asset_alias_1, {AssetUniqueKey.from_asset(asset): 
True}) is False
+
+    def test_evalute_resolved(self, evaluator, resolved_asset_alias_2, asset):
+        assert evaluator.run(resolved_asset_alias_2, 
{AssetUniqueKey.from_asset(asset): True}) is True
diff --git a/tests/models/test_asset.py b/tests/models/test_asset.py
index 1e21252c502..5608d4c8d59 100644
--- a/tests/models/test_asset.py
+++ b/tests/models/test_asset.py
@@ -72,7 +72,7 @@ class TestAssetAliasModel:
         return asset_alias_2
 
     def test_expand_alias_to_assets_empty(self, session, asset_alias_1):
-        assert expand_alias_to_assets(asset_alias_1.name, session) == []
+        assert list(expand_alias_to_assets(asset_alias_1.name, 
session=session)) == []
 
     def test_expand_alias_to_assets_resolved(self, session, 
resolved_asset_alias_2, asset_model):
-        assert expand_alias_to_assets(resolved_asset_alias_2.name, session) == 
[asset_model]
+        assert list(expand_alias_to_assets(resolved_asset_alias_2.name, 
session=session)) == [asset_model]
diff --git a/tests/timetables/test_assets_timetable.py 
b/tests/timetables/test_assets_timetable.py
index 9892b5805bd..d8386bc27f5 100644
--- a/tests/timetables/test_assets_timetable.py
+++ b/tests/timetables/test_assets_timetable.py
@@ -273,8 +273,11 @@ class TestAssetConditionWithTimetable:
         return [Asset(uri=f"test://asset{i}", name=f"hello{i}") for i in 
range(1, 3)]
 
     def test_asset_dag_run_queue_processing(self, session, dag_maker, 
create_test_assets):
+        from airflow.assets.evaluation import AssetEvaluator
+
         assets = create_test_assets
         asset_models = session.query(AssetModel).all()
+        evaluator = AssetEvaluator(session)
 
         with dag_maker(schedule=AssetAny(*assets)) as dag:
             EmptyOperator(task_id="hello")
@@ -298,7 +301,7 @@ class TestAssetConditionWithTimetable:
             dag = SerializedDAG.deserialize(serialized_dag.data)
             for asset_uri, status in dag_statuses[dag.dag_id].items():
                 cond = dag.timetable.asset_condition
-                assert cond.evaluate({asset_uri: status}), "DAG trigger 
evaluation failed"
+                assert evaluator.run(cond, {asset_uri: status}), "DAG trigger 
evaluation failed"
 
     def test_dag_with_complex_asset_condition(self, session, dag_maker):
         # Create Asset instances

Reply via email to