This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3f30adf570b feat: Adjust compat HookLevelLineage for new add_extra
method (#58057)
3f30adf570b is described below
commit 3f30adf570ba83fafb70473e854d3bc09ebff2ed
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Nov 26 14:31:43 2025 +0100
feat: Adjust compat HookLevelLineage for new add_extra method (#58057)
Extends the current common.compat compatibility layer with support for the
new Hook Lineage `add_extra` method, ensuring full backward compatibility and
consistent hook-level lineage behavior across all Airflow 2.11+ versions.
---
.../providers/common/compat/lineage/hook.py | 198 ++++-
.../tests/unit/common/compat/lineage/test_hook.py | 977 ++++++++++++++++++++-
2 files changed, 1137 insertions(+), 38 deletions(-)
diff --git
a/providers/common/compat/src/airflow/providers/common/compat/lineage/hook.py
b/providers/common/compat/src/airflow/providers/common/compat/lineage/hook.py
index 3d4de69d503..2eb07f446f3 100644
---
a/providers/common/compat/src/airflow/providers/common/compat/lineage/hook.py
+++
b/providers/common/compat/src/airflow/providers/common/compat/lineage/hook.py
@@ -16,22 +16,166 @@
# under the License.
from __future__ import annotations
-from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from typing import Any
-def _get_asset_compat_hook_lineage_collector():
- from airflow.lineage.hook import get_hook_lineage_collector
+ from airflow.lineage.hook import LineageContext
- collector = get_hook_lineage_collector()
- if all(
- getattr(collector, asset_method_name, None)
- for asset_method_name in ("add_input_asset", "add_output_asset",
"collected_assets")
+def _lacks_asset_methods(collector):
+ """Return True if the collector is missing any asset-related methods or
properties."""
+ if ( # lazy evaluation, early return
+ hasattr(collector, "add_input_asset") # method
+ and hasattr(collector, "add_output_asset") # method
+ and hasattr(collector, "create_asset") # method
+ # If below we called hasattr(collector, "collected_assets") we'd call
the property unnecessarily
+ and hasattr(type(collector), "collected_assets") # property
):
- return collector
+ return False
- # dataset is renamed as asset in Airflow 3.0
+ return True
+
+def _lacks_add_extra_method(collector):
+ """Return True if the collector does not define an 'add_extra' method."""
+ # Method may be on class and attribute may be dynamically set on instance
+ if hasattr(collector, "add_extra") and hasattr(collector, "_extra"):
+ return False
+ return True
+
+
+def _add_extra_polyfill(collector):
+ """
+ Add support for `add_extra` method to collector that may be lacking it
(e.g., Airflow versions < 3.2.).
+
+ This polyfill adds the `add_extra` method to a class, modifies
`collected_assets` and `has_collected`
+ properties and sets `_extra` and `_extra_counts` attributes on instance if
not already there.
+
+ This function should be called after renaming on collectors that have
`collected_assets` method,
+ so f.e. for Airflow 2 it should happen after renaming from dataset to
asset.
+ """
+ import hashlib
+ import json
+ from collections import defaultdict
+
+ import attr
+
+ from airflow.lineage.hook import HookLineage as _BaseHookLineage
+
+ # Add `extra` to HookLineage returned by `collected_assets` property
+ @attr.define
+ class ExtraLineageInfo:
+ """
+ Holds lineage information for arbitrary non-asset metadata.
+
+ This class represents additional lineage context captured during a
hook execution that is not
+ associated with a specific asset. It includes the metadata payload
itself, the count of
+ how many times it has been encountered, and the context in which it
was encountered.
+ """
+
+ key: str
+ value: Any
+ count: int
+ context: LineageContext
+
+ @attr.define
+ class HookLineage(_BaseHookLineage):
+ # mypy is not happy, as base class is using other ExtraLineageInfo,
but this code will never
+ # run on AF3.2, where this other one is used, so this is fine - we can
ignore.
+ extra: list[ExtraLineageInfo] = attr.field(factory=list) # type:
ignore[assignment]
+
+ # Initialize extra tracking attributes on this collector instance
+ collector._extra = {}
+ collector._extra_counts = defaultdict(int)
+
+ # Overwrite the `collected_assets` property on a class
+ _original_collected_assets = collector.__class__.collected_assets
+
+ def _compat_collected_assets(self) -> HookLineage:
+ """Get the collected hook lineage information."""
+ # Defensive check since we patch the class property, but initialized
_extra only on this instance.
+ if not hasattr(self, "_extra"):
+ self._extra = {}
+ if not hasattr(self, "_extra_counts"):
+ self._extra_counts = defaultdict(int)
+
+ # call the original `collected_assets` getter
+ lineage = _original_collected_assets.fget(self)
+ extra_list = [
+ ExtraLineageInfo(
+ key=key,
+ value=value,
+ count=self._extra_counts[count_key],
+ context=context,
+ )
+ for count_key, (key, value, context) in self._extra.items()
+ ]
+ return HookLineage(
+ inputs=lineage.inputs,
+ outputs=lineage.outputs,
+ extra=extra_list,
+ )
+
+ type(collector).collected_assets = property(_compat_collected_assets)
+
+ # Overwrite the `has_collected` property on a class
+ _original_has_collected = collector.__class__.has_collected
+
+ def _compat_has_collected(self) -> bool:
+ # Defensive check since we patch the class property, but initialized
_extra only on this instance.
+ if not hasattr(self, "_extra"):
+ self._extra = {}
+ # call the original `has_collected` getter
+ has_collected = _original_has_collected.fget(self)
+ return bool(has_collected or self._extra)
+
+ type(collector).has_collected = property(_compat_has_collected)
+
+ # Add `add_extra` method on the class
+ def _compat_add_extra(self, context, key, value):
+ """Add extra information for older Airflow versions."""
+ _max_collected_extra = 200
+
+ if len(self._extra) >= _max_collected_extra:
+ if hasattr(self, "log"):
+ self.log.debug("Maximum number of extra exceeded. Skipping.")
+ return
+
+ if not key or not value:
+ if hasattr(self, "log"):
+ self.log.debug("Missing required parameter: both 'key' and
'value' must be provided.")
+ return
+
+ # Defensive check since we patch the class property, but initialized
_extra only on this instance.
+ if not hasattr(self, "_extra"):
+ self._extra = {}
+ if not hasattr(self, "_extra_counts"):
+ self._extra_counts = defaultdict(int)
+
+ extra_str = json.dumps(value, sort_keys=True, default=str)
+ value_hash = hashlib.md5(extra_str.encode()).hexdigest()
+ entry_id = f"{key}_{value_hash}_{id(context)}"
+ if entry_id not in self._extra:
+ self._extra[entry_id] = (key, value, context)
+ self._extra_counts[entry_id] += 1
+
+ if len(self._extra) == _max_collected_extra:
+ if hasattr(self, "log"):
+ self.log.warning("Maximum number of extra exceeded. Skipping
subsequent inputs.")
+
+ type(collector).add_extra = _compat_add_extra
+ return collector
+
+
+def _add_asset_naming_compatibility_layer(collector):
+ """
+ Handle AF 2.x compatibility for dataset -> asset terminology rename.
+
+ This is only called for AF 2.x where we need to provide asset-named methods
+ that wrap the underlying dataset methods.
+ """
from functools import wraps
from airflow.lineage.hook import DatasetLineageInfo, HookLineage
@@ -55,9 +199,9 @@ def _get_asset_compat_hook_lineage_collector():
collector.add_input_asset =
rename_asset_kwargs_to_dataset_kwargs(collector.add_input_dataset)
collector.add_output_asset =
rename_asset_kwargs_to_dataset_kwargs(collector.add_output_dataset)
- def collected_assets_compat(collector) -> HookLineage:
+ def _compat_collected_assets(self) -> HookLineage:
"""Get the collected hook lineage information."""
- lineage = collector.collected_datasets
+ lineage = self.collected_datasets
return HookLineage(
[
DatasetLineageInfo(dataset=item.dataset, count=item.count,
context=item.context)
@@ -69,20 +213,30 @@ def _get_asset_compat_hook_lineage_collector():
],
)
- setattr(
- collector.__class__,
- "collected_assets",
- property(lambda collector: collected_assets_compat(collector)),
- )
-
+ type(collector).collected_assets = property(_compat_collected_assets)
return collector
def get_hook_lineage_collector():
- # Dataset has been renamed as Asset in 3.0
- if AIRFLOW_V_3_0_PLUS:
- from airflow.lineage.hook import get_hook_lineage_collector
+ """
+ Return a hook lineage collector with all required compatibility layers
applied.
+
+ Compatibility is determined by inspecting the collector's available
methods and
+ properties (duck typing), rather than relying on the Airflow version
number.
+
+ Behavior by example:
+ Airflow 2: Collector is missing asset-based methods and `add_extra` -
apply both layers.
+ Airflow 3.0–3.1: Collector has asset-based methods but lacks `add_extra` -
apply single layer.
+ Airflow 3.2+: Collector has asset-based methods and `add_extra` support -
no action required.
+ """
+ from airflow.lineage.hook import get_hook_lineage_collector as
get_global_collector
+
+ global_collector = get_global_collector()
+
+ if _lacks_asset_methods(global_collector):
+ global_collector =
_add_asset_naming_compatibility_layer(global_collector)
- return get_hook_lineage_collector()
+ if _lacks_add_extra_method(global_collector):
+ global_collector = _add_extra_polyfill(global_collector)
- return _get_asset_compat_hook_lineage_collector()
+ return global_collector
diff --git
a/providers/common/compat/tests/unit/common/compat/lineage/test_hook.py
b/providers/common/compat/tests/unit/common/compat/lineage/test_hook.py
index 7b4759af844..9045512531e 100644
--- a/providers/common/compat/tests/unit/common/compat/lineage/test_hook.py
+++ b/providers/common/compat/tests/unit/common/compat/lineage/test_hook.py
@@ -16,37 +16,982 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
import pytest
-from airflow.providers.common.compat.lineage.hook import
get_hook_lineage_collector
+from airflow.providers.common.compat.lineage.hook import
_lacks_add_extra_method, _lacks_asset_methods
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-def test_that_compat_does_not_raise():
[email protected]
+def collector():
+ from airflow.lineage.hook import HookLineageCollector
+
+ # Patch the "inner" function that the compat version will call
+ with mock.patch(
+ "airflow.lineage.hook.get_hook_lineage_collector",
+ return_value=HookLineageCollector(),
+ ):
+ from airflow.providers.common.compat.lineage.hook import
get_hook_lineage_collector
+
+ yield get_hook_lineage_collector()
+
+
[email protected]
+def noop_collector():
+ from airflow.lineage.hook import NoOpCollector
+
+ # Patch the "inner" function that the compat version will call
+ with mock.patch(
+ "airflow.lineage.hook.get_hook_lineage_collector",
+ return_value=NoOpCollector(),
+ ):
+ from airflow.providers.common.compat.lineage.hook import
get_hook_lineage_collector
+
+ yield get_hook_lineage_collector()
+
+
[email protected](params=["collector", "noop_collector"])
+def any_collector(request):
+ return request.getfixturevalue(request.param)
+
+
+def test_lacks_asset_methods_all_present():
+ class Collector:
+ def add_input_asset(self):
+ pass
+
+ def add_output_asset(self):
+ pass
+
+ @property
+ def collected_assets(self):
+ return "<HookLineage object usually>"
+
+ def create_asset(self):
+ pass
+
+ assert _lacks_asset_methods(Collector()) is False
+
+
+def test_lacks_asset_methods_missing_few():
+ class Collector:
+ def add_input_asset(self):
+ pass
+
+ @property
+ def collected_assets(self):
+ return "<HookLineage object usually>"
+
+ assert _lacks_asset_methods(Collector()) is True
+
+
+def test_lacks_asset_methods_none_present():
+ class Collector:
+ def add_input_dataset(self):
+ pass
+
+ def add_output_dataset(self):
+ pass
+
+ assert _lacks_asset_methods(Collector()) is True
+
+
+def test_lacks_add_extra_method_present():
+ class Collector:
+ def __init__(self):
+ self._extra = {}
+
+ def add_extra(self):
+ pass
+
+ assert _lacks_add_extra_method(Collector()) is False
+
+
+def test_lacks_add_extra_method_missing():
+ class Collector:
+ pass
+
+ assert _lacks_add_extra_method(Collector()) is True
+
+
+def test_retrieval_does_not_raise(): # do not use fixture here
+ from airflow.providers.common.compat.lineage.hook import
get_hook_lineage_collector
+
# On compat tests this goes into ImportError code path
assert get_hook_lineage_collector() is not None
assert get_hook_lineage_collector() is not None
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3.0+")
-def test_compat_has_only_asset_methods():
- hook_lienage_collector = get_hook_lineage_collector()
+def test_global_collector_is_reused(): # do not use fixture here
+ from airflow.providers.common.compat.lineage.hook import
get_hook_lineage_collector
+
+ c1 = get_hook_lineage_collector()
+ c2 = get_hook_lineage_collector()
+
+ assert c1 is c2
+
- assert hook_lienage_collector.add_input_asset is not None
- assert hook_lienage_collector.add_output_asset is not None
+def test_all_required_methods_exist(any_collector):
+ """Test that all required methods exist regardless of version."""
+ # Core methods that should always exist
+ assert hasattr(any_collector, "add_input_asset")
+ assert hasattr(any_collector, "add_output_asset")
+ assert hasattr(any_collector, "add_extra")
+ assert hasattr(any_collector, "collected_assets")
+ assert hasattr(any_collector, "create_asset")
+
+ # Verify they're callable
+ assert callable(any_collector.add_input_asset)
+ assert callable(any_collector.add_output_asset)
+ assert callable(any_collector.add_extra)
+ assert callable(any_collector.create_asset)
+
+
+def test_empty_collector(any_collector):
+ """Test that empty collector returns empty lineage."""
+ lineage = any_collector.collected_assets
+
+ assert lineage.inputs == []
+ assert lineage.outputs == []
+ assert lineage.extra == []
+
+
[email protected](AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow < 3.0")
+def test_af2_collector_has_dataset_methods(any_collector):
+ """Test that AF 2.x also has dataset methods."""
+
+ assert hasattr(any_collector, "add_input_dataset")
+ assert hasattr(any_collector, "add_output_dataset")
+ assert hasattr(any_collector, "collected_datasets")
+ assert hasattr(any_collector, "create_dataset")
+
+ assert callable(any_collector.add_input_dataset)
+ assert callable(any_collector.add_output_dataset)
+ assert callable(any_collector.create_dataset)
+
+
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3.0+")
+def test_af3_collector_do_not_have_dataset_methods(any_collector):
+ with pytest.raises(AttributeError):
+ any_collector.add_input_dataset
+ with pytest.raises(AttributeError):
+ any_collector.add_output_dataset
with pytest.raises(AttributeError):
- hook_lienage_collector.add_input_dataset
+ any_collector.collected_datasets
with pytest.raises(AttributeError):
- hook_lienage_collector.add_output_dataset
+ any_collector.create_dataset
[email protected](AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow < 3.0")
-def test_compat_has_asset_and_dataset_methods():
- hook_lienage_collector = get_hook_lineage_collector()
+class TestCollectorAddExtra:
+ def test_add_extra_basic_functionality(self, collector):
+ """Test basic add_extra functionality."""
+ mock_context = mock.MagicMock()
+ collector.add_extra(mock_context, "test_key", {"data": "value"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert hasattr(lineage, "extra")
+ assert len(lineage.extra) == 1
+ assert lineage.extra[0].key == "test_key"
+ assert lineage.extra[0].value == {"data": "value"}
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context
+
+ def test_add_extra_count_tracking(self, collector):
+ """Test that duplicate extra entries are counted correctly."""
+ mock_context = mock.MagicMock()
+
+ # Add same extra multiple times
+ collector.add_extra(mock_context, "test_key", {"data": "value"})
+ collector.add_extra(mock_context, "test_key", {"data": "value"})
+ collector.add_extra(mock_context, "test_key", {"data": "value"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 1
+ assert lineage.extra[0].count == 3
+
+ def test_add_extra_different_values(self, collector):
+ """Test that different values are tracked separately."""
+ mock_context = mock.MagicMock()
+
+ # Add different values
+ collector.add_extra(mock_context, "key1", {"data": "value1"})
+ collector.add_extra(mock_context, "key2", {"data": "value2"})
+ collector.add_extra(mock_context, "key1", {"data": "value3"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 3
+ assert lineage.extra[0].key == "key1"
+ assert lineage.extra[0].value == {"data": "value1"}
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context
+ assert lineage.extra[1].key == "key2"
+ assert lineage.extra[1].value == {"data": "value2"}
+ assert lineage.extra[1].count == 1
+ assert lineage.extra[1].context == mock_context
+ assert lineage.extra[2].key == "key1"
+ assert lineage.extra[2].value == {"data": "value3"}
+ assert lineage.extra[2].count == 1
+ assert lineage.extra[2].context == mock_context
+
+ def test_add_extra_different_contexts(self, collector):
+ """Test that different contexts are tracked separately."""
+ mock_context1 = mock.MagicMock()
+ mock_context2 = mock.MagicMock()
+
+ # Add same key/value with different contexts
+ collector.add_extra(mock_context1, "test_key", {"data": "value"})
+ collector.add_extra(mock_context2, "test_key", {"data": "value"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert lineage.extra[0].key == "test_key"
+ assert lineage.extra[0].value == {"data": "value"}
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context1
+ assert lineage.extra[1].key == "test_key"
+ assert lineage.extra[1].value == {"data": "value"}
+ assert lineage.extra[1].count == 1
+ assert lineage.extra[1].context == mock_context2
+
+ def test_add_extra_missing_key(self, collector):
+ """Test that add_extra handles missing key gracefully."""
+ mock_context = mock.MagicMock()
+
+ # Try to add with empty key
+ collector.add_extra(mock_context, "", {"data": "value"})
+ collector.add_extra(mock_context, None, {"data": "value"})
+
+ assert not collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 0
+
+ def test_add_extra_missing_value(self, collector):
+ """Test that add_extra handles missing value gracefully."""
+ mock_context = mock.MagicMock()
+
+ # Try to add with empty/None value
+ collector.add_extra(mock_context, "key", "")
+ collector.add_extra(mock_context, "key", None)
+
+ assert not collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 0
+
+ def test_add_extra_max_limit(self, collector):
+ """Test that add_extra respects maximum limit."""
+ mock_context = mock.MagicMock()
+ max_limit = 200
+
+ # Add more than max allowed
+ for i in range(max_limit + 10):
+ collector.add_extra(mock_context, f"key_{i}", {"data":
f"value_{i}"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == max_limit
+
+ def test_add_extra_complex_values(self, collector):
+ """Test that add_extra handles complex JSON-serializable values."""
+ mock_context = mock.MagicMock()
+
+ # Add various complex types
+ collector.add_extra(mock_context, "dict", {"nested": {"data":
"value"}})
+ collector.add_extra(mock_context, "list", [1, 2, 3, "test"])
+ collector.add_extra(mock_context, "number", 42)
+ collector.add_extra(mock_context, "string", "simple string")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 4
+ assert lineage.extra[0].key == "dict"
+ assert lineage.extra[0].value == {"nested": {"data": "value"}}
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context
+ assert lineage.extra[1].key == "list"
+ assert lineage.extra[1].value == [1, 2, 3, "test"]
+ assert lineage.extra[1].count == 1
+ assert lineage.extra[1].context == mock_context
+ assert lineage.extra[2].key == "number"
+ assert lineage.extra[2].value == 42
+ assert lineage.extra[2].count == 1
+ assert lineage.extra[2].context == mock_context
+ assert lineage.extra[3].key == "string"
+ assert lineage.extra[3].value == "simple string"
+ assert lineage.extra[3].count == 1
+ assert lineage.extra[3].context == mock_context
+
+
+class TestCollectorAddAssets:
+ def test_add_asset_basic_functionality(self, collector):
+ """Test basic add_input_asset and add_output_asset functionality."""
+ mock_context = mock.MagicMock()
+
+ collector.add_input_asset(mock_context, uri="s3://bucket/input-file")
+ collector.add_output_asset(mock_context, uri="s3://bucket/output-file")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri == "s3://bucket/input-file"
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].asset.uri == "s3://bucket/output-file"
+ assert lineage.outputs[0].count == 1
+ assert lineage.outputs[0].context == mock_context
+
+ def test_add_asset_count_tracking(self, collector):
+ """Test that duplicate assets are counted correctly."""
+ mock_context = mock.MagicMock()
+
+ # Add same input multiple times
+ collector.add_input_asset(mock_context, uri="s3://bucket/input")
+ collector.add_input_asset(mock_context, uri="s3://bucket/input")
+ collector.add_input_asset(mock_context, uri="s3://bucket/input")
+
+ # Add same output multiple times
+ collector.add_output_asset(mock_context, uri="s3://bucket/output")
+ collector.add_output_asset(mock_context, uri="s3://bucket/output")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri == "s3://bucket/input"
+ assert lineage.inputs[0].count == 3
+ assert lineage.inputs[0].context == mock_context
+
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].asset.uri == "s3://bucket/output"
+ assert lineage.outputs[0].count == 2
+ assert lineage.outputs[0].context == mock_context
+
+ def test_add_asset_different_uris(self, collector):
+ """Test that different URIs are tracked separately."""
+ mock_context = mock.MagicMock()
+
+ # Add different input URIs
+ collector.add_input_asset(mock_context, uri="s3://bucket/file1")
+ collector.add_input_asset(mock_context, uri="s3://bucket/file2")
+ collector.add_input_asset(mock_context,
uri="postgres://example.com:5432/database/default/table")
+
+ # Add different output URIs
+ collector.add_output_asset(mock_context, uri="s3://output/file1")
+ collector.add_output_asset(mock_context, uri="s3://output/file2")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 3
+ assert lineage.inputs[0].asset.uri == "s3://bucket/file1"
+ assert lineage.inputs[1].asset.uri == "s3://bucket/file2"
+ assert lineage.inputs[2].asset.uri ==
"postgres://example.com:5432/database/default/table"
+
+ assert len(lineage.outputs) == 2
+ assert lineage.outputs[0].asset.uri == "s3://output/file1"
+ assert lineage.outputs[1].asset.uri == "s3://output/file2"
+
+ def test_add_asset_different_contexts(self, collector):
+ """Test that different contexts are tracked separately."""
+ mock_context1 = mock.MagicMock()
+ mock_context2 = mock.MagicMock()
+
+ # Add same URI with different contexts
+ collector.add_input_asset(mock_context1, uri="s3://bucket/file")
+ collector.add_input_asset(mock_context2, uri="s3://bucket/file")
+
+ collector.add_output_asset(mock_context1, uri="s3://output/file")
+ collector.add_output_asset(mock_context2, uri="s3://output/file")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 2
+ assert lineage.inputs[0].asset.uri == "s3://bucket/file"
+ assert lineage.inputs[0].context == mock_context1
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[1].asset.uri == "s3://bucket/file"
+ assert lineage.inputs[1].context == mock_context2
+ assert lineage.inputs[1].count == 1
+
+ assert len(lineage.outputs) == 2
+ assert lineage.outputs[0].asset.uri == "s3://output/file"
+ assert lineage.outputs[0].context == mock_context1
+ assert lineage.outputs[0].count == 1
+ assert lineage.outputs[1].asset.uri == "s3://output/file"
+ assert lineage.outputs[1].context == mock_context2
+ assert lineage.outputs[1].count == 1
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3.0+")
+ def test_add_asset_with_name_and_group(self, collector):
+ """Test adding assets with name and group parameters."""
+ mock_context = mock.MagicMock()
+
+ collector.add_input_asset(mock_context, uri="s3://bucket/file",
name="my-input", group="input-group")
+ collector.add_output_asset(
+ mock_context, uri="s3://output/file", name="my-output",
group="output-group"
+ )
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri == "s3://bucket/file"
+ assert lineage.inputs[0].asset.name == "my-input"
+ assert lineage.inputs[0].asset.group == "input-group"
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].asset.uri == "s3://output/file"
+ assert lineage.outputs[0].asset.name == "my-output"
+ assert lineage.outputs[0].asset.group == "output-group"
+ assert lineage.outputs[0].count == 1
+ assert lineage.outputs[0].context == mock_context
+
+ def test_add_asset_with_extra_metadata(self, collector):
+ """Test adding assets with extra metadata."""
+ mock_context = mock.MagicMock()
+
+ collector.add_input_asset(
+ mock_context,
+ uri="postgres://example.com:5432/database/default/table",
+ asset_extra={"schema": "public", "table": "users"},
+ )
+ collector.add_output_asset(
+ mock_context,
+ uri="postgres://example.com:5432/database/default/table",
+ asset_extra={"schema": "public", "table": "results"},
+ )
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri ==
"postgres://example.com:5432/database/default/table"
+ assert lineage.inputs[0].asset.extra == {"schema": "public", "table":
"users"}
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].asset.uri ==
"postgres://example.com:5432/database/default/table"
+ assert lineage.outputs[0].asset.extra == {"schema": "public", "table":
"results"}
+ assert lineage.outputs[0].count == 1
+ assert lineage.outputs[0].context == mock_context
+
+ def test_add_asset_different_extra_values(self, collector):
+ """Test that assets with different extra values are tracked
separately."""
+ mock_context = mock.MagicMock()
+
+ # Same URI but different extra metadata
+ collector.add_input_asset(mock_context, uri="s3://bucket/file",
asset_extra={"version": "1"})
+ collector.add_input_asset(mock_context, uri="s3://bucket/file",
asset_extra={"version": "2"})
+
+ collector.add_output_asset(mock_context, uri="s3://output/file",
asset_extra={"format": "parquet"})
+ collector.add_output_asset(mock_context, uri="s3://output/file",
asset_extra={"format": "csv"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 2
+ assert lineage.inputs[0].asset.uri == "s3://bucket/file"
+ assert lineage.inputs[0].asset.extra == {"version": "1"}
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+ assert lineage.inputs[1].asset.uri == "s3://bucket/file"
+ assert lineage.inputs[1].asset.extra == {"version": "2"}
+ assert lineage.inputs[1].count == 1
+ assert lineage.inputs[1].context == mock_context
+
+ assert len(lineage.outputs) == 2
+ assert lineage.outputs[0].asset.uri == "s3://output/file"
+ assert lineage.outputs[0].asset.extra == {"format": "parquet"}
+ assert lineage.outputs[0].count == 1
+ assert lineage.outputs[0].context == mock_context
+ assert lineage.outputs[1].asset.uri == "s3://output/file"
+ assert lineage.outputs[1].asset.extra == {"format": "csv"}
+ assert lineage.outputs[1].count == 1
+ assert lineage.outputs[1].context == mock_context
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3.0+")
+ def test_add_asset_max_limit_af3(self, collector):
+ """Test that asset operations respect maximum limit."""
+ mock_context = mock.MagicMock()
+ max_limit = 100
+ added_assets = max_limit + 50
+
+ # Limitation on collected assets was added in AF3 #45798
+ expected_number = max_limit
+
+ # Add more than max allowed inputs
+ for i in range(added_assets):
+ collector.add_input_asset(mock_context,
uri=f"s3://bucket/input-{i}")
+
+ # Add more than max allowed outputs
+ for i in range(added_assets):
+ collector.add_output_asset(mock_context,
uri=f"s3://bucket/output-{i}")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == expected_number
+ assert len(lineage.outputs) == expected_number
+
+ @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test requires < Airflow
3.0")
+ def test_add_asset_max_limit_af2(self, collector):
+ """Test that asset operations do not respect maximum limit."""
+ mock_context = mock.MagicMock()
+ max_limit = 100
+ added_assets = max_limit + 50
+
+ # Limitation on collected assets was added in AF3 #45798
+ expected_number = added_assets
+
+ # Add more than max allowed inputs
+ for i in range(added_assets):
+ collector.add_input_asset(mock_context,
uri=f"s3://bucket/input-{i}")
+
+ # Add more than max allowed outputs
+ for i in range(added_assets):
+ collector.add_output_asset(mock_context,
uri=f"s3://bucket/output-{i}")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == expected_number
+ assert len(lineage.outputs) == expected_number
+
+
+class TestEdgeCases:
+ """Test edge cases and error conditions to ensure collector never fails."""
+
+ @pytest.mark.parametrize("uri", ["", None])
+ def test_invalid_uri_none(self, collector, uri):
+ """Test handling of None URI - should not raise."""
+ mock_context = mock.MagicMock()
+
+ # Should not raise exceptions
+ collector.add_input_asset(mock_context, uri=uri)
+ collector.add_output_asset(mock_context, uri=uri)
+
+ # Collector should handle gracefully and not collect invalid URIs
+ assert not collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 0
+ assert len(lineage.outputs) == 0
+ assert len(lineage.extra) == 0
+
+ def test_malformed_uri(self, collector):
+ """Test handling of malformed URIs - should not raise."""
+ mock_context = mock.MagicMock()
+
+ # Various malformed URIs should not cause crashes
+ collector.add_input_asset(mock_context, uri="not-a-valid-uri")
+ collector.add_input_asset(mock_context, uri="://missing-scheme")
+ collector.add_input_asset(mock_context, uri="scheme:")
+ collector.add_output_asset(mock_context, uri="//no-scheme")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 3
+ assert lineage.inputs[0].asset.uri == "not-a-valid-uri"
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+ assert lineage.inputs[1].asset.uri == "://missing-scheme"
+ assert lineage.inputs[1].count == 1
+ assert lineage.inputs[1].context == mock_context
+ assert lineage.inputs[2].asset.uri == "scheme:/"
+ assert lineage.inputs[2].count == 1
+ assert lineage.inputs[2].context == mock_context
+
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].asset.uri == "//no-scheme"
+ assert lineage.outputs[0].count == 1
+ assert lineage.outputs[0].context == mock_context
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3.0+")
+ def test_very_long_uri(self, collector):
+ """Test handling of very long URIs - 1000 chars OK, 2000 chars raises
ValueError."""
+ mock_context = mock.MagicMock()
+
+ # Create very long URI (1000 chars - should work)
+ long_path = "a" * 1000
+ long_uri = f"s3://bucket/{long_path}"
+
+ # Create too long URI (2000 chars - should raise)
+ too_long_uri = f"s3://bucket/{long_path * 2}"
+
+ collector.add_input_asset(mock_context, uri=long_uri)
+
+ # Too long URI should raise ValueError
+ with pytest.raises(ValueError, match="Asset name cannot exceed"):
+ collector.add_output_asset(mock_context, uri=too_long_uri)
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri == long_uri
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+
+ assert len(lineage.outputs) == 0
+ assert len(lineage.extra) == 0
+
+ def test_none_context(self, collector):
+ """Test handling of None context - should not raise."""
+ # Should not raise exceptions
+ collector.add_input_asset(None, uri="s3://bucket/input")
+ collector.add_output_asset(None, uri="s3://bucket/output")
+ collector.add_extra(None, "key", "value")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri == "s3://bucket/input"
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context is None
+
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].asset.uri == "s3://bucket/output"
+ assert lineage.outputs[0].count == 1
+ assert lineage.outputs[0].context is None
+
+ assert len(lineage.extra) == 1
+ assert lineage.extra[0].key == "key"
+ assert lineage.extra[0].value == "value"
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context is None
+
+ def test_special_characters_in_extra_key(self, collector):
+ """Test that extra keys with special characters work."""
+ mock_context = mock.MagicMock()
+
+ collector.add_extra(mock_context, "key-with-dashes", {"data": "value"})
+ collector.add_extra(mock_context, "key.with.dots", {"data": "value"})
+ collector.add_extra(mock_context, "key_with_underscores", {"data":
"value"})
+ collector.add_extra(mock_context, "key/with/slashes", {"data":
"value"})
+ collector.add_extra(mock_context, "key:with:colons", {"data": "value"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 5
+ assert lineage.extra[0].key == "key-with-dashes"
+ assert lineage.extra[1].key == "key.with.dots"
+ assert lineage.extra[2].key == "key_with_underscores"
+ assert lineage.extra[3].key == "key/with/slashes"
+ assert lineage.extra[4].key == "key:with:colons"
+
+ def test_unicode_in_extra_key_and_value(self, collector):
+ """Test that unicode characters in extra work correctly."""
+ mock_context = mock.MagicMock()
+
+ collector.add_extra(mock_context, "clé_française", {"données":
"valeur"})
+ collector.add_extra(mock_context, "中文键", {"中文": "值"})
+ collector.add_extra(mock_context, "مفتاح", {"بيانات": "قيمة"})
+ collector.add_extra(mock_context, "emoji_🚀", {"status": "✅"})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 4
+ assert lineage.extra[0].key == "clé_française"
+ assert lineage.extra[0].value == {"données": "valeur"}
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context
+
+ assert lineage.extra[1].key == "中文键"
+ assert lineage.extra[1].value == {"中文": "值"}
+ assert lineage.extra[1].count == 1
+ assert lineage.extra[1].context == mock_context
+
+ assert lineage.extra[2].key == "مفتاح"
+ assert lineage.extra[2].value == {"بيانات": "قيمة"}
+ assert lineage.extra[2].count == 1
+ assert lineage.extra[2].context == mock_context
+
+ assert lineage.extra[3].key == "emoji_🚀"
+ assert lineage.extra[3].value == {"status": "✅"}
+ assert lineage.extra[3].count == 1
+ assert lineage.extra[3].context == mock_context
+
+ def test_very_large_extra_value(self, collector):
+ """Test that large extra values are handled."""
+ mock_context = mock.MagicMock()
+
+ # Create a large value
+ large_value = {"data": "x" * 10000, "list": list(range(1000))}
+
+ collector.add_extra(mock_context, "large_key", large_value)
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 1
+ assert lineage.extra[0].key == "large_key"
+ assert lineage.extra[0].value == large_value
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context
+
+ assert len(lineage.inputs) == 0
+ assert len(lineage.outputs) == 0
+
+ def test_deeply_nested_extra_value(self, collector):
+ """Test that deeply nested data structures in extra are handled."""
+ mock_context = mock.MagicMock()
+
+ # Create deeply nested structure
+ nested_value = {"level1": {"level2": {"level3": {"level4": {"level5":
{"data": "deep"}}}}}}
+
+ collector.add_extra(mock_context, "nested", nested_value)
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 1
+ assert lineage.extra[0].key == "nested"
+ assert lineage.extra[0].value == nested_value
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context
+
+ assert len(lineage.inputs) == 0
+ assert len(lineage.outputs) == 0
+
+ def test_extra_value_with_various_types(self, collector):
+ """Test that extra can handle various data types."""
+ mock_context = mock.MagicMock()
+
+ collector.add_extra(mock_context, "string", "text")
+ collector.add_extra(mock_context, "integer", 42)
+ collector.add_extra(mock_context, "float", 3.14)
+ collector.add_extra(mock_context, "boolean", True)
+ collector.add_extra(mock_context, "list", [1, 2, 3])
+ collector.add_extra(mock_context, "dict", {"key": "value"})
+ collector.add_extra(mock_context, "null", None)
+
+ assert collector.has_collected
+
+ # None value should not be collected (based on validation)
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 6 # None is filtered out
+
+ assert lineage.extra[0].key == "string"
+ assert lineage.extra[0].value == "text"
+ assert lineage.extra[0].count == 1
+
+ assert lineage.extra[1].key == "integer"
+ assert lineage.extra[1].value == 42
+ assert lineage.extra[1].count == 1
+
+ assert lineage.extra[2].key == "float"
+ assert lineage.extra[2].value == 3.14
+ assert lineage.extra[2].count == 1
+
+ assert lineage.extra[3].key == "boolean"
+ assert lineage.extra[3].value is True
+ assert lineage.extra[3].count == 1
+
+ assert lineage.extra[4].key == "list"
+ assert lineage.extra[4].value == [1, 2, 3]
+ assert lineage.extra[4].count == 1
+
+ assert lineage.extra[5].key == "dict"
+ assert lineage.extra[5].value == {"key": "value"}
+ assert lineage.extra[5].count == 1
+
+ assert len(lineage.inputs) == 0
+ assert len(lineage.outputs) == 0
+
+ def test_non_json_serializable_value_in_extra(self, collector):
+ """Test that non-JSON-serializable values are handled gracefully."""
+ mock_context = mock.MagicMock()
+
+ # Create a non-serializable object
+ class CustomObject:
+ def __str__(self):
+ return "custom_object"
+
+ # Should not raise - collector should handle via str conversion or skip
+ collector.add_extra(mock_context, "custom_key", CustomObject())
+
+ # May or may not be collected depending on implementation
+ lineage = collector.collected_assets
+ # Just verify it doesn't crash
+ assert isinstance(lineage.extra, list)
+ assert len(lineage.inputs) == 0
+ assert len(lineage.outputs) == 0
+
+ def test_empty_asset_extra(self, collector):
+ """Test that empty asset_extra is handled correctly."""
+ mock_context = mock.MagicMock()
+
+ collector.add_input_asset(mock_context, uri="s3://bucket/file",
asset_extra={})
+ collector.add_output_asset(mock_context, uri="s3://bucket/file",
asset_extra={})
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.extra == {}
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].asset.extra == {}
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3.0+")
+ def test_asset_with_all_optional_parameters(self, collector):
+ """Test asset creation with all optional parameters provided."""
+ mock_context = mock.MagicMock()
+
+ collector.add_input_asset(
+ mock_context,
+ uri="s3://bucket/file",
+ name="custom-name",
+ group="custom-group",
+ asset_extra={"key": "value"},
+ )
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri == "s3://bucket/file"
+ assert lineage.inputs[0].asset.name == "custom-name"
+ assert lineage.inputs[0].asset.group == "custom-group"
+ assert lineage.inputs[0].asset.extra == {"key": "value"}
+
+ def test_rapid_repeated_calls(self, collector):
+ """Test that rapid repeated calls don't cause issues."""
+ mock_context = mock.MagicMock()
+
+ # Simulate rapid repeated calls
+ for _ in range(50):
+ collector.add_input_asset(mock_context, uri="s3://bucket/file")
+ collector.add_output_asset(mock_context, uri="s3://bucket/output")
+ collector.add_extra(mock_context, "key", "value")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ # Should have counted properly
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].count == 50
+ assert len(lineage.outputs) == 1
+ assert lineage.outputs[0].count == 50
+ assert len(lineage.extra) == 1
+ assert lineage.extra[0].count == 50
+
+ def test_mixed_valid_invalid_operations(self, collector):
+ """Test mixing valid and invalid operations."""
+ mock_context = mock.MagicMock()
+
+ # Mix valid and invalid calls
+ collector.add_input_asset(mock_context, uri="s3://bucket/valid")
+ collector.add_input_asset(mock_context, uri=None) # Invalid - should
not be collected
+ collector.add_input_asset(mock_context, uri="") # Invalid - should
not be collected
+ collector.add_input_asset(mock_context,
uri="s3://bucket/another-valid")
+
+ collector.add_extra(mock_context, "valid_key", "valid_value")
+ collector.add_extra(mock_context, "", "invalid_key") # Invalid key -
should not be collected
+ collector.add_extra(mock_context, "another_key", "another_value")
+
+ assert collector.has_collected
+
+ # Should collect only valid items
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 2
+ assert lineage.inputs[0].asset.uri == "s3://bucket/valid"
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+ assert lineage.inputs[1].asset.uri == "s3://bucket/another-valid"
+ assert lineage.inputs[1].count == 1
+ assert lineage.inputs[1].context == mock_context
+
+ assert len(lineage.extra) == 2
+ assert lineage.extra[0].key == "valid_key"
+ assert lineage.extra[0].value == "valid_value"
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[1].key == "another_key"
+ assert lineage.extra[1].value == "another_value"
+ assert lineage.extra[1].count == 1
+
+ assert len(lineage.outputs) == 0
+
+ def test_collector_collected_assets_called_multiple_times(self, collector):
+ """Test that collected_assets property can be called multiple times."""
+ mock_context = mock.MagicMock()
+
+ collector.add_input_asset(mock_context, uri="s3://bucket/file")
+
+ assert collector.has_collected
+
+ # Call multiple times - should return same data
+ lineage1 = collector.collected_assets
+ lineage2 = collector.collected_assets
+ lineage3 = collector.collected_assets
+
+ assert lineage1.inputs == lineage2.inputs == lineage3.inputs
+ assert len(lineage1.inputs) == 1
+ assert lineage1.inputs[0].asset.uri == "s3://bucket/file"
+ assert lineage1.inputs[0].count == 1
+ assert lineage1.inputs[0].context == mock_context
+
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow
3.0+")
+ def test_empty_name_and_group(self, collector):
+ """Test that empty strings for name and group are handled."""
+ mock_context = mock.MagicMock()
+
+ # Empty strings for optional parameters
+ collector.add_input_asset(mock_context, uri="s3://bucket/file",
name="", group="")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.inputs) == 1
+ assert lineage.inputs[0].asset.uri == "s3://bucket/file"
+ assert lineage.inputs[0].asset.name == "s3://bucket/file"
+ assert lineage.inputs[0].asset.group == "asset"
+ assert lineage.inputs[0].count == 1
+ assert lineage.inputs[0].context == mock_context
+
+ assert len(lineage.outputs) == 0
+ assert len(lineage.extra) == 0
+
+ def test_extremely_long_extra_key(self, collector):
+ """Test that extremely long extra keys are handled."""
+ mock_context = mock.MagicMock()
+
+ long_key = "k" * 10000
+ collector.add_extra(mock_context, long_key, "value")
+
+ assert collector.has_collected
+
+ lineage = collector.collected_assets
+ assert len(lineage.extra) == 1
+ assert lineage.extra[0].key == long_key
+ assert lineage.extra[0].value == "value"
+ assert lineage.extra[0].count == 1
+ assert lineage.extra[0].context == mock_context
- assert hook_lienage_collector.add_input_asset is not None
- assert hook_lienage_collector.add_output_asset is not None
- assert hook_lienage_collector.add_input_dataset is not None
- assert hook_lienage_collector.add_output_dataset is not None
+ assert len(lineage.inputs) == 0
+ assert len(lineage.outputs) == 0