This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3220d376d9d Rewrite db usage in asset decorator operator (#47896)
3220d376d9d is described below
commit 3220d376d9d6c7b83febd201d2d6e97fd45c9a77
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Mar 19 14:34:22 2025 +0800
Rewrite db usage in asset decorator operator (#47896)
This leaves much room for optimization. We really should use a cache
across argument resolution, outlet_events, and inlet_events. Refactoring
can be done later though; this works for now.
---
.../airflow/sdk/definitions/asset/decorators.py | 38 ++++++-------
.../task_sdk/definitions/test_asset_decorators.py | 64 ++++++++++------------
2 files changed, 47 insertions(+), 55 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
index dea0cf1b95b..e875dabb5af 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py
@@ -23,7 +23,8 @@ from typing import TYPE_CHECKING, Any
import attrs
from airflow.providers.standard.operators.python import PythonOperator
-from airflow.sdk.definitions.asset import Asset, AssetNameRef, AssetRef,
BaseAsset
+from airflow.sdk.definitions.asset import Asset, AssetRef, BaseAsset
+from airflow.sdk.exceptions import AirflowRuntimeError
if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterator, Mapping
@@ -56,35 +57,34 @@ class _AssetMainOperator(PythonOperator):
definition_name=definition._function.__name__,
)
- def _iter_kwargs(
- self, context: Mapping[str, Any], active_assets: dict[str, Asset]
- ) -> Iterator[tuple[str, Any]]:
+ def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str,
Any]]:
+ import structlog
+
+ from airflow.sdk.execution_time.comms import ErrorResponse,
GetAssetByName
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ log = structlog.get_logger(logger_name=self.__class__.__qualname__)
+
+ def _fetch_asset(name: str) -> Asset:
+ SUPERVISOR_COMMS.send_request(log, GetAssetByName(name=name))
+ if isinstance(msg := SUPERVISOR_COMMS.get_message(),
ErrorResponse):
+ raise AirflowRuntimeError(msg)
+ return Asset(**msg.model_dump(exclude={"type"}))
+
value: Any
for key, param in
inspect.signature(self.python_callable).parameters.items():
if param.default is not inspect.Parameter.empty:
value = param.default
elif key == "self":
- value = active_assets.get(self._definition_name)
+ value = _fetch_asset(self._definition_name)
elif key == "context":
value = context
else:
- value = active_assets.get(key, Asset(name=key))
+ value = _fetch_asset(key)
yield key, value
def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str,
Any]:
- from airflow.models.asset import fetch_active_assets_by_name
- from airflow.utils.session import create_session
-
- asset_names = {asset_ref.name for asset_ref in self.inlets if
isinstance(asset_ref, AssetNameRef)}
- if "self" in inspect.signature(self.python_callable).parameters:
- asset_names.add(self._definition_name)
-
- if asset_names:
- with create_session() as session:
- active_assets = fetch_active_assets_by_name(asset_names,
session)
- else:
- active_assets = {}
- return dict(self._iter_kwargs(context, active_assets))
+ return dict(self._iter_kwargs(context))
@attrs.define(kw_only=True)
diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
index afa94c0a7e3..7201fb130d0 100644
--- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
+++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
@@ -20,9 +20,9 @@ from unittest import mock
import pytest
-from airflow.models.asset import AssetModel
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset
+from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName
@pytest.fixture
@@ -246,30 +246,26 @@ class Test_AssetMainOperator:
assert op.python_callable ==
example_asset_func_with_valid_arg_as_inlet_asset
assert op._definition_name == "example_asset_func"
- @mock.patch("airflow.models.asset.fetch_active_assets_by_name")
- @mock.patch("airflow.utils.session.create_session")
+ @mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True)
def test_determine_kwargs(
self,
- mock_create_session,
- mock_fetch_active_assets_by_name,
+ mock_supervisor_comms,
example_asset_func_with_valid_arg_as_inlet_asset,
):
asset_definition = asset(schedule=None, uri="s3://bucket/object",
group="MLModel", extra={"k": "v"})(
example_asset_func_with_valid_arg_as_inlet_asset
)
- class FakeSession:
- def __enter__(self):
- return self
-
- def __exit__(self, *args, **kwargs):
- pass
-
- mock_create_session.return_value = fake_session = FakeSession()
- mock_fetch_active_assets_by_name.return_value = {
- "example_asset_func": AssetModel.from_public(asset_definition),
- "inlet_asset_1": AssetModel(uri="s3://bucket/object1",
name="inlet_asset_1"),
- }
+ mock_supervisor_comms.get_message.side_effect = [
+ AssetResult(
+ name="example_asset_func",
+ uri="s3://bucket/object",
+ group="MLModel",
+ extra={"k": "v"},
+ ),
+ AssetResult(name="inlet_asset_1", uri="s3://bucket/object1",
group="asset", extra=None),
+ AssetResult(name="inlet_asset_2", uri="inlet_asset_2",
group="asset", extra=None),
+ ]
op = _AssetMainOperator(
task_id="example_asset_func",
@@ -290,31 +286,26 @@ class Test_AssetMainOperator:
"inlet_asset_2": Asset(name="inlet_asset_2"),
}
- assert mock_fetch_active_assets_by_name.mock_calls == [
- mock.call({"example_asset_func", "inlet_asset_1",
"inlet_asset_2"}, fake_session),
+ assert mock_supervisor_comms.mock_calls == [
+ mock.call.send_request(mock.ANY,
GetAssetByName(name="example_asset_func")),
+ mock.call.get_message(),
+ mock.call.send_request(mock.ANY,
GetAssetByName(name="inlet_asset_1")),
+ mock.call.get_message(),
+ mock.call.send_request(mock.ANY,
GetAssetByName(name="inlet_asset_2")),
+ mock.call.get_message(),
]
- @mock.patch("airflow.models.asset.fetch_active_assets_by_name")
- @mock.patch("airflow.utils.session.create_session")
+ @mock.patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True)
def test_determine_kwargs_defaults(
self,
- mock_create_session,
- mock_fetch_active_assets_by_name,
+ mock_supervisor_comms,
example_asset_func_with_valid_arg_as_inlet_asset_and_default,
):
asset_definition =
asset(schedule=None)(example_asset_func_with_valid_arg_as_inlet_asset_and_default)
- class FakeSession:
- def __enter__(self):
- return self
-
- def __exit__(self, *args, **kwargs):
- pass
-
- mock_create_session.return_value = fake_session = FakeSession()
- mock_fetch_active_assets_by_name.return_value = {
- "inlet_asset_1": AssetModel(uri="s3://bucket/object1",
name="inlet_asset_1"),
- }
+ mock_supervisor_comms.get_message.side_effect = [
+ AssetResult(name="inlet_asset_1", uri="s3://bucket/object1",
group="asset", extra=None),
+ ]
op = _AssetMainOperator(
task_id="__main__",
@@ -329,6 +320,7 @@ class Test_AssetMainOperator:
"unknown_name": "default supplied for non-asset argument",
}
- assert mock_fetch_active_assets_by_name.mock_calls == [
- mock.call({"inlet_asset_1"}, fake_session),
+ assert mock_supervisor_comms.mock_calls == [
+ mock.call.send_request(mock.ANY,
GetAssetByName(name="inlet_asset_1")),
+ mock.call.get_message(),
]