This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 072f3b95ad Make HookLineageCollector group datasets by (#41034)
072f3b95ad is described below

commit 072f3b95ad2de17fc871d319ac96ba77201ef801
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Mon Jul 29 11:59:00 2024 +0200

    Make HookLineageCollector group datasets by (#41034)
    
    URI + hash of extra + id of hook object.
    
    This is to avoid going out of memory when sending
    large numbers of dataset to the collector.
    
    
    
    Add `DatasetLineageInfo` class that is returned
    within `HookLineage`.
    
    It is used to store count and context information about datasets
    collected in hook lineage level.
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
---
 airflow/lineage/hook.py                     | 72 +++++++++++++++++++++----
 tests/io/test_path.py                       | 22 ++++----
 tests/io/test_wrapper.py                    | 16 +++---
 tests/lineage/test_hook.py                  | 81 ++++++++++++++++++++---------
 tests/providers/amazon/aws/hooks/test_s3.py | 14 ++---
 5 files changed, 148 insertions(+), 57 deletions(-)

diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py
index 744db3fb38..4ff35e4d9c 100644
--- a/airflow/lineage/hook.py
+++ b/airflow/lineage/hook.py
@@ -17,6 +17,9 @@
 # under the License.
 from __future__ import annotations
 
+import hashlib
+import json
+from collections import defaultdict
 from typing import TYPE_CHECKING, Union
 
 import attr
@@ -35,12 +38,32 @@ if TYPE_CHECKING:
 _hook_lineage_collector: HookLineageCollector | None = None
 
 
[email protected]
+class DatasetLineageInfo:
+    """
+    Holds lineage information for a single dataset.
+
+    This class represents the lineage information for a single dataset, 
including the dataset itself,
+    the count of how many times it has been encountered, and the context in 
which it was encountered.
+    """
+
+    dataset: Dataset
+    count: int
+    context: LineageContext
+
+
 @attr.define
 class HookLineage:
-    """Holds lineage collected by HookLineageCollector."""
+    """
+    Holds lineage collected by HookLineageCollector.
+
+    This class represents the lineage information collected by the 
`HookLineageCollector`. It stores
+    the input and output datasets, each with an associated count indicating 
how many times the dataset
+    has been encountered during the hook execution.
+    """
 
-    inputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list)
-    outputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list)
+    inputs: list[DatasetLineageInfo] = attr.ib(factory=list)
+    outputs: list[DatasetLineageInfo] = attr.ib(factory=list)
 
 
 class HookLineageCollector(LoggingMixin):
@@ -52,8 +75,24 @@ class HookLineageCollector(LoggingMixin):
 
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
-        self.inputs: list[tuple[Dataset, LineageContext]] = []
-        self.outputs: list[tuple[Dataset, LineageContext]] = []
+        # Dictionary to store input datasets, counted by unique key (dataset 
URI, MD5 hash of extra
+        # dictionary, and LineageContext's unique identifier)
+        self._inputs: dict[str, tuple[Dataset, LineageContext]] = {}
+        self._outputs: dict[str, tuple[Dataset, LineageContext]] = {}
+        self._input_counts: dict[str, int] = defaultdict(int)
+        self._output_counts: dict[str, int] = defaultdict(int)
+
+    def _generate_key(self, dataset: Dataset, context: LineageContext) -> str:
+        """
+        Generate a unique key for the given dataset and context.
+
+        This method creates a unique key by combining the dataset URI, the MD5 
hash of the dataset's extra
+        dictionary, and the LineageContext's unique identifier. This ensures 
that the generated key is
+        unique for each combination of dataset and context.
+        """
+        extra_str = json.dumps(dataset.extra, sort_keys=True)
+        extra_hash = hashlib.md5(extra_str.encode()).hexdigest()
+        return f"{dataset.uri}_{extra_hash}_{id(context)}"
 
     def create_dataset(
         self, scheme: str | None, uri: str | None, dataset_kwargs: dict | 
None, dataset_extra: dict | None
@@ -106,7 +145,10 @@ class HookLineageCollector(LoggingMixin):
             scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, 
dataset_extra=dataset_extra
         )
         if dataset:
-            self.inputs.append((dataset, context))
+            key = self._generate_key(dataset, context)
+            if key not in self._inputs:
+                self._inputs[key] = (dataset, context)
+            self._input_counts[key] += 1
 
     def add_output_dataset(
         self,
@@ -121,17 +163,29 @@ class HookLineageCollector(LoggingMixin):
             scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, 
dataset_extra=dataset_extra
         )
         if dataset:
-            self.outputs.append((dataset, context))
+            key = self._generate_key(dataset, context)
+            if key not in self._outputs:
+                self._outputs[key] = (dataset, context)
+            self._output_counts[key] += 1
 
     @property
     def collected_datasets(self) -> HookLineage:
         """Get the collected hook lineage information."""
-        return HookLineage(self.inputs, self.outputs)
+        return HookLineage(
+            [
+                DatasetLineageInfo(dataset=dataset, 
count=self._input_counts[key], context=context)
+                for key, (dataset, context) in self._inputs.items()
+            ],
+            [
+                DatasetLineageInfo(dataset=dataset, 
count=self._output_counts[key], context=context)
+                for key, (dataset, context) in self._outputs.items()
+            ],
+        )
 
     @property
     def has_collected(self) -> bool:
         """Check if any datasets have been collected."""
-        return len(self.inputs) != 0 or len(self.outputs) != 0
+        return len(self._inputs) != 0 or len(self._outputs) != 0
 
 
 class NoOpCollector(HookLineageCollector):
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index c9c5a99bc4..195c2423b1 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -280,10 +280,12 @@ class TestFs:
 
         _to.unlink()
 
-        assert len(hook_lineage_collector.collected_datasets.inputs) == 1
-        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
-        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(uri=_from_path)
-        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(uri=_to_path)
+        collected_datasets = hook_lineage_collector.collected_datasets
+
+        assert len(collected_datasets.inputs) == 1
+        assert len(collected_datasets.outputs) == 1
+        assert collected_datasets.inputs[0].dataset == Dataset(uri=_from_path)
+        assert collected_datasets.outputs[0].dataset == Dataset(uri=_to_path)
 
     def test_move_remote(self, hook_lineage_collector):
         attach("fakefs", fs=FakeRemoteFileSystem())
@@ -301,10 +303,12 @@ class TestFs:
 
         _to.unlink()
 
-        assert len(hook_lineage_collector.collected_datasets.inputs) == 1
-        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
-        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(uri=str(_from))
-        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(uri=str(_to))
+        collected_datasets = hook_lineage_collector.collected_datasets
+
+        assert len(collected_datasets.inputs) == 1
+        assert len(collected_datasets.outputs) == 1
+        assert collected_datasets.inputs[0].dataset == Dataset(uri=str(_from))
+        assert collected_datasets.outputs[0].dataset == Dataset(uri=str(_to))
 
     def test_copy_remote_remote(self, hook_lineage_collector):
         attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
@@ -335,7 +339,7 @@ class TestFs:
         _to.rmdir(recursive=True)
 
         assert len(hook_lineage_collector.collected_datasets.inputs) == 1
-        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(uri=str(_from_file))
+        assert hook_lineage_collector.collected_datasets.inputs[0].dataset == 
Dataset(uri=str(_from_file))
 
         # Empty file - shutil.copyfileobj does nothing
         assert len(hook_lineage_collector.collected_datasets.outputs) == 0
diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py
index dab0bee6ec..e00c5ab22b 100644
--- a/tests/io/test_wrapper.py
+++ b/tests/io/test_wrapper.py
@@ -32,8 +32,8 @@ def test_wrapper_catches_reads_writes(providers_manager, 
hook_lineage_collector)
     file.write("aaa")
     file.close()
 
-    assert len(hook_lineage_collector.outputs) == 1
-    assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)
+    assert len(hook_lineage_collector._outputs) == 1
+    assert next(iter(hook_lineage_collector._outputs.values()))[0] == 
Dataset(uri=uri)
 
     file = path.open("r")
     file.read()
@@ -41,8 +41,8 @@ def test_wrapper_catches_reads_writes(providers_manager, 
hook_lineage_collector)
 
     path.unlink(missing_ok=True)
 
-    assert len(hook_lineage_collector.inputs) == 1
-    assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)
+    assert len(hook_lineage_collector._inputs) == 1
+    assert next(iter(hook_lineage_collector._inputs.values()))[0] == 
Dataset(uri=uri)
 
 
 @patch("airflow.providers_manager.ProvidersManager")
@@ -53,12 +53,12 @@ def 
test_wrapper_works_with_contextmanager(providers_manager, hook_lineage_colle
     with path.open("w") as file:
         file.write("asdf")
 
-    assert len(hook_lineage_collector.outputs) == 1
-    assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)
+    assert len(hook_lineage_collector._outputs) == 1
+    assert next(iter(hook_lineage_collector._outputs.values()))[0] == 
Dataset(uri=uri)
 
     with path.open("r") as file:
         file.read()
     path.unlink(missing_ok=True)
 
-    assert len(hook_lineage_collector.inputs) == 1
-    assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)
+    assert len(hook_lineage_collector._inputs) == 1
+    assert next(iter(hook_lineage_collector._inputs.values()))[0] == 
Dataset(uri=uri)
diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py
index 15d4d6c1e4..97e160e1f9 100644
--- a/tests/lineage/test_hook.py
+++ b/tests/lineage/test_hook.py
@@ -26,6 +26,7 @@ from airflow.datasets import Dataset
 from airflow.hooks.base import BaseHook
 from airflow.lineage import hook
 from airflow.lineage.hook import (
+    DatasetLineageInfo,
     HookLineage,
     HookLineageCollector,
     HookLineageReader,
@@ -36,44 +37,72 @@ from tests.test_utils.mock_plugins import 
mock_plugin_manager
 
 
 class TestHookLineageCollector:
+    def setup_method(self):
+        self.collector = HookLineageCollector()
+
     def test_are_datasets_collected(self):
-        lineage_collector = HookLineageCollector()
-        assert lineage_collector is not None
-        assert lineage_collector.collected_datasets == HookLineage()
+        assert self.collector is not None
+        assert self.collector.collected_datasets == HookLineage()
         input_hook = BaseHook()
         output_hook = BaseHook()
-        lineage_collector.add_input_dataset(input_hook, 
uri="s3://in_bucket/file")
-        lineage_collector.add_output_dataset(
+        self.collector.add_input_dataset(input_hook, uri="s3://in_bucket/file")
+        self.collector.add_output_dataset(
             output_hook, 
uri="postgres://example.com:5432/database/default/table"
         )
-        assert lineage_collector.collected_datasets == HookLineage(
-            [(Dataset("s3://in_bucket/file"), input_hook)],
-            [(Dataset("postgres://example.com:5432/database/default/table"), 
output_hook)],
+        assert self.collector.collected_datasets == HookLineage(
+            [DatasetLineageInfo(dataset=Dataset("s3://in_bucket/file"), 
count=1, context=input_hook)],
+            [
+                DatasetLineageInfo(
+                    
dataset=Dataset("postgres://example.com:5432/database/default/table"),
+                    count=1,
+                    context=output_hook,
+                )
+            ],
         )
 
     @patch("airflow.lineage.hook.Dataset")
     def test_add_input_dataset(self, mock_dataset):
-        collector = HookLineageCollector()
-        dataset = MagicMock(spec=Dataset)
+        dataset = MagicMock(spec=Dataset, extra={})
         mock_dataset.return_value = dataset
 
         hook = MagicMock()
-        collector.add_input_dataset(hook, uri="test_uri")
+        self.collector.add_input_dataset(hook, uri="test_uri")
 
-        assert collector.inputs == [(dataset, hook)]
+        assert next(iter(self.collector._inputs.values())) == (dataset, hook)
         mock_dataset.assert_called_once_with(uri="test_uri", extra=None)
 
+    def test_grouping_datasets(self):
+        hook_1 = MagicMock()
+        hook_2 = MagicMock()
+
+        uri = "test://uri/"
+
+        self.collector.add_input_dataset(context=hook_1, uri=uri)
+        self.collector.add_input_dataset(context=hook_2, uri=uri)
+        self.collector.add_input_dataset(context=hook_1, uri=uri, 
dataset_extra={"key": "value"})
+
+        collected_inputs = self.collector.collected_datasets.inputs
+
+        assert len(collected_inputs) == 3
+        assert collected_inputs[0].dataset.uri == "test://uri/"
+        assert collected_inputs[0].dataset == collected_inputs[1].dataset
+        assert collected_inputs[0].count == 1
+        assert collected_inputs[0].context == collected_inputs[2].context == 
hook_1
+        assert collected_inputs[1].count == 1
+        assert collected_inputs[1].context == hook_2
+        assert collected_inputs[2].count == 1
+        assert collected_inputs[2].dataset.extra == {"key": "value"}
+
     @patch("airflow.lineage.hook.ProvidersManager")
     def test_create_dataset(self, mock_providers_manager):
         def create_dataset(arg1, arg2="default", extra=None):
             return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra)
 
         mock_providers_manager.return_value.dataset_factories = {"myscheme": 
create_dataset}
-        collector = HookLineageCollector()
-        assert collector.create_dataset(
+        assert self.collector.create_dataset(
             scheme="myscheme", uri=None, dataset_kwargs={"arg1": "value_1"}, 
dataset_extra=None
         ) == Dataset("myscheme://value_1/default")
-        assert collector.create_dataset(
+        assert self.collector.create_dataset(
             scheme="myscheme",
             uri=None,
             dataset_kwargs={"arg1": "value_1", "arg2": "value_2"},
@@ -81,21 +110,25 @@ class TestHookLineageCollector:
         ) == Dataset("myscheme://value_1/value_2", extra={"key": "value"})
 
     def test_collected_datasets(self):
-        collector = HookLineageCollector()
-        inputs = [(MagicMock()), MagicMock(spec=Dataset)]
-        outputs = [(MagicMock()), MagicMock(spec=Dataset)]
-        collector.inputs = inputs
-        collector.outputs = outputs
+        context_input = MagicMock()
+        context_output = MagicMock()
+
+        self.collector.add_input_dataset(context_input, uri="test://input")
+        self.collector.add_output_dataset(context_output, uri="test://output")
+
+        hook_lineage = self.collector.collected_datasets
+        assert len(hook_lineage.inputs) == 1
+        assert hook_lineage.inputs[0].dataset.uri == "test://input/"
+        assert hook_lineage.inputs[0].context == context_input
 
-        hook_lineage = collector.collected_datasets
-        assert hook_lineage.inputs == inputs
-        assert hook_lineage.outputs == outputs
+        assert len(hook_lineage.outputs) == 1
+        assert hook_lineage.outputs[0].dataset.uri == "test://output/"
 
     def test_has_collected(self):
         collector = HookLineageCollector()
         assert not collector.has_collected
 
-        collector.inputs = [MagicMock(spec=Dataset), MagicMock()]
+        collector._inputs = {"unique_key": (MagicMock(spec=Dataset), 
MagicMock())}
         assert collector.has_collected
 
 
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py 
b/tests/providers/amazon/aws/hooks/test_s3.py
index acedf3d011..9dade82004 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -395,7 +395,7 @@ class TestAwsS3Hook:
         hook = S3Hook()
         hook.load_string("Contént", "my_key", s3_bucket)
         assert len(hook_lineage_collector.collected_datasets.outputs) == 1
-        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+        assert hook_lineage_collector.collected_datasets.outputs[0].dataset == 
Dataset(
             uri=f"s3://{s3_bucket}/my_key"
         )
 
@@ -988,7 +988,7 @@ class TestAwsS3Hook:
         path.write_text("Content")
         hook.load_file(path, "my_key", s3_bucket)
         assert len(hook_lineage_collector.collected_datasets.outputs) == 1
-        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+        assert hook_lineage_collector.collected_datasets.outputs[0].dataset == 
Dataset(
             uri=f"s3://{s3_bucket}/my_key"
         )
 
@@ -1060,12 +1060,12 @@ class TestAwsS3Hook:
         ):
             mock_hook.copy_object("my_key", "my_key3", s3_bucket, s3_bucket)
             assert len(hook_lineage_collector.collected_datasets.inputs) == 1
-            assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(
+            assert hook_lineage_collector.collected_datasets.inputs[0].dataset 
== Dataset(
                 uri=f"s3://{s3_bucket}/my_key"
             )
 
             assert len(hook_lineage_collector.collected_datasets.outputs) == 1
-            assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+            assert 
hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset(
                 uri=f"s3://{s3_bucket}/my_key3"
             )
 
@@ -1198,7 +1198,7 @@ class TestAwsS3Hook:
         s3_hook.download_file(key=key, bucket_name=bucket)
 
         assert len(hook_lineage_collector.collected_datasets.inputs) == 1
-        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(
+        assert hook_lineage_collector.collected_datasets.inputs[0].dataset == 
Dataset(
             uri="s3://test_bucket/test_key"
         )
 
@@ -1250,12 +1250,12 @@ class TestAwsS3Hook:
         )
 
         assert len(hook_lineage_collector.collected_datasets.inputs) == 1
-        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(
+        assert hook_lineage_collector.collected_datasets.inputs[0].dataset == 
Dataset(
             uri="s3://test_bucket/test_key/test.log"
         )
 
         assert len(hook_lineage_collector.collected_datasets.outputs) == 1
-        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(
+        assert hook_lineage_collector.collected_datasets.outputs[0].dataset == 
Dataset(
             uri=f"file://{local_path}/test.log",
         )
 

Reply via email to