This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 072f3b95ad Make HookLineageCollector group datasets by (#41034)
072f3b95ad is described below
commit 072f3b95ad2de17fc871d319ac96ba77201ef801
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Mon Jul 29 11:59:00 2024 +0200
Make HookLineageCollector group datasets by (#41034)
URI + hash of extra + id of hook object.
This is to avoid going out of memory when sending
large numbers of dataset to the collector.
Add `DatasetLineageInfo` class that is returned
within `HookLineage`.
It is used to store count and context information about datasets
collected in hook lineage level.
Signed-off-by: Jakub Dardzinski <[email protected]>
---
airflow/lineage/hook.py | 72 +++++++++++++++++++++----
tests/io/test_path.py | 22 ++++----
tests/io/test_wrapper.py | 16 +++---
tests/lineage/test_hook.py | 81 ++++++++++++++++++++---------
tests/providers/amazon/aws/hooks/test_s3.py | 14 ++---
5 files changed, 148 insertions(+), 57 deletions(-)
diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py
index 744db3fb38..4ff35e4d9c 100644
--- a/airflow/lineage/hook.py
+++ b/airflow/lineage/hook.py
@@ -17,6 +17,9 @@
# under the License.
from __future__ import annotations
+import hashlib
+import json
+from collections import defaultdict
from typing import TYPE_CHECKING, Union
import attr
@@ -35,12 +38,32 @@ if TYPE_CHECKING:
_hook_lineage_collector: HookLineageCollector | None = None
[email protected]
+class DatasetLineageInfo:
+ """
+ Holds lineage information for a single dataset.
+
+ This class represents the lineage information for a single dataset,
including the dataset itself,
+ the count of how many times it has been encountered, and the context in
which it was encountered.
+ """
+
+ dataset: Dataset
+ count: int
+ context: LineageContext
+
+
@attr.define
class HookLineage:
- """Holds lineage collected by HookLineageCollector."""
+ """
+ Holds lineage collected by HookLineageCollector.
+
+ This class represents the lineage information collected by the
`HookLineageCollector`. It stores
+ the input and output datasets, each with an associated count indicating
how many times the dataset
+ has been encountered during the hook execution.
+ """
- inputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list)
- outputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list)
+ inputs: list[DatasetLineageInfo] = attr.ib(factory=list)
+ outputs: list[DatasetLineageInfo] = attr.ib(factory=list)
class HookLineageCollector(LoggingMixin):
@@ -52,8 +75,24 @@ class HookLineageCollector(LoggingMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.inputs: list[tuple[Dataset, LineageContext]] = []
- self.outputs: list[tuple[Dataset, LineageContext]] = []
+ # Dictionary to store input datasets, counted by unique key (dataset
URI, MD5 hash of extra
+ # dictionary, and LineageContext's unique identifier)
+ self._inputs: dict[str, tuple[Dataset, LineageContext]] = {}
+ self._outputs: dict[str, tuple[Dataset, LineageContext]] = {}
+ self._input_counts: dict[str, int] = defaultdict(int)
+ self._output_counts: dict[str, int] = defaultdict(int)
+
+ def _generate_key(self, dataset: Dataset, context: LineageContext) -> str:
+ """
+ Generate a unique key for the given dataset and context.
+
+ This method creates a unique key by combining the dataset URI, the MD5
hash of the dataset's extra
+ dictionary, and the LineageContext's unique identifier. This ensures
that the generated key is
+ unique for each combination of dataset and context.
+ """
+ extra_str = json.dumps(dataset.extra, sort_keys=True)
+ extra_hash = hashlib.md5(extra_str.encode()).hexdigest()
+ return f"{dataset.uri}_{extra_hash}_{id(context)}"
def create_dataset(
self, scheme: str | None, uri: str | None, dataset_kwargs: dict |
None, dataset_extra: dict | None
@@ -106,7 +145,10 @@ class HookLineageCollector(LoggingMixin):
scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs,
dataset_extra=dataset_extra
)
if dataset:
- self.inputs.append((dataset, context))
+ key = self._generate_key(dataset, context)
+ if key not in self._inputs:
+ self._inputs[key] = (dataset, context)
+ self._input_counts[key] += 1
def add_output_dataset(
self,
@@ -121,17 +163,29 @@ class HookLineageCollector(LoggingMixin):
scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs,
dataset_extra=dataset_extra
)
if dataset:
- self.outputs.append((dataset, context))
+ key = self._generate_key(dataset, context)
+ if key not in self._outputs:
+ self._outputs[key] = (dataset, context)
+ self._output_counts[key] += 1
@property
def collected_datasets(self) -> HookLineage:
"""Get the collected hook lineage information."""
- return HookLineage(self.inputs, self.outputs)
+ return HookLineage(
+ [
+ DatasetLineageInfo(dataset=dataset,
count=self._input_counts[key], context=context)
+ for key, (dataset, context) in self._inputs.items()
+ ],
+ [
+ DatasetLineageInfo(dataset=dataset,
count=self._output_counts[key], context=context)
+ for key, (dataset, context) in self._outputs.items()
+ ],
+ )
@property
def has_collected(self) -> bool:
"""Check if any datasets have been collected."""
- return len(self.inputs) != 0 or len(self.outputs) != 0
+ return len(self._inputs) != 0 or len(self._outputs) != 0
class NoOpCollector(HookLineageCollector):
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index c9c5a99bc4..195c2423b1 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -280,10 +280,12 @@ class TestFs:
_to.unlink()
- assert len(hook_lineage_collector.collected_datasets.inputs) == 1
- assert len(hook_lineage_collector.collected_datasets.outputs) == 1
- assert hook_lineage_collector.collected_datasets.inputs[0][0] ==
Dataset(uri=_from_path)
- assert hook_lineage_collector.collected_datasets.outputs[0][0] ==
Dataset(uri=_to_path)
+ collected_datasets = hook_lineage_collector.collected_datasets
+
+ assert len(collected_datasets.inputs) == 1
+ assert len(collected_datasets.outputs) == 1
+ assert collected_datasets.inputs[0].dataset == Dataset(uri=_from_path)
+ assert collected_datasets.outputs[0].dataset == Dataset(uri=_to_path)
def test_move_remote(self, hook_lineage_collector):
attach("fakefs", fs=FakeRemoteFileSystem())
@@ -301,10 +303,12 @@ class TestFs:
_to.unlink()
- assert len(hook_lineage_collector.collected_datasets.inputs) == 1
- assert len(hook_lineage_collector.collected_datasets.outputs) == 1
- assert hook_lineage_collector.collected_datasets.inputs[0][0] ==
Dataset(uri=str(_from))
- assert hook_lineage_collector.collected_datasets.outputs[0][0] ==
Dataset(uri=str(_to))
+ collected_datasets = hook_lineage_collector.collected_datasets
+
+ assert len(collected_datasets.inputs) == 1
+ assert len(collected_datasets.outputs) == 1
+ assert collected_datasets.inputs[0].dataset == Dataset(uri=str(_from))
+ assert collected_datasets.outputs[0].dataset == Dataset(uri=str(_to))
def test_copy_remote_remote(self, hook_lineage_collector):
attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
@@ -335,7 +339,7 @@ class TestFs:
_to.rmdir(recursive=True)
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
- assert hook_lineage_collector.collected_datasets.inputs[0][0] ==
Dataset(uri=str(_from_file))
+ assert hook_lineage_collector.collected_datasets.inputs[0].dataset ==
Dataset(uri=str(_from_file))
# Empty file - shutil.copyfileobj does nothing
assert len(hook_lineage_collector.collected_datasets.outputs) == 0
diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py
index dab0bee6ec..e00c5ab22b 100644
--- a/tests/io/test_wrapper.py
+++ b/tests/io/test_wrapper.py
@@ -32,8 +32,8 @@ def test_wrapper_catches_reads_writes(providers_manager,
hook_lineage_collector)
file.write("aaa")
file.close()
- assert len(hook_lineage_collector.outputs) == 1
- assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)
+ assert len(hook_lineage_collector._outputs) == 1
+ assert next(iter(hook_lineage_collector._outputs.values()))[0] ==
Dataset(uri=uri)
file = path.open("r")
file.read()
@@ -41,8 +41,8 @@ def test_wrapper_catches_reads_writes(providers_manager,
hook_lineage_collector)
path.unlink(missing_ok=True)
- assert len(hook_lineage_collector.inputs) == 1
- assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)
+ assert len(hook_lineage_collector._inputs) == 1
+ assert next(iter(hook_lineage_collector._inputs.values()))[0] ==
Dataset(uri=uri)
@patch("airflow.providers_manager.ProvidersManager")
@@ -53,12 +53,12 @@ def
test_wrapper_works_with_contextmanager(providers_manager, hook_lineage_colle
with path.open("w") as file:
file.write("asdf")
- assert len(hook_lineage_collector.outputs) == 1
- assert hook_lineage_collector.outputs[0][0] == Dataset(uri=uri)
+ assert len(hook_lineage_collector._outputs) == 1
+ assert next(iter(hook_lineage_collector._outputs.values()))[0] ==
Dataset(uri=uri)
with path.open("r") as file:
file.read()
path.unlink(missing_ok=True)
- assert len(hook_lineage_collector.inputs) == 1
- assert hook_lineage_collector.inputs[0][0] == Dataset(uri=uri)
+ assert len(hook_lineage_collector._inputs) == 1
+ assert next(iter(hook_lineage_collector._inputs.values()))[0] ==
Dataset(uri=uri)
diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py
index 15d4d6c1e4..97e160e1f9 100644
--- a/tests/lineage/test_hook.py
+++ b/tests/lineage/test_hook.py
@@ -26,6 +26,7 @@ from airflow.datasets import Dataset
from airflow.hooks.base import BaseHook
from airflow.lineage import hook
from airflow.lineage.hook import (
+ DatasetLineageInfo,
HookLineage,
HookLineageCollector,
HookLineageReader,
@@ -36,44 +37,72 @@ from tests.test_utils.mock_plugins import
mock_plugin_manager
class TestHookLineageCollector:
+ def setup_method(self):
+ self.collector = HookLineageCollector()
+
def test_are_datasets_collected(self):
- lineage_collector = HookLineageCollector()
- assert lineage_collector is not None
- assert lineage_collector.collected_datasets == HookLineage()
+ assert self.collector is not None
+ assert self.collector.collected_datasets == HookLineage()
input_hook = BaseHook()
output_hook = BaseHook()
- lineage_collector.add_input_dataset(input_hook,
uri="s3://in_bucket/file")
- lineage_collector.add_output_dataset(
+ self.collector.add_input_dataset(input_hook, uri="s3://in_bucket/file")
+ self.collector.add_output_dataset(
output_hook,
uri="postgres://example.com:5432/database/default/table"
)
- assert lineage_collector.collected_datasets == HookLineage(
- [(Dataset("s3://in_bucket/file"), input_hook)],
- [(Dataset("postgres://example.com:5432/database/default/table"),
output_hook)],
+ assert self.collector.collected_datasets == HookLineage(
+ [DatasetLineageInfo(dataset=Dataset("s3://in_bucket/file"),
count=1, context=input_hook)],
+ [
+ DatasetLineageInfo(
+
dataset=Dataset("postgres://example.com:5432/database/default/table"),
+ count=1,
+ context=output_hook,
+ )
+ ],
)
@patch("airflow.lineage.hook.Dataset")
def test_add_input_dataset(self, mock_dataset):
- collector = HookLineageCollector()
- dataset = MagicMock(spec=Dataset)
+ dataset = MagicMock(spec=Dataset, extra={})
mock_dataset.return_value = dataset
hook = MagicMock()
- collector.add_input_dataset(hook, uri="test_uri")
+ self.collector.add_input_dataset(hook, uri="test_uri")
- assert collector.inputs == [(dataset, hook)]
+ assert next(iter(self.collector._inputs.values())) == (dataset, hook)
mock_dataset.assert_called_once_with(uri="test_uri", extra=None)
+ def test_grouping_datasets(self):
+ hook_1 = MagicMock()
+ hook_2 = MagicMock()
+
+ uri = "test://uri/"
+
+ self.collector.add_input_dataset(context=hook_1, uri=uri)
+ self.collector.add_input_dataset(context=hook_2, uri=uri)
+ self.collector.add_input_dataset(context=hook_1, uri=uri,
dataset_extra={"key": "value"})
+
+ collected_inputs = self.collector.collected_datasets.inputs
+
+ assert len(collected_inputs) == 3
+ assert collected_inputs[0].dataset.uri == "test://uri/"
+ assert collected_inputs[0].dataset == collected_inputs[1].dataset
+ assert collected_inputs[0].count == 1
+ assert collected_inputs[0].context == collected_inputs[2].context ==
hook_1
+ assert collected_inputs[1].count == 1
+ assert collected_inputs[1].context == hook_2
+ assert collected_inputs[2].count == 1
+ assert collected_inputs[2].dataset.extra == {"key": "value"}
+
@patch("airflow.lineage.hook.ProvidersManager")
def test_create_dataset(self, mock_providers_manager):
def create_dataset(arg1, arg2="default", extra=None):
return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra)
mock_providers_manager.return_value.dataset_factories = {"myscheme":
create_dataset}
- collector = HookLineageCollector()
- assert collector.create_dataset(
+ assert self.collector.create_dataset(
scheme="myscheme", uri=None, dataset_kwargs={"arg1": "value_1"},
dataset_extra=None
) == Dataset("myscheme://value_1/default")
- assert collector.create_dataset(
+ assert self.collector.create_dataset(
scheme="myscheme",
uri=None,
dataset_kwargs={"arg1": "value_1", "arg2": "value_2"},
@@ -81,21 +110,25 @@ class TestHookLineageCollector:
) == Dataset("myscheme://value_1/value_2", extra={"key": "value"})
def test_collected_datasets(self):
- collector = HookLineageCollector()
- inputs = [(MagicMock()), MagicMock(spec=Dataset)]
- outputs = [(MagicMock()), MagicMock(spec=Dataset)]
- collector.inputs = inputs
- collector.outputs = outputs
+ context_input = MagicMock()
+ context_output = MagicMock()
+
+ self.collector.add_input_dataset(context_input, uri="test://input")
+ self.collector.add_output_dataset(context_output, uri="test://output")
+
+ hook_lineage = self.collector.collected_datasets
+ assert len(hook_lineage.inputs) == 1
+ assert hook_lineage.inputs[0].dataset.uri == "test://input/"
+ assert hook_lineage.inputs[0].context == context_input
- hook_lineage = collector.collected_datasets
- assert hook_lineage.inputs == inputs
- assert hook_lineage.outputs == outputs
+ assert len(hook_lineage.outputs) == 1
+ assert hook_lineage.outputs[0].dataset.uri == "test://output/"
def test_has_collected(self):
collector = HookLineageCollector()
assert not collector.has_collected
- collector.inputs = [MagicMock(spec=Dataset), MagicMock()]
+ collector._inputs = {"unique_key": (MagicMock(spec=Dataset),
MagicMock())}
assert collector.has_collected
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py
b/tests/providers/amazon/aws/hooks/test_s3.py
index acedf3d011..9dade82004 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -395,7 +395,7 @@ class TestAwsS3Hook:
hook = S3Hook()
hook.load_string("Contént", "my_key", s3_bucket)
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
- assert hook_lineage_collector.collected_datasets.outputs[0][0] ==
Dataset(
+ assert hook_lineage_collector.collected_datasets.outputs[0].dataset ==
Dataset(
uri=f"s3://{s3_bucket}/my_key"
)
@@ -988,7 +988,7 @@ class TestAwsS3Hook:
path.write_text("Content")
hook.load_file(path, "my_key", s3_bucket)
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
- assert hook_lineage_collector.collected_datasets.outputs[0][0] ==
Dataset(
+ assert hook_lineage_collector.collected_datasets.outputs[0].dataset ==
Dataset(
uri=f"s3://{s3_bucket}/my_key"
)
@@ -1060,12 +1060,12 @@ class TestAwsS3Hook:
):
mock_hook.copy_object("my_key", "my_key3", s3_bucket, s3_bucket)
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
- assert hook_lineage_collector.collected_datasets.inputs[0][0] ==
Dataset(
+ assert hook_lineage_collector.collected_datasets.inputs[0].dataset
== Dataset(
uri=f"s3://{s3_bucket}/my_key"
)
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
- assert hook_lineage_collector.collected_datasets.outputs[0][0] ==
Dataset(
+ assert
hook_lineage_collector.collected_datasets.outputs[0].dataset == Dataset(
uri=f"s3://{s3_bucket}/my_key3"
)
@@ -1198,7 +1198,7 @@ class TestAwsS3Hook:
s3_hook.download_file(key=key, bucket_name=bucket)
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
- assert hook_lineage_collector.collected_datasets.inputs[0][0] ==
Dataset(
+ assert hook_lineage_collector.collected_datasets.inputs[0].dataset ==
Dataset(
uri="s3://test_bucket/test_key"
)
@@ -1250,12 +1250,12 @@ class TestAwsS3Hook:
)
assert len(hook_lineage_collector.collected_datasets.inputs) == 1
- assert hook_lineage_collector.collected_datasets.inputs[0][0] ==
Dataset(
+ assert hook_lineage_collector.collected_datasets.inputs[0].dataset ==
Dataset(
uri="s3://test_bucket/test_key/test.log"
)
assert len(hook_lineage_collector.collected_datasets.outputs) == 1
- assert hook_lineage_collector.collected_datasets.outputs[0][0] ==
Dataset(
+ assert hook_lineage_collector.collected_datasets.outputs[0].dataset ==
Dataset(
uri=f"file://{local_path}/test.log",
)