This is an automated email from the ASF dual-hosted git repository.

bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 08bc0f4490 Update ObjectStoragePath for universal_pathlib>=v0.2.1 
(#37524)
08bc0f4490 is described below

commit 08bc0f44904fe0d8bc8779e0e892e4d42def3983
Author: Andreas Poehlmann <[email protected]>
AuthorDate: Tue Feb 20 10:53:49 2024 +0100

    Update ObjectStoragePath for universal_pathlib>=v0.2.1 (#37524)
    
    This updates ObjectStoragePath to be compatible with universal_pathlib >= 
0.2.1 which in turn makes it compatible with Python 3.12+.
---
 airflow/io/path.py                             | 169 +++++++++----------------
 airflow/providers/common/io/xcom/backend.py    |   4 +-
 pyproject.toml                                 |   9 +-
 tests/io/test_path.py                          | 135 +++++++++++++-------
 tests/providers/common/io/xcom/test_backend.py |   2 +-
 5 files changed, 154 insertions(+), 165 deletions(-)

diff --git a/airflow/io/path.py b/airflow/io/path.py
index d65d837e7e..cb4c48c476 100644
--- a/airflow/io/path.py
+++ b/airflow/io/path.py
@@ -17,24 +17,20 @@
 from __future__ import annotations
 
 import contextlib
-import functools
 import os
 import shutil
 import typing
-from pathlib import PurePath
+from typing import Any, Mapping
 from urllib.parse import urlsplit
 
-from fsspec.core import split_protocol
 from fsspec.utils import stringify_path
-from upath.implementations.cloud import CloudPath, _CloudAccessor
+from upath.implementations.cloud import CloudPath
 from upath.registry import get_upath_class
 
 from airflow.io.store import attach
 from airflow.io.utils.stat import stat_result
 
 if typing.TYPE_CHECKING:
-    from urllib.parse import SplitResult
-
     from fsspec import AbstractFileSystem
 
 
@@ -43,124 +39,68 @@ PT = typing.TypeVar("PT", bound="ObjectStoragePath")
 default = "file"
 
 
-class _AirflowCloudAccessor(_CloudAccessor):
-    __slots__ = ("_store",)
-
-    def __init__(
-        self,
-        parsed_url: SplitResult | None,
-        conn_id: str | None = None,
-        **kwargs: typing.Any,
-    ) -> None:
-        # warning: we are not calling super().__init__ here
-        # as it will try to create a new fs from a different
-        # set if registered filesystems
-        if parsed_url and parsed_url.scheme:
-            self._store = attach(parsed_url.scheme, conn_id)
-        else:
-            self._store = attach("file", conn_id)
-
-    @property
-    def _fs(self) -> AbstractFileSystem:
-        return self._store.fs
-
-    def __eq__(self, other):
-        return isinstance(other, _AirflowCloudAccessor) and self._store == 
other._store
-
-
 class ObjectStoragePath(CloudPath):
     """A path-like object for object storage."""
 
-    _accessor: _AirflowCloudAccessor
-
     __version__: typing.ClassVar[int] = 1
 
-    _default_accessor = _AirflowCloudAccessor
+    _protocol_dispatch = False
 
     sep: typing.ClassVar[str] = "/"
     root_marker: typing.ClassVar[str] = "/"
 
-    _bucket: str
-    _key: str
-    _protocol: str
-    _hash: int | None
-
-    __slots__ = (
-        "_bucket",
-        "_key",
-        "_conn_id",
-        "_protocol",
-        "_hash",
-    )
-
-    def __new__(
-        cls: type[PT],
-        *args: str | os.PathLike,
-        scheme: str | None = None,
-        conn_id: str | None = None,
-        **kwargs: typing.Any,
-    ) -> PT:
-        args_list = list(args)
-
-        if args_list:
-            other = args_list.pop(0) or "."
-        else:
-            other = "."
-
-        if isinstance(other, PurePath):
-            _cls: typing.Any = type(other)
-            drv, root, parts = _cls._parse_args(args_list)
-            drv, root, parts = _cls._flavour.join_parsed_parts(
-                other._drv,  # type: ignore[attr-defined]
-                other._root,  # type: ignore[attr-defined]
-                other._parts,  # type: ignore[attr-defined]
-                drv,
-                root,
-                parts,  # type: ignore
-            )
-
-            _kwargs = getattr(other, "_kwargs", {})
-            _url = getattr(other, "_url", None)
-            other_kwargs = _kwargs.copy()
-            if _url and _url.scheme:
-                other_kwargs["url"] = _url
-            new_kwargs = _kwargs.copy()
-            new_kwargs.update(kwargs)
-
-            return _cls(_cls._format_parsed_parts(drv, root, parts, 
**other_kwargs), **new_kwargs)
-
-        url = stringify_path(other)
-        parsed_url: SplitResult = urlsplit(url)
-
-        if scheme:  # allow override of protocol
-            parsed_url = parsed_url._replace(scheme=scheme)
-
-        if not parsed_url.path:  # ensure path has root
-            parsed_url = parsed_url._replace(path="/")
-
-        if not parsed_url.scheme and not split_protocol(url)[0]:
-            args_list.insert(0, url)
-        else:
-            args_list.insert(0, parsed_url.path)
+    __slots__ = ("_hash_cached",)
+
+    @classmethod
+    def _transform_init_args(
+        cls,
+        args: tuple[str | os.PathLike, ...],
+        protocol: str,
+        storage_options: dict[str, Any],
+    ) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]:
+        """Extract conn_id from the URL and set it as a storage option."""
+        if args:
+            arg0 = args[0]
+            parsed_url = urlsplit(stringify_path(arg0))
+            userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
+            if have_info:
+                storage_options.setdefault("conn_id", userinfo or None)
+                parsed_url = parsed_url._replace(netloc=hostinfo)
+            args = (parsed_url.geturl(),) + args[1:]
+            protocol = protocol or parsed_url.scheme
+        return args, protocol, storage_options
 
-        # This matches the parsing logic in urllib.parse; see:
-        # 
https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203
-        userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
-        if have_info:
-            conn_id = conn_id or userinfo or None
-            parsed_url = parsed_url._replace(netloc=hostinfo)
+    @classmethod
+    def _parse_storage_options(
+        cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
+    ) -> dict[str, Any]:
+        fs = attach(protocol or "file", 
conn_id=storage_options.get("conn_id")).fs
+        pth_storage_options = type(fs)._get_kwargs_from_urls(urlpath)
+        return {**pth_storage_options, **storage_options}
 
-        return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, 
**kwargs)  # type: ignore
+    @classmethod
+    def _fs_factory(
+        cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
+    ) -> AbstractFileSystem:
+        return attach(protocol or "file", storage_options.get("conn_id")).fs
 
-    @functools.lru_cache
     def __hash__(self) -> int:
-        return hash(str(self))
+        self._hash_cached: int
+        try:
+            return self._hash_cached
+        except AttributeError:
+            self._hash_cached = hash(str(self))
+            return self._hash_cached
 
     def __eq__(self, other: typing.Any) -> bool:
         return self.samestore(other) and str(self) == str(other)
 
     def samestore(self, other: typing.Any) -> bool:
-        return isinstance(other, ObjectStoragePath) and self._accessor == 
other._accessor
+        return (
+            isinstance(other, ObjectStoragePath)
+            and self.protocol == other.protocol
+            and self.storage_options.get("conn_id") == 
other.storage_options.get("conn_id")
+        )
 
     @property
     def container(self) -> str:
@@ -186,12 +126,17 @@ class ObjectStoragePath(CloudPath):
     def namespace(self) -> str:
         return f"{self.protocol}://{self.bucket}" if self.bucket else 
self.protocol
 
+    def open(self, mode="r", **kwargs):
+        """Open the file pointed to by this path."""
+        kwargs.setdefault("block_size", kwargs.pop("buffering", None))
+        return self.fs.open(self.path, mode=mode, **kwargs)
+
     def stat(self) -> stat_result:  # type: ignore[override]
         """Call ``stat`` and return the result."""
         return stat_result(
-            self._accessor.stat(self),
+            self.fs.stat(self.path),
             protocol=self.protocol,
-            conn_id=self._accessor._store.conn_id,
+            conn_id=self.storage_options.get("conn_id"),
         )
 
     def samefile(self, other_path: typing.Any) -> bool:
@@ -368,7 +313,11 @@ class ObjectStoragePath(CloudPath):
                 if path == self.path:
                     continue
 
-                src_obj = ObjectStoragePath(path, 
conn_id=self._accessor._store.conn_id)
+                src_obj = ObjectStoragePath(
+                    path,
+                    protocol=self.protocol,
+                    conn_id=self.storage_options.get("conn_id"),
+                )
 
                 # skip directories, empty directories will not be created
                 if src_obj.is_dir():
@@ -401,7 +350,7 @@ class ObjectStoragePath(CloudPath):
         self.unlink()
 
     def serialize(self) -> dict[str, typing.Any]:
-        _kwargs = self._kwargs.copy()
+        _kwargs = {**self.storage_options}
         conn_id = _kwargs.pop("conn_id", None)
 
         return {
diff --git a/airflow/providers/common/io/xcom/backend.py 
b/airflow/providers/common/io/xcom/backend.py
index 6e995c30e1..3028a49be2 100644
--- a/airflow/providers/common/io/xcom/backend.py
+++ b/airflow/providers/common/io/xcom/backend.py
@@ -132,7 +132,7 @@ class XComObjectStoreBackend(BaseXCom):
             if not p.parent.exists():
                 p.parent.mkdir(parents=True, exist_ok=True)
 
-            with p.open("wb", compression=compression) as f:
+            with p.open(mode="wb", compression=compression) as f:
                 f.write(s_val)
 
             return BaseXCom.serialize_value(str(p))
@@ -152,7 +152,7 @@ class XComObjectStoreBackend(BaseXCom):
 
         try:
             p = ObjectStoragePath(path) / XComObjectStoreBackend._get_key(data)
-            return json.load(p.open("rb", compression="infer"), 
cls=XComDecoder)
+            return json.load(p.open(mode="rb", compression="infer"), 
cls=XComDecoder)
         except TypeError:
             return data
         except ValueError:
diff --git a/pyproject.toml b/pyproject.toml
index f53c1002a3..42265978a7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -148,14 +148,7 @@ dependencies = [
     # We should also remove "licenses/LICENSE-unicodecsv.txt" file when we 
remove this dependency
     "unicodecsv>=0.14.1",
     # The Universal Pathlib provides  Pathlib-like interface for FSSPEC
-    # In 0.1. *It was not very well defined for extension, so the way how we 
use it for 0.1.*
-    # so we used a lot of private methods and attributes that were not defined 
in the interface
-    # an they are broken with version 0.2.0 which is much better suited for 
extension and supports
-    # Python 3.12. We should limit it, unti we migrate to 0.2.0
-    # See: 
https://github.com/fsspec/universal_pathlib/pull/173#issuecomment-1937090528
-    # This is prerequistite to make Airflow compatible with Python 3.12
-    # Tracked in https://github.com/apache/airflow/pull/36755
-    "universal-pathlib>=0.1.4,<0.2.0",
+    "universal-pathlib>=0.2.1",
     # Werkzug 3 breaks Flask-Login 0.6.2, also connexion needs to be updated 
to >= 3.0
     # we should remove this limitation when FAB supports Flask 2.3 and we 
migrate connexion to 3+
     "werkzeug>=2.0,<3",
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index deb8d412cc..e03b40e0e4 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -20,11 +20,13 @@ from __future__ import annotations
 import uuid
 from stat import S_ISDIR, S_ISREG
 from tempfile import NamedTemporaryFile
+from typing import Any, ClassVar
 from unittest import mock
 
 import pytest
 from fsspec.implementations.local import LocalFileSystem
-from fsspec.utils import stringify_path
+from fsspec.implementations.memory import MemoryFileSystem
+from fsspec.registry import _registry as _fsspec_registry, 
register_implementation
 
 from airflow.datasets import Dataset
 from airflow.io import _register_filesystems, get_fs
@@ -38,19 +40,46 @@ FOO = "file:///mnt/warehouse/foo"
 BAR = FOO
 
 
-class FakeRemoteFileSystem(LocalFileSystem):
-    id = "fakefs"
-    auto_mk_dir = True
+class FakeLocalFileSystem(MemoryFileSystem):
+    protocol = ("file", "local")
+    root_marker = "/"
+    store: ClassVar[dict[str, Any]] = {}
+    pseudo_dirs = [""]
 
-    @property
-    def fsid(self):
-        return self.id
+    def __init__(self, *args, **kwargs):
+        self.conn_id = kwargs.pop("conn_id", None)
+        super().__init__(*args, **kwargs)
 
     @classmethod
-    def _strip_protocol(cls, path) -> str:
-        path = stringify_path(path)
-        i = path.find("://")
-        return path[i + 3 :] if i > 0 else path
+    def _strip_protocol(cls, path):
+        for protocol in cls.protocol:
+            if path.startswith(f"{protocol}://"):
+                return path[len(f"{protocol}://") :]
+        if "::" in path or "://" in path:
+            return path.rstrip("/")
+        path = path.lstrip("/").rstrip("/")
+        return "/" + path if path else ""
+
+
+class FakeRemoteFileSystem(MemoryFileSystem):
+    protocol = ("s3", "fakefs", "ffs", "ffs2")
+    root_marker = ""
+    store: ClassVar[dict[str, Any]] = {}
+    pseudo_dirs = [""]
+
+    def __init__(self, *args, **kwargs):
+        self.conn_id = kwargs.pop("conn_id", None)
+        super().__init__(*args, **kwargs)
+
+    @classmethod
+    def _strip_protocol(cls, path):
+        for protocol in cls.protocol:
+            if path.startswith(f"{protocol}://"):
+                return path[len(f"{protocol}://") :]
+        if "::" in path or "://" in path:
+            return path.rstrip("/")
+        path = path.lstrip("/").rstrip("/")
+        return "/" + path if path else ""
 
 
 def get_fs_no_storage_options(_: str):
@@ -60,10 +89,15 @@ def get_fs_no_storage_options(_: str):
 class TestFs:
     def setup_class(self):
         self._store_cache = _STORE_CACHE.copy()
+        self._fsspec_registry = _fsspec_registry.copy()
+        for protocol in FakeRemoteFileSystem.protocol:
+            register_implementation(protocol, FakeRemoteFileSystem, 
clobber=True)
 
     def teardown(self):
         _STORE_CACHE.clear()
         _STORE_CACHE.update(self._store_cache)
+        _fsspec_registry.clear()
+        _fsspec_registry.update(self._fsspec_registry)
 
     def test_alias(self):
         store = attach("file", alias="local")
@@ -71,22 +105,24 @@ class TestFs:
         assert "local" in _STORE_CACHE
 
     def test_init_objectstoragepath(self):
-        path = ObjectStoragePath("file://bucket/key/part1/part2")
+        attach("s3", fs=FakeRemoteFileSystem())
+
+        path = ObjectStoragePath("s3://bucket/key/part1/part2")
         assert path.bucket == "bucket"
         assert path.key == "key/part1/part2"
-        assert path.protocol == "file"
+        assert path.protocol == "s3"
         assert path.path == "bucket/key/part1/part2"
 
         path2 = ObjectStoragePath(path / "part3")
         assert path2.bucket == "bucket"
         assert path2.key == "key/part1/part2/part3"
-        assert path2.protocol == "file"
+        assert path2.protocol == "s3"
         assert path2.path == "bucket/key/part1/part2/part3"
 
         path3 = ObjectStoragePath(path2 / "2023")
         assert path3.bucket == "bucket"
         assert path3.key == "key/part1/part2/part3/2023"
-        assert path3.protocol == "file"
+        assert path3.protocol == "s3"
         assert path3.path == "bucket/key/part1/part2/part3/2023"
 
     def test_read_write(self):
@@ -116,49 +152,57 @@ class TestFs:
 
         assert not o.exists()
 
-    @pytest.fixture()
-    def fake_fs(self):
-        fs = mock.Mock()
-        fs._strip_protocol.return_value = "/"
-        fs.conn_id = "fake"
-        return fs
-
-    def test_objectstoragepath_init_conn_id_in_uri(self, fake_fs):
-        fake_fs.stat.return_value = {"stat": "result"}
-        attach(protocol="fake", conn_id="fake", fs=fake_fs)
+    def test_objectstoragepath_init_conn_id_in_uri(self):
+        attach(protocol="fake", conn_id="fake", 
fs=FakeRemoteFileSystem(conn_id="fake"))
         p = ObjectStoragePath("fake://fake@bucket/path")
-        assert p.stat() == {"stat": "result", "conn_id": "fake", "protocol": 
"fake"}
+        p.touch()
+        fsspec_info = p.fs.info(p.path)
+        assert p.stat() == {**fsspec_info, "conn_id": "fake", "protocol": 
"fake"}
+
+    @pytest.fixture
+    def fake_local_files(self):
+        obj = FakeLocalFileSystem()
+        obj.touch(FOO)
+        try:
+            yield
+        finally:
+            FakeLocalFileSystem.store.clear()
+            FakeLocalFileSystem.pseudo_dirs[:] = [""]
 
     @pytest.mark.parametrize(
         "fn, args, fn2, path, expected_args, expected_kwargs",
         [
-            ("checksum", {}, "checksum", FOO, 
FakeRemoteFileSystem._strip_protocol(BAR), {}),
-            ("size", {}, "size", FOO, 
FakeRemoteFileSystem._strip_protocol(BAR), {}),
+            ("checksum", {}, "checksum", FOO, 
FakeLocalFileSystem._strip_protocol(BAR), {}),
+            ("size", {}, "size", FOO, 
FakeLocalFileSystem._strip_protocol(BAR), {}),
             (
                 "sign",
                 {"expiration": 200, "extra": "xtra"},
                 "sign",
                 FOO,
-                FakeRemoteFileSystem._strip_protocol(BAR),
+                FakeLocalFileSystem._strip_protocol(BAR),
                 {"expiration": 200, "extra": "xtra"},
             ),
-            ("ukey", {}, "ukey", FOO, 
FakeRemoteFileSystem._strip_protocol(BAR), {}),
+            ("ukey", {}, "ukey", FOO, 
FakeLocalFileSystem._strip_protocol(BAR), {}),
             (
                 "read_block",
                 {"offset": 0, "length": 1},
                 "read_block",
                 FOO,
-                FakeRemoteFileSystem._strip_protocol(BAR),
+                FakeLocalFileSystem._strip_protocol(BAR),
                 {"delimiter": None, "length": 1, "offset": 0},
             ),
         ],
     )
-    def test_standard_extended_api(self, fake_fs, fn, args, fn2, path, 
expected_args, expected_kwargs):
-        store = attach(protocol="file", conn_id="fake", fs=fake_fs)
-        o = ObjectStoragePath(path, conn_id="fake")
+    def test_standard_extended_api(
+        self, fake_local_files, fn, args, fn2, path, expected_args, 
expected_kwargs
+    ):
+        fs = FakeLocalFileSystem()
+        with mock.patch.object(fs, fn2) as method:
+            attach(protocol="file", conn_id="fake", fs=fs)
+            o = ObjectStoragePath(path, conn_id="fake")
 
-        getattr(o, fn)(**args)
-        getattr(store.fs, fn2).assert_called_once_with(expected_args, 
**expected_kwargs)
+            getattr(o, fn)(**args)
+            method.assert_called_once_with(expected_args, **expected_kwargs)
 
     def test_stat(self):
         with NamedTemporaryFile() as f:
@@ -168,6 +212,8 @@ class TestFs:
             assert S_ISDIR(o.parent.stat().st_mode)
 
     def test_bucket_key_protocol(self):
+        attach(protocol="s3", fs=FakeRemoteFileSystem())
+
         bucket = "bkt"
         key = "yek"
         protocol = "s3"
@@ -227,24 +273,23 @@ class TestFs:
         _to.unlink()
 
     def test_copy_remote_remote(self):
-        # foo = xxx added to prevent same fs token
-        attach("ffs", fs=FakeRemoteFileSystem(auto_mkdir=True, foo="bar"))
-        attach("ffs2", fs=FakeRemoteFileSystem(auto_mkdir=True, foo="baz"))
+        attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
+        attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True))
 
-        dir_src = f"/tmp/{str(uuid.uuid4())}"
-        dir_dst = f"/tmp/{str(uuid.uuid4())}"
+        dir_src = f"bucket1/{str(uuid.uuid4())}"
+        dir_dst = f"bucket2/{str(uuid.uuid4())}"
         key = "foo/bar/baz.txt"
 
-        # note we are dealing with object storage characteristics
-        # while working on a local filesystem, so it might feel not intuitive
         _from = ObjectStoragePath(f"ffs://{dir_src}")
         _from_file = _from / key
         _from_file.touch()
+        assert _from.bucket == "bucket1"
         assert _from_file.exists()
 
         _to = ObjectStoragePath(f"ffs2://{dir_dst}")
         _from.copy(_to)
 
+        assert _to.bucket == "bucket2"
         assert _to.exists()
         assert _to.is_dir()
         assert (_to / _from.key / key).exists()
@@ -254,7 +299,7 @@ class TestFs:
         _to.rmdir(recursive=True)
 
     def test_serde_objectstoragepath(self):
-        path = "file://bucket/key/part1/part2"
+        path = "file:///bucket/key/part1/part2"
         o = ObjectStoragePath(path)
 
         s = o.serialize()
@@ -312,6 +357,8 @@ class TestFs:
             _register_filesystems.cache_clear()
 
     def test_dataset(self):
+        attach("s3", fs=FakeRemoteFileSystem())
+
         p = "s3"
         f = "/tmp/foo"
         i = Dataset(uri=f"{p}://{f}", extra={"foo": "bar"})
diff --git a/tests/providers/common/io/xcom/test_backend.py 
b/tests/providers/common/io/xcom/test_backend.py
index fce5ed985e..0641e18fe0 100644
--- a/tests/providers/common/io/xcom/test_backend.py
+++ b/tests/providers/common/io/xcom/test_backend.py
@@ -181,7 +181,7 @@ class TestXcomObjectStoreBackend:
             run_id=task_instance.run_id,
             session=session,
         )
-        assert self.path in qry.first().value
+        assert str(p) == qry.first().value
 
     @pytest.mark.db_test
     def test_clear(self, task_instance, session):

Reply via email to