This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch aip-62/object-storage
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 945e7c4ff9d22eb04ee44a0ed41f8eb1776740c2
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Tue Jul 16 11:34:10 2024 +0200

    openlineage: add support for hook lineage for Object Store
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/io/path.py       | 61 +++++++++++++++++++++++++++++++++++++++++++--
 tests/io/test_wrapper.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 124 insertions(+), 2 deletions(-)

diff --git a/airflow/io/path.py b/airflow/io/path.py
index 5f39782e2a..5cd64a9d54 100644
--- a/airflow/io/path.py
+++ b/airflow/io/path.py
@@ -27,8 +27,10 @@ from fsspec.utils import stringify_path
 from upath.implementations.cloud import CloudPath
 from upath.registry import get_upath_class
 
-from airflow.io.store import attach
+from airflow.io.store import ObjectStore, attach
 from airflow.io.utils.stat import stat_result
+from airflow.lineage.hook import get_hook_lineage_collector
+from airflow.utils.log.logging_mixin import LoggingMixin
 
 if typing.TYPE_CHECKING:
     from fsspec import AbstractFileSystem
@@ -39,6 +41,47 @@ PT = typing.TypeVar("PT", bound="ObjectStoragePath")
 default = "file"
 
 
+class TrackingFileWrapper(LoggingMixin):
+    """Wrapper that tracks file operations to intercept lineage."""
+
+    def __init__(self, path: ObjectStoragePath, obj):
+        super().__init__()
+        self._path: ObjectStoragePath = path
+        self._obj = obj
+
+    def __getattr__(self, name):
+        attr = getattr(self._obj, name)
+        if callable(attr):
+            # If the attribute is a method, wrap it in another method to 
intercept the call
+            def wrapper(*args, **kwargs):
+                self.log.error("Calling method: %s", name)
+                if name == "read":
+                    get_hook_lineage_collector().add_input_dataset(
+                        context=self._path.store, uri=str(self._path)
+                    )
+                elif name == "write":
+                    get_hook_lineage_collector().add_output_dataset(
+                        context=self._path.store, uri=str(self._path)
+                    )
+                result = attr(*args, **kwargs)
+                return result
+
+            return wrapper
+        return attr
+
+    def __getitem__(self, key):
+        # Intercept item access
+        self.log.error("Accessing item: %s", key)
+        return self._obj[key]
+
+    def __enter__(self):
+        self._obj.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self._obj.__exit__(exc_type, exc_val, exc_tb)
+
+
 class ObjectStoragePath(CloudPath):
     """A path-like object for object storage."""
 
@@ -94,6 +137,15 @@ class ObjectStoragePath(CloudPath):
             and self.storage_options.get("conn_id") == 
other.storage_options.get("conn_id")
         )
 
+    @property
+    def store(self) -> ObjectStore:
+        return ObjectStore(
+            protocol=self.protocol,
+            conn_id=self.storage_options.get("conn_id"),
+            fs=self.fs,
+            storage_options=self._storage_options,
+        )
+
     @property
     def container(self) -> str:
         return self.bucket
@@ -121,7 +173,7 @@ class ObjectStoragePath(CloudPath):
     def open(self, mode="r", **kwargs):
         """Open the file pointed to by this path."""
         kwargs.setdefault("block_size", kwargs.pop("buffering", None))
-        return self.fs.open(self.path, mode=mode, **kwargs)
+        return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, 
**kwargs))
 
     def stat(self) -> stat_result:  # type: ignore[override]
         """Call ``stat`` and return the result."""
@@ -276,6 +328,9 @@ class ObjectStoragePath(CloudPath):
         if isinstance(dst, str):
             dst = ObjectStoragePath(dst)
 
+        get_hook_lineage_collector().add_input_dataset(context=self.store, 
uri=str(self.path))
+        get_hook_lineage_collector().add_output_dataset(context=dst.store, 
uri=str(dst.path))
+
         # same -> same
         if self.samestore(dst):
             self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs)
@@ -339,6 +394,8 @@ class ObjectStoragePath(CloudPath):
             path = ObjectStoragePath(path)
 
         if self.samestore(path):
+            get_hook_lineage_collector().add_input_dataset(context=self.store, 
uri=str(self))
+            get_hook_lineage_collector().add_input_dataset(context=path.store, 
uri=str(path))
             return self.fs.move(self.path, path.path, recursive=recursive, 
**kwargs)
 
         # non-local copy
diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py
new file mode 100644
index 0000000000..4c7023b3dd
--- /dev/null
+++ b/tests/io/test_wrapper.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import uuid
+from unittest.mock import patch
+
+from airflow.datasets import Dataset
+from airflow.io.path import ObjectStoragePath
+from airflow.lineage import hook
+
+
+@patch("airflow.providers_manager.ProvidersManager")
+def test_wrapper_catches_reads_writes(providers_manager):
+    providers_manager.return_value._dataset_factories = lambda x: 
Dataset(uri=x)
+    uri = f"file:///tmp/{str(uuid.uuid4())}"
+    path = ObjectStoragePath(uri)
+    file = path.open("w")
+    file.write("aaa")
+    file.close()
+
+    assert len(hook.get_hook_lineage_collector().outputs) == 1
+    assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri)
+
+    file = path.open("r")
+    file.read()
+    file.close()
+
+    path.unlink(missing_ok=True)
+
+    assert len(hook.get_hook_lineage_collector().inputs) == 1
+    assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri)
+
+
+@patch("airflow.providers_manager.ProvidersManager")
+def test_wrapper_works_with_contextmanager(providers_manager):
+    providers_manager.return_value._dataset_factories = lambda x: 
Dataset(uri=x)
+    uri = f"file:///tmp/{str(uuid.uuid4())}"
+    path = ObjectStoragePath(uri)
+    with path.open("w") as file:
+        file.write("asdf")
+
+    assert len(hook.get_hook_lineage_collector().outputs) == 1
+    assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri)
+
+    with path.open("r") as file:
+        file.read()
+    path.unlink(missing_ok=True)
+
+    assert len(hook.get_hook_lineage_collector().inputs) == 1
+    assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri)

Reply via email to