This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5942c98ea7d limit amount of assets collected by hook lineage collector
(#45798)
5942c98ea7d is described below
commit 5942c98ea7d2aa4bc34b999ff863a5ba55cc05db
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Tue Jan 21 13:45:34 2025 +0100
limit amount of assets collected by hook lineage collector (#45798)
Signed-off-by: Maciej Obuchowski <[email protected]>
---
airflow/lineage/hook.py | 20 +++++++++++--
tests/lineage/test_hook.py | 75 ++++++++++++++++++++++++++--------------------
2 files changed, 60 insertions(+), 35 deletions(-)
diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py
index 62a2c7a5493..796deb8466b 100644
--- a/airflow/lineage/hook.py
+++ b/airflow/lineage/hook.py
@@ -38,6 +38,11 @@ if TYPE_CHECKING:
_hook_lineage_collector: HookLineageCollector | None = None
+# Maximum number of assets input or output that can be collected in a single
hook execution.
+# Input assets and output assets are collected separately.
+MAX_COLLECTED_ASSETS = 100
+
+
@attr.define
class AssetLineageInfo:
"""
@@ -81,6 +86,7 @@ class HookLineageCollector(LoggingMixin):
self._outputs: dict[str, tuple[Asset, LineageContext]] = {}
self._input_counts: dict[str, int] = defaultdict(int)
self._output_counts: dict[str, int] = defaultdict(int)
+ self._asset_factories = ProvidersManager().asset_factories
def _generate_key(self, asset: Asset, context: LineageContext) -> str:
"""
@@ -136,7 +142,7 @@ class HookLineageCollector(LoggingMixin):
)
return None
- asset_factory = ProvidersManager().asset_factories.get(scheme)
+ asset_factory = self._asset_factories.get(scheme)
if not asset_factory:
self.log.debug("Unsupported scheme: %s. Please provide a valid URI
to create an asset.", scheme)
return None
@@ -159,6 +165,9 @@ class HookLineageCollector(LoggingMixin):
asset_extra: dict | None = None,
):
"""Add the input asset and its corresponding hook execution context to
the collector."""
+ if len(self._inputs) >= MAX_COLLECTED_ASSETS:
+ self.log.debug("Maximum number of asset inputs exceeded.
Skipping.")
+ return
asset = self.create_asset(
scheme=scheme, uri=uri, name=name, group=group,
asset_kwargs=asset_kwargs, asset_extra=asset_extra
)
@@ -167,6 +176,8 @@ class HookLineageCollector(LoggingMixin):
if key not in self._inputs:
self._inputs[key] = (asset, context)
self._input_counts[key] += 1
+ if len(self._inputs) == MAX_COLLECTED_ASSETS:
+ self.log.warning("Maximum number of asset inputs exceeded.
Skipping subsequent inputs.")
def add_output_asset(
self,
@@ -179,6 +190,9 @@ class HookLineageCollector(LoggingMixin):
asset_extra: dict | None = None,
):
"""Add the output asset and its corresponding hook execution context
to the collector."""
+ if len(self._outputs) >= MAX_COLLECTED_ASSETS:
+ self.log.debug("Maximum number of asset outputs exceeded.
Skipping.")
+ return
asset = self.create_asset(
scheme=scheme, uri=uri, name=name, group=group,
asset_kwargs=asset_kwargs, asset_extra=asset_extra
)
@@ -187,6 +201,8 @@ class HookLineageCollector(LoggingMixin):
if key not in self._outputs:
self._outputs[key] = (asset, context)
self._output_counts[key] += 1
+ if len(self._outputs) == MAX_COLLECTED_ASSETS:
+ self.log.warning("Maximum number of asset outputs exceeded.
Skipping subsequent inputs.")
@property
def collected_assets(self) -> HookLineage:
@@ -225,7 +241,7 @@ class NoOpCollector(HookLineageCollector):
def collected_assets(
self,
) -> HookLineage:
- self.log.warning(
+ self.log.debug(
"Data lineage tracking is disabled. Register a hook lineage reader
to start tracking hook lineage."
)
return HookLineage([], [])
diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py
index f66f6c2bf9f..4586e59ea76 100644
--- a/tests/lineage/test_hook.py
+++ b/tests/lineage/test_hook.py
@@ -38,20 +38,21 @@ from tests_common.test_utils.mock_plugins import
mock_plugin_manager
class TestHookLineageCollector:
- def setup_method(self):
- self.collector = HookLineageCollector()
+ @pytest.fixture
+ def collector(self, scope="method"):
+ return HookLineageCollector()
- def test_are_assets_collected(self):
- assert self.collector is not None
- assert self.collector.collected_assets == HookLineage()
+ def test_are_assets_collected(self, collector):
+ assert collector is not None
+ assert collector.collected_assets == HookLineage()
input_hook = BaseHook()
output_hook = BaseHook()
- self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file",
name="asset-1", group="test")
- self.collector.add_output_asset(
+ collector.add_input_asset(input_hook, uri="s3://in_bucket/file",
name="asset-1", group="test")
+ collector.add_output_asset(
output_hook,
uri="postgres://example.com:5432/database/default/table",
)
- assert self.collector.collected_assets == HookLineage(
+ assert collector.collected_assets == HookLineage(
[
AssetLineageInfo(
asset=Asset(uri="s3://in_bucket/file", name="asset-1",
group="test"),
@@ -73,27 +74,27 @@ class TestHookLineageCollector:
)
@patch("airflow.lineage.hook.Asset")
- def test_add_input_asset(self, mock_asset):
+ def test_add_input_asset(self, mock_asset, collector):
asset = MagicMock(spec=Asset, extra={})
mock_asset.return_value = asset
hook = MagicMock()
- self.collector.add_input_asset(hook, uri="test_uri")
+ collector.add_input_asset(hook, uri="test_uri")
- assert next(iter(self.collector._inputs.values())) == (asset, hook)
+ assert next(iter(collector._inputs.values())) == (asset, hook)
mock_asset.assert_called_once_with(uri="test_uri")
- def test_grouping_assets(self):
+ def test_grouping_assets(self, collector):
hook_1 = MagicMock()
hook_2 = MagicMock()
uri = "test://uri/"
- self.collector.add_input_asset(context=hook_1, uri=uri)
- self.collector.add_input_asset(context=hook_2, uri=uri)
- self.collector.add_input_asset(context=hook_1, uri=uri,
asset_extra={"key": "value"})
+ collector.add_input_asset(context=hook_1, uri=uri)
+ collector.add_input_asset(context=hook_2, uri=uri)
+ collector.add_input_asset(context=hook_1, uri=uri, asset_extra={"key":
"value"})
- collected_inputs = self.collector.collected_assets.inputs
+ collected_inputs = collector.collected_assets.inputs
assert len(collected_inputs) == 3
assert collected_inputs[0].asset.uri == "test://uri/"
@@ -105,15 +106,14 @@ class TestHookLineageCollector:
assert collected_inputs[2].count == 1
assert collected_inputs[2].asset.extra == {"key": "value"}
- @patch("airflow.lineage.hook.ProvidersManager")
- def test_create_asset(self, mock_providers_manager):
+ def test_create_asset(self, collector):
def create_asset(arg1, arg2="default", extra=None):
return Asset(
uri=f"myscheme://{arg1}/{arg2}", name=f"asset-{arg1}",
group="test", extra=extra or {}
)
- mock_providers_manager.return_value.asset_factories = {"myscheme":
create_asset}
- assert self.collector.create_asset(
+ collector._asset_factories = {"myscheme": create_asset}
+ assert collector.create_asset(
scheme="myscheme",
uri=None,
name=None,
@@ -121,7 +121,7 @@ class TestHookLineageCollector:
asset_kwargs={"arg1": "value_1"},
asset_extra=None,
) == Asset(uri="myscheme://value_1/default", name="asset-value_1",
group="test")
- assert self.collector.create_asset(
+ assert collector.create_asset(
scheme="myscheme",
uri=None,
name=None,
@@ -133,14 +133,14 @@ class TestHookLineageCollector:
)
@patch("airflow.lineage.hook.ProvidersManager")
- def test_create_asset_no_factory(self, mock_providers_manager):
+ def test_create_asset_no_factory(self, mock_providers_manager, collector):
test_scheme = "myscheme"
mock_providers_manager.return_value.asset_factories = {}
test_kwargs = {"arg1": "value_1"}
assert (
- self.collector.create_asset(
+ collector.create_asset(
scheme=test_scheme,
uri=None,
name=None,
@@ -152,7 +152,7 @@ class TestHookLineageCollector:
)
@patch("airflow.lineage.hook.ProvidersManager")
- def test_create_asset_factory_exception(self, mock_providers_manager):
+ def test_create_asset_factory_exception(self, mock_providers_manager,
collector):
def create_asset(extra=None, **kwargs):
raise RuntimeError("Factory error")
@@ -162,20 +162,18 @@ class TestHookLineageCollector:
test_kwargs = {"arg1": "value_1"}
assert (
- self.collector.create_asset(
- scheme=test_scheme, uri=None, asset_kwargs=test_kwargs,
asset_extra=None
- )
+ collector.create_asset(scheme=test_scheme, uri=None,
asset_kwargs=test_kwargs, asset_extra=None)
is None
)
- def test_collected_assets(self):
+ def test_collected_assets(self, collector):
context_input = MagicMock()
context_output = MagicMock()
- self.collector.add_input_asset(context_input, uri="test://input")
- self.collector.add_output_asset(context_output, uri="test://output")
+ collector.add_input_asset(context_input, uri="test://input")
+ collector.add_output_asset(context_output, uri="test://output")
- hook_lineage = self.collector.collected_assets
+ hook_lineage = collector.collected_assets
assert len(hook_lineage.inputs) == 1
assert hook_lineage.inputs[0].asset.uri == "test://input/"
assert hook_lineage.inputs[0].context == context_input
@@ -183,13 +181,24 @@ class TestHookLineageCollector:
assert len(hook_lineage.outputs) == 1
assert hook_lineage.outputs[0].asset.uri == "test://output/"
- def test_has_collected(self):
- collector = HookLineageCollector()
+ def test_has_collected(self, collector):
assert not collector.has_collected
collector._inputs = {"unique_key": (MagicMock(spec=Asset),
MagicMock())}
assert collector.has_collected
+ def test_hooks_limit_input_output_assets(self):
+ collector = HookLineageCollector()
+ assert not collector.has_collected
+
+ for i in range(1000):
+ collector.add_input_asset(MagicMock(), uri=f"test://input/{i}")
+ collector.add_output_asset(MagicMock(), uri=f"test://output/{i}")
+
+ assert collector.has_collected
+ assert len(collector._inputs) == 100
+ assert len(collector._outputs) == 100
+
class FakePlugin(plugins_manager.AirflowPlugin):
name = "FakePluginHavingHookLineageCollector"