This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch aip-62/object-storage in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 945e7c4ff9d22eb04ee44a0ed41f8eb1776740c2 Author: Maciej Obuchowski <[email protected]> AuthorDate: Tue Jul 16 11:34:10 2024 +0200 openlineage: add support for hook lineage for Object Store Signed-off-by: Maciej Obuchowski <[email protected]> --- airflow/io/path.py | 61 +++++++++++++++++++++++++++++++++++++++++++-- tests/io/test_wrapper.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/airflow/io/path.py b/airflow/io/path.py index 5f39782e2a..5cd64a9d54 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -27,8 +27,10 @@ from fsspec.utils import stringify_path from upath.implementations.cloud import CloudPath from upath.registry import get_upath_class -from airflow.io.store import attach +from airflow.io.store import ObjectStore, attach from airflow.io.utils.stat import stat_result +from airflow.lineage.hook import get_hook_lineage_collector +from airflow.utils.log.logging_mixin import LoggingMixin if typing.TYPE_CHECKING: from fsspec import AbstractFileSystem @@ -39,6 +41,47 @@ PT = typing.TypeVar("PT", bound="ObjectStoragePath") default = "file" +class TrackingFileWrapper(LoggingMixin): + """Wrapper that tracks file operations to intercept lineage.""" + + def __init__(self, path: ObjectStoragePath, obj): + super().__init__() + self._path: ObjectStoragePath = path + self._obj = obj + + def __getattr__(self, name): + attr = getattr(self._obj, name) + if callable(attr): + # If the attribute is a method, wrap it in another method to intercept the call + def wrapper(*args, **kwargs): + self.log.error("Calling method: %s", name) + if name == "read": + get_hook_lineage_collector().add_input_dataset( + context=self._path.store, uri=str(self._path) + ) + elif name == "write": + get_hook_lineage_collector().add_output_dataset( + context=self._path.store, uri=str(self._path) + ) + result = attr(*args, **kwargs) + return result + + return wrapper + return attr + + def __getitem__(self, key): + # Intercept item access + self.log.error("Accessing item: %s", key) + return self._obj[key] + + def __enter__(self): + self._obj.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._obj.__exit__(exc_type, exc_val, exc_tb) + + class ObjectStoragePath(CloudPath): """A path-like object for object storage.""" @@ -94,6 +137,15 @@ class ObjectStoragePath(CloudPath): and self.storage_options.get("conn_id") == other.storage_options.get("conn_id") ) + @property + def store(self) -> ObjectStore: + return ObjectStore( + protocol=self.protocol, + conn_id=self.storage_options.get("conn_id"), + fs=self.fs, + storage_options=self._storage_options, + ) + @property def container(self) -> str: return self.bucket @@ -121,7 +173,7 @@ class ObjectStoragePath(CloudPath): def open(self, mode="r", **kwargs): """Open the file pointed to by this path.""" kwargs.setdefault("block_size", kwargs.pop("buffering", None)) - return self.fs.open(self.path, mode=mode, **kwargs) + return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs)) def stat(self) -> stat_result: # type: ignore[override] """Call ``stat`` and return the result.""" @@ -276,6 +328,9 @@ class ObjectStoragePath(CloudPath): if isinstance(dst, str): dst = ObjectStoragePath(dst) + get_hook_lineage_collector().add_input_dataset(context=self.store, uri=str(self.path)) + get_hook_lineage_collector().add_output_dataset(context=dst.store, uri=str(dst.path)) + # same -> same if self.samestore(dst): self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs) @@ -339,6 +394,8 @@ class ObjectStoragePath(CloudPath): path = ObjectStoragePath(path) if self.samestore(path): + get_hook_lineage_collector().add_input_dataset(context=self.store, uri=str(self)) + get_hook_lineage_collector().add_input_dataset(context=path.store, uri=str(path)) return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) # non-local copy diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py new file mode 100644 index 0000000000..4c7023b3dd --- /dev/null +++ b/tests/io/test_wrapper.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import uuid +from unittest.mock import patch + +from airflow.datasets import Dataset +from airflow.io.path import ObjectStoragePath +from airflow.lineage import hook + + +@patch("airflow.providers_manager.ProvidersManager") +def test_wrapper_catches_reads_writes(providers_manager): + providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + uri = f"file:///tmp/{str(uuid.uuid4())}" + path = ObjectStoragePath(uri) + file = path.open("w") + file.write("aaa") + file.close() + + assert len(hook.get_hook_lineage_collector().outputs) == 1 + assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri) + + file = path.open("r") + file.read() + file.close() + + path.unlink(missing_ok=True) + + assert len(hook.get_hook_lineage_collector().inputs) == 1 + assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri) + + +@patch("airflow.providers_manager.ProvidersManager") +def test_wrapper_works_with_contextmanager(providers_manager): + providers_manager.return_value._dataset_factories = lambda x: Dataset(uri=x) + uri = f"file:///tmp/{str(uuid.uuid4())}" + path = ObjectStoragePath(uri) + with path.open("w") as file: + file.write("asdf") + + assert len(hook.get_hook_lineage_collector().outputs) == 1 + assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri) + + with path.open("r") as file: + file.read() + path.unlink(missing_ok=True) + + assert len(hook.get_hook_lineage_collector().inputs) == 1 + assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri)
