This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch aip-62/object-storage
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ae8432f1f33501a5c2143a95cebf9910ffd852a0
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Tue Jul 16 11:34:10 2024 +0200

    openlineage: add support for hook lineage for Object Store
    
    Signed-off-by: Maciej Obuchowski <[email protected]>
---
 airflow/io/path.py       | 61 +++++++++++++++++++++++++++++++++++++++++++--
 tests/io/test_path.py    | 42 +++++++++++++++++++++++--------
 tests/io/test_wrapper.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 155 insertions(+), 13 deletions(-)

diff --git a/airflow/io/path.py b/airflow/io/path.py
index 5f39782e2a..5cd64a9d54 100644
--- a/airflow/io/path.py
+++ b/airflow/io/path.py
@@ -27,8 +27,10 @@ from fsspec.utils import stringify_path
 from upath.implementations.cloud import CloudPath
 from upath.registry import get_upath_class
 
-from airflow.io.store import attach
+from airflow.io.store import ObjectStore, attach
 from airflow.io.utils.stat import stat_result
+from airflow.lineage.hook import get_hook_lineage_collector
+from airflow.utils.log.logging_mixin import LoggingMixin
 
 if typing.TYPE_CHECKING:
     from fsspec import AbstractFileSystem
@@ -39,6 +41,47 @@ PT = typing.TypeVar("PT", bound="ObjectStoragePath")
 default = "file"
 
 
+class TrackingFileWrapper(LoggingMixin):
+    """Wrapper that tracks file operations to intercept lineage."""
+
+    def __init__(self, path: ObjectStoragePath, obj):
+        super().__init__()
+        self._path: ObjectStoragePath = path
+        self._obj = obj
+
+    def __getattr__(self, name):
+        attr = getattr(self._obj, name)
+        if callable(attr):
+            # If the attribute is a method, wrap it in another method to 
intercept the call
+            def wrapper(*args, **kwargs):
+                self.log.error("Calling method: %s", name)
+                if name == "read":
+                    get_hook_lineage_collector().add_input_dataset(
+                        context=self._path.store, uri=str(self._path)
+                    )
+                elif name == "write":
+                    get_hook_lineage_collector().add_output_dataset(
+                        context=self._path.store, uri=str(self._path)
+                    )
+                result = attr(*args, **kwargs)
+                return result
+
+            return wrapper
+        return attr
+
+    def __getitem__(self, key):
+        # Intercept item access
+        self.log.error("Accessing item: %s", key)
+        return self._obj[key]
+
+    def __enter__(self):
+        self._obj.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self._obj.__exit__(exc_type, exc_val, exc_tb)
+
+
 class ObjectStoragePath(CloudPath):
     """A path-like object for object storage."""
 
@@ -94,6 +137,15 @@ class ObjectStoragePath(CloudPath):
             and self.storage_options.get("conn_id") == 
other.storage_options.get("conn_id")
         )
 
+    @property
+    def store(self) -> ObjectStore:
+        return ObjectStore(
+            protocol=self.protocol,
+            conn_id=self.storage_options.get("conn_id"),
+            fs=self.fs,
+            storage_options=self._storage_options,
+        )
+
     @property
     def container(self) -> str:
         return self.bucket
@@ -121,7 +173,7 @@ class ObjectStoragePath(CloudPath):
     def open(self, mode="r", **kwargs):
         """Open the file pointed to by this path."""
         kwargs.setdefault("block_size", kwargs.pop("buffering", None))
-        return self.fs.open(self.path, mode=mode, **kwargs)
+        return TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, 
**kwargs))
 
     def stat(self) -> stat_result:  # type: ignore[override]
         """Call ``stat`` and return the result."""
@@ -276,6 +328,9 @@ class ObjectStoragePath(CloudPath):
         if isinstance(dst, str):
             dst = ObjectStoragePath(dst)
 
+        get_hook_lineage_collector().add_input_dataset(context=self.store, 
uri=str(self.path))
+        get_hook_lineage_collector().add_output_dataset(context=dst.store, 
uri=str(dst.path))
+
         # same -> same
         if self.samestore(dst):
             self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs)
@@ -339,6 +394,8 @@ class ObjectStoragePath(CloudPath):
             path = ObjectStoragePath(path)
 
         if self.samestore(path):
+            get_hook_lineage_collector().add_input_dataset(context=self.store, 
uri=str(self))
+            get_hook_lineage_collector().add_input_dataset(context=path.store, 
uri=str(path))
             return self.fs.move(self.path, path.path, recursive=recursive, 
**kwargs)
 
         # non-local copy
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index 1ccdfbfb79..3238a4852b 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -267,9 +267,11 @@ class TestFs:
         with pytest.raises(ValueError):
             o1.relative_to(o3)
 
-    def test_move_local(self):
-        _from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
-        _to = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
+    def test_move_local(self, hook_lineage_collector):
+        _from_path = f"file:///tmp/{str(uuid.uuid4())}"
+        _to_path = f"file:///tmp/{str(uuid.uuid4())}"
+        _from = ObjectStoragePath(_from_path)
+        _to = ObjectStoragePath(_to_path)
 
         _from.touch()
         _from.move(_to)
@@ -278,13 +280,19 @@ class TestFs:
 
         _to.unlink()
 
-    def test_move_remote(self):
+        assert len(hook_lineage_collector.collected_datasets.inputs) == 1
+        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
+        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(uri=_from_path)
+        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(uri=_to_path)
+
+    def test_move_remote(self, hook_lineage_collector):
         attach("fakefs", fs=FakeRemoteFileSystem())
 
-        _from = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")
-        print(_from)
-        _to = ObjectStoragePath(f"fakefs:///tmp/{str(uuid.uuid4())}")
-        print(_to)
+        _from_path = f"file:///tmp/{str(uuid.uuid4())}"
+        _to_path = f"fakefs:///tmp/{str(uuid.uuid4())}"
+
+        _from = ObjectStoragePath(_from_path)
+        _to = ObjectStoragePath(_to_path)
 
         _from.touch()
         _from.move(_to)
@@ -293,7 +301,12 @@ class TestFs:
 
         _to.unlink()
 
-    def test_copy_remote_remote(self):
+        assert len(hook_lineage_collector.collected_datasets.inputs) == 1
+        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
+        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(uri=_from_path)
+        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(uri=_to_path)
+
+    def test_copy_remote_remote(self, hook_lineage_collector):
         attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
         attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True))
 
@@ -301,13 +314,15 @@ class TestFs:
         dir_dst = f"bucket2/{str(uuid.uuid4())}"
         key = "foo/bar/baz.txt"
 
-        _from = ObjectStoragePath(f"ffs://{dir_src}")
+        _from_path = f"ffs://{dir_src}"
+        _from = ObjectStoragePath(_from_path)
         _from_file = _from / key
         _from_file.touch()
         assert _from.bucket == "bucket1"
         assert _from_file.exists()
 
-        _to = ObjectStoragePath(f"ffs2://{dir_dst}")
+        _to_path = f"ffs2://{dir_dst}"
+        _to = ObjectStoragePath(_to_path)
         _from.copy(_to)
 
         assert _to.bucket == "bucket2"
@@ -319,6 +334,11 @@ class TestFs:
         _from.rmdir(recursive=True)
         _to.rmdir(recursive=True)
 
+        assert len(hook_lineage_collector.collected_datasets.inputs) == 1
+        assert len(hook_lineage_collector.collected_datasets.outputs) == 1
+        assert hook_lineage_collector.collected_datasets.inputs[0][0] == 
Dataset(uri=_from_path)
+        assert hook_lineage_collector.collected_datasets.outputs[0][0] == 
Dataset(uri=_to_path)
+
     def test_serde_objectstoragepath(self):
         path = "file:///bucket/key/part1/part2"
         o = ObjectStoragePath(path)
diff --git a/tests/io/test_wrapper.py b/tests/io/test_wrapper.py
new file mode 100644
index 0000000000..4c7023b3dd
--- /dev/null
+++ b/tests/io/test_wrapper.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import uuid
+from unittest.mock import patch
+
+from airflow.datasets import Dataset
+from airflow.io.path import ObjectStoragePath
+from airflow.lineage import hook
+
+
+@patch("airflow.providers_manager.ProvidersManager")
+def test_wrapper_catches_reads_writes(providers_manager):
+    providers_manager.return_value._dataset_factories = lambda x: 
Dataset(uri=x)
+    uri = f"file:///tmp/{str(uuid.uuid4())}"
+    path = ObjectStoragePath(uri)
+    file = path.open("w")
+    file.write("aaa")
+    file.close()
+
+    assert len(hook.get_hook_lineage_collector().outputs) == 1
+    assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri)
+
+    file = path.open("r")
+    file.read()
+    file.close()
+
+    path.unlink(missing_ok=True)
+
+    assert len(hook.get_hook_lineage_collector().inputs) == 1
+    assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri)
+
+
+@patch("airflow.providers_manager.ProvidersManager")
+def test_wrapper_works_with_contextmanager(providers_manager):
+    providers_manager.return_value._dataset_factories = lambda x: 
Dataset(uri=x)
+    uri = f"file:///tmp/{str(uuid.uuid4())}"
+    path = ObjectStoragePath(uri)
+    with path.open("w") as file:
+        file.write("asdf")
+
+    assert len(hook.get_hook_lineage_collector().outputs) == 1
+    assert hook.get_hook_lineage_collector().outputs[0][0] == Dataset(uri=uri)
+
+    with path.open("r") as file:
+        file.read()
+    path.unlink(missing_ok=True)
+
+    assert len(hook.get_hook_lineage_collector().inputs) == 1
+    assert hook.get_hook_lineage_collector().inputs[0][0] == Dataset(uri=uri)

Reply via email to