This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch aip-62/object-storage in repository https://gitbox.apache.org/repos/asf/airflow.git
commit ae8432f1f33501a5c2143a95cebf9910ffd852a0 Author: Maciej Obuchowski <[email protected]> AuthorDate: Tue Jul 16 11:34:10 2024 +0200 openlineage: add support for hook lineage for Object Store Signed-off-by: Maciej Obuchowski <[email protected]> --- airflow/io/path.py | 61 +++++++++++++++++++++++++++++++++++++++++++-- tests/io/test_path.py | 42 +++++++++++++++++++++++-------- tests/io/test_wrapper.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 13 deletions(-) diff --git a/airflow/io/path.py b/airflow/io/path.py index 5f39782e2a..5cd64a9d54 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -27,8 +27,10 @@ from fsspec.utils import stringify_path from upath.implementations.cloud import CloudPath from upath.registry import get_upath_class -from airflow.io.store import attach +from airflow.io.store import ObjectStore, attach from airflow.io.utils.stat import stat_result +from airflow.lineage.hook import get_hook_lineage_collector +from airflow.utils.log.logging_mixin import LoggingMixin if typing.TYPE_CHECKING: from fsspec import AbstractFileSystem @@ -39,6 +41,47 @@ PT = typing.TypeVar("PT", bound="ObjectStoragePath") default = "file" +class TrackingFileWrapper(LoggingMixin): + """Wrapper that tracks file operations to intercept lineage.""" + + def __init__(self, path: ObjectStoragePath, obj): + super().__init__() + self._path: ObjectStoragePath = path + self._obj = obj + + def __getattr__(self, name): + attr = getattr(self._obj, name) + if callable(attr): + # If the attribute is a method, wrap it in another method to intercept the call + def wrapper(*args, **kwargs): + self.log.error("Calling method: %s", name) + if name == "read": + get_hook_lineage_collector().add_input_dataset( + context=self._path.store, uri=str(self._path) + ) + elif name == "write": + get_hook_lineage_collector().add_output_dataset( + context=self._path.store, uri=str(self._path) + ) + result = attr(*args, **kwargs) + return result + + return wrapper + return attr + + def __getitem__(self, key): + # Intercept item access + self.log.error("Accessing item: %s", key) + return self._obj[key] + + def __enter__(self): + self._obj.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._obj.__exit__(exc_type, exc_val, exc_tb) + + class ObjectStoragePath(CloudPath): """A path-like object for object storage.""" @@ -94,6 +137,15 @@ class ObjectStoragePath(CloudPath): and self.storage_options.get("conn_id") == other.storage_options.get("conn_id") ) + @property + def store(self) -> ObjectStore: + return ObjectStore( + protocol=self.protocol, + conn_id=self.storage_options.get("conn_id"), + fs=self.fs, + storage_options=self._storage_options, + ) + @property def container(self) -> str: return self.bucket @@ -121,7 +173,7 @@ class ObjectStoragePath(CloudPath): def open(self, mode="r", **kwargs): """Open the file pointed to by this path.""" kwargs.setdefault("block_size", kwargs.pop("buffering", None)) - return self.fs.open(self.path, mode=mode, **kwargs) + return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs)) def stat(self) -> stat_result: # type: ignore[override] """Call ``stat`` and return the result.""" @@ -276,6 +328,9 @@ class ObjectStoragePath(CloudPath): if isinstance(dst, str): dst = ObjectStoragePath(dst) + get_hook_lineage_collector().add_input_dataset(context=self.store, uri=str(self.path)) + get_hook_lineage_collector().add_output_dataset(context=dst.store, uri=str(dst.path)) + # same -> same if self.samestore(dst): self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs) @@ -339,6 +394,8 @@ class ObjectStoragePath(CloudPath): path = ObjectStoragePath(path) if self.samestore(path): + get_hook_lineage_collector().add_input_dataset(context=self.store, uri=str(self)) + get_hook_lineage_collector().add_input_dataset(context=path.store, uri=str(path)) return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) # non-local copy diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 1ccdfbfb79..3238a4852b 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -267,9 +267,11 @@ class TestFs: with pytest.raises(ValueError): o1.relative_to(o3) - def test_move_local(self): - _from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") - _to = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") + def test_move_local(self, hook_lineage_collector): + _from_path = f"file:///tmp/{str(uuid.uuid4())}" + _to_path = f"file:///tmp/{str(uuid.uuid4())}" + _from = ObjectStoragePath(_from_path) + _to = ObjectStoragePath(_to_path) _from.touch() _from.move(_to) @@ -278,13 +280,19 @@ class TestFs: _to.unlink() - def test_move_remote(self): + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=_from_path) + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=_to_path) + + def test_move_remote(self, hook_lineage_collector): attach("fakefs", fs=FakeRemoteFileSystem()) - _from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}") - print(_from) - _to = ObjectStoragePath(f"fakefs:///tmp/{str(uuid.uuid4())}") - print(_to) + _from_path = f"file:///tmp/{str(uuid.uuid4())}" + _to_path = f"fakefs:///tmp/{str(uuid.uuid4())}" + + _from = ObjectStoragePath(_from_path) + _to = ObjectStoragePath(_to_path) _from.touch() _from.move(_to) @@ -293,7 +301,12 @@ class TestFs: _to.unlink() - def test_copy_remote_remote(self): + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=_from_path) + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=_to_path) + + def test_copy_remote_remote(self, hook_lineage_collector): attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True)) attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True)) @@ -301,13 +314,15 @@ class TestFs: dir_dst = f"bucket2/{str(uuid.uuid4())}" key = "foo/bar/baz.txt" - _from = ObjectStoragePath(f"ffs://{dir_src}") + _from_path = f"ffs://{dir_src}" + _from = ObjectStoragePath(_from_path) _from_file = _from / key _from_file.touch() assert _from.bucket == "bucket1" assert _from_file.exists() - _to = ObjectStoragePath(f"ffs2://{dir_dst}") + _to_path = f"ffs2://{dir_dst}" + _to = ObjectStoragePath(_to_path) _from.copy(_to) assert _to.bucket == "bucket2" @@ -319,6 +334,11 @@ class TestFs: _from.rmdir(recursive=True) _to.rmdir(recursive=True) + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset(uri=_from_path) + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset(uri=_to_path) + def test_serde_objectstoragepath(self): path = "file:///bucket/key/part1/part2" o = ObjectStoragePath(path) diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py new file mode 100644 index 0000000000..4c7023b3dd --- /dev/null +++ b/tests/io/test_wrapper.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import uuid +from unittest.mock import patch + +from airflow.datasets import Dataset +from airflow.io.path import ObjectStoragePath +from airflow.lineage import hook + + +@patch("airflow.providers_manager.ProvidersManager") +def test_wrapper_catches_reads_writes(providers_manager): + providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + uri = f"file:///tmp/{str(uuid.uuid4())}" + path = ObjectStoragePath(uri) + file = path.open("w") + file.write("aaa") + file.close() + + assert len(hook.get_hook_lineage_collector().outputs) == 1 + assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri) + + file = path.open("r") + file.read() + file.close() + + path.unlink(missing_ok=True) + + assert len(hook.get_hook_lineage_collector().inputs) == 1 + assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri) + + +@patch("airflow.providers_manager.ProvidersManager") +def test_wrapper_works_with_contextmanager(providers_manager): + providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + uri = f"file:///tmp/{str(uuid.uuid4())}" + path = ObjectStoragePath(uri) + with path.open("w") as file: + file.write("asdf") + + assert len(hook.get_hook_lineage_collector().outputs) == 1 + assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri) + + with path.open("r") as file: + file.read() + path.unlink(missing_ok=True) + + assert len(hook.get_hook_lineage_collector().inputs) == 1 + assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri)
