This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3220d376d9d Rewrite db usage in asset decorator operator (#47896)
3220d376d9d is described below

commit 3220d376d9d6c7b83febd201d2d6e97fd45c9a77
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Mar 19 14:34:22 2025 +0800

    Rewrite db usage in asset decorator operator (#47896)
    
    This leaves much room for optimization. We really should use a cache
    across argument resolution, outlet_events, and inlet_events. Refactoring
    can be done later though; this works for now.
---
 .../airflow/sdk/definitions/asset/decorators.py    | 38 ++++++-------
 .../task_sdk/definitions/test_asset_decorators.py  | 64 ++++++++++------------
 2 files changed, 47 insertions(+), 55 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py 
b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
index dea0cf1b95b..e875dabb5af 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -23,7 +23,8 @@ from typing import TYPE_CHECKING, Any
 import attrs
 
 from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetNameRef, AssetRef, 
BaseAsset
+from airflow.sdk.definitions.asset import Asset, AssetRef, BaseAsset
+from airflow.sdk.exceptions import AirflowRuntimeError
 
 if TYPE_CHECKING:
     from collections.abc import Callable, Collection, Iterator, Mapping
@@ -56,35 +57,34 @@ class _AssetMainOperator(PythonOperator):
             definition_name=definition._function.__name__,
         )
 
-    def _iter_kwargs(
-        self, context: Mapping[str, Any], active_assets: dict[str, Asset]
-    ) -> Iterator[tuple[str, Any]]:
+    def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, 
Any]]:
+        import structlog
+
+        from airflow.sdk.execution_time.comms import ErrorResponse, 
GetAssetByName
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        log = structlog.get_logger(logger_name=self.__class__.__qualname__)
+
+        def _fetch_asset(name: str) -> Asset:
+            SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name))
+            if isinstance(msg := SUPERVISOR_COMMS.get_message(), 
ErrorResponse):
+                raise AirflowRuntimeError(msg)
+            return Asset(**msg.model_dump(exclude={"type"}))
+
         value: Any
         for key, param in 
inspect.signature(self.python_callable).parameters.items():
             if param.default is not inspect.Parameter.empty:
                 value = param.default
             elif key == "self":
-                value = active_assets.get(self._definition_name)
+                value = _fetch_asset(self._definition_name)
             elif key == "context":
                 value = context
             else:
-                value = active_assets.get(key, Asset(name=key))
+                value = _fetch_asset(key)
             yield key, value
 
     def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, 
Any]:
-        from airflow.models.asset import fetch_active_assets_by_name
-        from airflow.utils.session import create_session
-
-        asset_names = {asset_ref.name for asset_ref in self.inlets if 
isinstance(asset_ref, AssetNameRef)}
-        if "self" in inspect.signature(self.python_callable).parameters:
-            asset_names.add(self._definition_name)
-
-        if asset_names:
-            with create_session() as session:
-                active_assets = fetch_active_assets_by_name(asset_names, 
session)
-        else:
-            active_assets = {}
-        return dict(self._iter_kwargs(context, active_assets))
+        return dict(self._iter_kwargs(context))
 
 
 @attrs.define(kw_only=True)
diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py 
b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
index afa94c0a7e3..7201fb130d0 100644
--- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
+++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
@@ -20,9 +20,9 @@ from unittest import mock
 
 import pytest
 
-from airflow.models.asset import AssetModel
 from airflow.sdk.definitions.asset import Asset
 from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset
+from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName
 
 
 @pytest.fixture
@@ -246,30 +246,26 @@ class Test_AssetMainOperator:
         assert op.python_callable == 
example_asset_func_with_valid_arg_as_inlet_asset
         assert op._definition_name == "example_asset_func"
 
-    @mock.patch("airflow.models.asset.fetch_active_assets_by_name")
-    @mock.patch("airflow.utils.session.create_session")
+    @mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True)
     def test_determine_kwargs(
         self,
-        mock_create_session,
-        mock_fetch_active_assets_by_name,
+        mock_supervisor_comms,
         example_asset_func_with_valid_arg_as_inlet_asset,
     ):
         asset_definition = asset(schedule=None, uri="s3://bucket/object", 
group="MLModel", extra={"k": "v"})(
             example_asset_func_with_valid_arg_as_inlet_asset
         )
 
-        class FakeSession:
-            def __enter__(self):
-                return self
-
-            def __exit__(self, *args, **kwargs):
-                pass
-
-        mock_create_session.return_value = fake_session = FakeSession()
-        mock_fetch_active_assets_by_name.return_value = {
-            "example_asset_func": AssetModel.from_public(asset_definition),
-            "inlet_asset_1": AssetModel(uri="s3://bucket/object1", 
name="inlet_asset_1"),
-        }
+        mock_supervisor_comms.get_message.side_effect = [
+            AssetResult(
+                name="example_asset_func",
+                uri="s3://bucket/object",
+                group="MLModel",
+                extra={"k": "v"},
+            ),
+            AssetResult(name="inlet_asset_1", uri="s3://bucket/object1", 
group="asset", extra=None),
+            AssetResult(name="inlet_asset_2", uri="inlet_asset_2", 
group="asset", extra=None),
+        ]
 
         op = _AssetMainOperator(
             task_id="example_asset_func",
@@ -290,31 +286,26 @@ class Test_AssetMainOperator:
             "inlet_asset_2": Asset(name="inlet_asset_2"),
         }
 
-        assert mock_fetch_active_assets_by_name.mock_calls == [
-            mock.call({"example_asset_func", "inlet_asset_1", 
"inlet_asset_2"}, fake_session),
+        assert mock_supervisor_comms.mock_calls == [
+            mock.call.send_request(mock.ANY, 
GetAssetByName(name="example_asset_func")),
+            mock.call.get_message(),
+            mock.call.send_request(mock.ANY, 
GetAssetByName(name="inlet_asset_1")),
+            mock.call.get_message(),
+            mock.call.send_request(mock.ANY, 
GetAssetByName(name="inlet_asset_2")),
+            mock.call.get_message(),
         ]
 
-    @mock.patch("airflow.models.asset.fetch_active_assets_by_name")
-    @mock.patch("airflow.utils.session.create_session")
+    @mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True)
     def test_determine_kwargs_defaults(
         self,
-        mock_create_session,
-        mock_fetch_active_assets_by_name,
+        mock_supervisor_comms,
         example_asset_func_with_valid_arg_as_inlet_asset_and_default,
     ):
         asset_definition = 
asset(schedule=None)(example_asset_func_with_valid_arg_as_inlet_asset_and_default)
 
-        class FakeSession:
-            def __enter__(self):
-                return self
-
-            def __exit__(self, *args, **kwargs):
-                pass
-
-        mock_create_session.return_value = fake_session = FakeSession()
-        mock_fetch_active_assets_by_name.return_value = {
-            "inlet_asset_1": AssetModel(uri="s3://bucket/object1", 
name="inlet_asset_1"),
-        }
+        mock_supervisor_comms.get_message.side_effect = [
+            AssetResult(name="inlet_asset_1", uri="s3://bucket/object1", 
group="asset", extra=None),
+        ]
 
         op = _AssetMainOperator(
             task_id="__main__",
@@ -329,6 +320,7 @@ class Test_AssetMainOperator:
             "unknown_name": "default supplied for non-asset argument",
         }
 
-        assert mock_fetch_active_assets_by_name.mock_calls == [
-            mock.call({"inlet_asset_1"}, fake_session),
+        assert mock_supervisor_comms.mock_calls == [
+            mock.call.send_request(mock.ANY, 
GetAssetByName(name="inlet_asset_1")),
+            mock.call.get_message(),
         ]

Reply via email to