This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5942c98ea7d limit amount of assets collected by hook lineage collector 
(#45798)
5942c98ea7d is described below

commit 5942c98ea7d2aa4bc34b999ff863a5ba55cc05db
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Tue Jan 21 13:45:34 2025 +0100

    limit amount of assets collected by hook lineage collector (#45798)
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/lineage/hook.py    | 20 +++++++++++--
 tests/lineage/test_hook.py | 75 ++++++++++++++++++++++++++--------------------
 2 files changed, 60 insertions(+), 35 deletions(-)

diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py
index 62a2c7a5493..796deb8466b 100644
--- a/airflow/lineage/hook.py
+++ b/airflow/lineage/hook.py
@@ -38,6 +38,11 @@ if TYPE_CHECKING:
 _hook_lineage_collector: HookLineageCollector | None = None
 
 
+# Maximum number of assets input or output that can be collected in a single 
hook execution.
+# Input assets and output assets are collected separately.
+MAX_COLLECTED_ASSETS = 100
+
+
 @attr.define
 class AssetLineageInfo:
     """
@@ -81,6 +86,7 @@ class HookLineageCollector(LoggingMixin):
         self._outputs: dict[str, tuple[Asset, LineageContext]] = {}
         self._input_counts: dict[str, int] = defaultdict(int)
         self._output_counts: dict[str, int] = defaultdict(int)
+        self._asset_factories = ProvidersManager().asset_factories
 
     def _generate_key(self, asset: Asset, context: LineageContext) -> str:
         """
@@ -136,7 +142,7 @@ class HookLineageCollector(LoggingMixin):
             )
             return None
 
-        asset_factory = ProvidersManager().asset_factories.get(scheme)
+        asset_factory = self._asset_factories.get(scheme)
         if not asset_factory:
             self.log.debug("Unsupported scheme: %s. Please provide a valid URI 
to create an asset.", scheme)
             return None
@@ -159,6 +165,9 @@ class HookLineageCollector(LoggingMixin):
         asset_extra: dict | None = None,
     ):
         """Add the input asset and its corresponding hook execution context to 
the collector."""
+        if len(self._inputs) >= MAX_COLLECTED_ASSETS:
+            self.log.debug("Maximum number of asset inputs exceeded. 
Skipping.")
+            return
         asset = self.create_asset(
             scheme=scheme, uri=uri, name=name, group=group, 
asset_kwargs=asset_kwargs, asset_extra=asset_extra
         )
@@ -167,6 +176,8 @@ class HookLineageCollector(LoggingMixin):
             if key not in self._inputs:
                 self._inputs[key] = (asset, context)
             self._input_counts[key] += 1
+        if len(self._inputs) == MAX_COLLECTED_ASSETS:
+            self.log.warning("Maximum number of asset inputs exceeded. 
Skipping subsequent inputs.")
 
     def add_output_asset(
         self,
@@ -179,6 +190,9 @@ class HookLineageCollector(LoggingMixin):
         asset_extra: dict | None = None,
     ):
         """Add the output asset and its corresponding hook execution context 
to the collector."""
+        if len(self._outputs) >= MAX_COLLECTED_ASSETS:
+            self.log.debug("Maximum number of asset outputs exceeded. 
Skipping.")
+            return
         asset = self.create_asset(
             scheme=scheme, uri=uri, name=name, group=group, 
asset_kwargs=asset_kwargs, asset_extra=asset_extra
         )
@@ -187,6 +201,8 @@ class HookLineageCollector(LoggingMixin):
             if key not in self._outputs:
                 self._outputs[key] = (asset, context)
             self._output_counts[key] += 1
+        if len(self._outputs) == MAX_COLLECTED_ASSETS:
+            self.log.warning("Maximum number of asset outputs exceeded. 
Skipping subsequent inputs.")
 
     @property
     def collected_assets(self) -> HookLineage:
@@ -225,7 +241,7 @@ class NoOpCollector(HookLineageCollector):
     def collected_assets(
         self,
     ) -> HookLineage:
-        self.log.warning(
+        self.log.debug(
             "Data lineage tracking is disabled. Register a hook lineage reader 
to start tracking hook lineage."
         )
         return HookLineage([], [])
diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py
index f66f6c2bf9f..4586e59ea76 100644
--- a/tests/lineage/test_hook.py
+++ b/tests/lineage/test_hook.py
@@ -38,20 +38,21 @@ from tests_common.test_utils.mock_plugins import 
mock_plugin_manager
 
 
 class TestHookLineageCollector:
-    def setup_method(self):
-        self.collector = HookLineageCollector()
+    @pytest.fixture
+    def collector(self, scope="method"):
+        return HookLineageCollector()
 
-    def test_are_assets_collected(self):
-        assert self.collector is not None
-        assert self.collector.collected_assets == HookLineage()
+    def test_are_assets_collected(self, collector):
+        assert collector is not None
+        assert collector.collected_assets == HookLineage()
         input_hook = BaseHook()
         output_hook = BaseHook()
-        self.collector.add_input_asset(input_hook, uri="s3://in_bucket/file", 
name="asset-1", group="test")
-        self.collector.add_output_asset(
+        collector.add_input_asset(input_hook, uri="s3://in_bucket/file", 
name="asset-1", group="test")
+        collector.add_output_asset(
             output_hook,
             uri="postgres://example.com:5432/database/default/table",
         )
-        assert self.collector.collected_assets == HookLineage(
+        assert collector.collected_assets == HookLineage(
             [
                 AssetLineageInfo(
                     asset=Asset(uri="s3://in_bucket/file", name="asset-1", 
group="test"),
@@ -73,27 +74,27 @@ class TestHookLineageCollector:
         )
 
     @patch("airflow.lineage.hook.Asset")
-    def test_add_input_asset(self, mock_asset):
+    def test_add_input_asset(self, mock_asset, collector):
         asset = MagicMock(spec=Asset, extra={})
         mock_asset.return_value = asset
 
         hook = MagicMock()
-        self.collector.add_input_asset(hook, uri="test_uri")
+        collector.add_input_asset(hook, uri="test_uri")
 
-        assert next(iter(self.collector._inputs.values())) == (asset, hook)
+        assert next(iter(collector._inputs.values())) == (asset, hook)
         mock_asset.assert_called_once_with(uri="test_uri")
 
-    def test_grouping_assets(self):
+    def test_grouping_assets(self, collector):
         hook_1 = MagicMock()
         hook_2 = MagicMock()
 
         uri = "test://uri/"
 
-        self.collector.add_input_asset(context=hook_1, uri=uri)
-        self.collector.add_input_asset(context=hook_2, uri=uri)
-        self.collector.add_input_asset(context=hook_1, uri=uri, 
asset_extra={"key": "value"})
+        collector.add_input_asset(context=hook_1, uri=uri)
+        collector.add_input_asset(context=hook_2, uri=uri)
+        collector.add_input_asset(context=hook_1, uri=uri, asset_extra={"key": 
"value"})
 
-        collected_inputs = self.collector.collected_assets.inputs
+        collected_inputs = collector.collected_assets.inputs
 
         assert len(collected_inputs) == 3
         assert collected_inputs[0].asset.uri == "test://uri/"
@@ -105,15 +106,14 @@ class TestHookLineageCollector:
         assert collected_inputs[2].count == 1
         assert collected_inputs[2].asset.extra == {"key": "value"}
 
-    @patch("airflow.lineage.hook.ProvidersManager")
-    def test_create_asset(self, mock_providers_manager):
+    def test_create_asset(self, collector):
         def create_asset(arg1, arg2="default", extra=None):
             return Asset(
                 uri=f"myscheme://{arg1}/{arg2}", name=f"asset-{arg1}", 
group="test", extra=extra or {}
             )
 
-        mock_providers_manager.return_value.asset_factories = {"myscheme": 
create_asset}
-        assert self.collector.create_asset(
+        collector._asset_factories = {"myscheme": create_asset}
+        assert collector.create_asset(
             scheme="myscheme",
             uri=None,
             name=None,
@@ -121,7 +121,7 @@ class TestHookLineageCollector:
             asset_kwargs={"arg1": "value_1"},
             asset_extra=None,
         ) == Asset(uri="myscheme://value_1/default", name="asset-value_1", 
group="test")
-        assert self.collector.create_asset(
+        assert collector.create_asset(
             scheme="myscheme",
             uri=None,
             name=None,
@@ -133,14 +133,14 @@ class TestHookLineageCollector:
         )
 
     @patch("airflow.lineage.hook.ProvidersManager")
-    def test_create_asset_no_factory(self, mock_providers_manager):
+    def test_create_asset_no_factory(self, mock_providers_manager, collector):
         test_scheme = "myscheme"
         mock_providers_manager.return_value.asset_factories = {}
 
         test_kwargs = {"arg1": "value_1"}
 
         assert (
-            self.collector.create_asset(
+            collector.create_asset(
                 scheme=test_scheme,
                 uri=None,
                 name=None,
@@ -152,7 +152,7 @@ class TestHookLineageCollector:
         )
 
     @patch("airflow.lineage.hook.ProvidersManager")
-    def test_create_asset_factory_exception(self, mock_providers_manager):
+    def test_create_asset_factory_exception(self, mock_providers_manager, 
collector):
         def create_asset(extra=None, **kwargs):
             raise RuntimeError("Factory error")
 
@@ -162,20 +162,18 @@ class TestHookLineageCollector:
         test_kwargs = {"arg1": "value_1"}
 
         assert (
-            self.collector.create_asset(
-                scheme=test_scheme, uri=None, asset_kwargs=test_kwargs, 
asset_extra=None
-            )
+            collector.create_asset(scheme=test_scheme, uri=None, 
asset_kwargs=test_kwargs, asset_extra=None)
             is None
         )
 
-    def test_collected_assets(self):
+    def test_collected_assets(self, collector):
         context_input = MagicMock()
         context_output = MagicMock()
 
-        self.collector.add_input_asset(context_input, uri="test://input")
-        self.collector.add_output_asset(context_output, uri="test://output")
+        collector.add_input_asset(context_input, uri="test://input")
+        collector.add_output_asset(context_output, uri="test://output")
 
-        hook_lineage = self.collector.collected_assets
+        hook_lineage = collector.collected_assets
         assert len(hook_lineage.inputs) == 1
         assert hook_lineage.inputs[0].asset.uri == "test://input/"
         assert hook_lineage.inputs[0].context == context_input
@@ -183,13 +181,24 @@ class TestHookLineageCollector:
         assert len(hook_lineage.outputs) == 1
         assert hook_lineage.outputs[0].asset.uri == "test://output/"
 
-    def test_has_collected(self):
-        collector = HookLineageCollector()
+    def test_has_collected(self, collector):
         assert not collector.has_collected
 
         collector._inputs = {"unique_key": (MagicMock(spec=Asset), 
MagicMock())}
         assert collector.has_collected
 
+    def test_hooks_limit_input_output_assets(self):
+        collector = HookLineageCollector()
+        assert not collector.has_collected
+
+        for i in range(1000):
+            collector.add_input_asset(MagicMock(), uri=f"test://input/{i}")
+            collector.add_output_asset(MagicMock(), uri=f"test://output/{i}")
+
+        assert collector.has_collected
+        assert len(collector._inputs) == 100
+        assert len(collector._outputs) == 100
+
 
 class FakePlugin(plugins_manager.AirflowPlugin):
     name = "FakePluginHavingHookLineageCollector"

Reply via email to