This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 08bc0f4490 Update ObjectStoragePath for universal_pathlib>=v0.2.1
(#37524)
08bc0f4490 is described below
commit 08bc0f44904fe0d8bc8779e0e892e4d42def3983
Author: Andreas Poehlmann <[email protected]>
AuthorDate: Tue Feb 20 10:53:49 2024 +0100
Update ObjectStoragePath for universal_pathlib>=v0.2.1 (#37524)
This updates ObjectStoragePath to be compatible with universal_pathlib >=
0.2.1 which in turn makes it compatible with Python 3.12+.
---
airflow/io/path.py | 169 +++++++++----------------
airflow/providers/common/io/xcom/backend.py | 4 +-
pyproject.toml | 9 +-
tests/io/test_path.py | 135 +++++++++++++-------
tests/providers/common/io/xcom/test_backend.py | 2 +-
5 files changed, 154 insertions(+), 165 deletions(-)
diff --git a/airflow/io/path.py b/airflow/io/path.py
index d65d837e7e..cb4c48c476 100644
--- a/airflow/io/path.py
+++ b/airflow/io/path.py
@@ -17,24 +17,20 @@
from __future__ import annotations
import contextlib
-import functools
import os
import shutil
import typing
-from pathlib import PurePath
+from typing import Any, Mapping
from urllib.parse import urlsplit
-from fsspec.core import split_protocol
from fsspec.utils import stringify_path
-from upath.implementations.cloud import CloudPath, _CloudAccessor
+from upath.implementations.cloud import CloudPath
from upath.registry import get_upath_class
from airflow.io.store import attach
from airflow.io.utils.stat import stat_result
if typing.TYPE_CHECKING:
- from urllib.parse import SplitResult
-
from fsspec import AbstractFileSystem
@@ -43,124 +39,68 @@ PT = typing.TypeVar("PT", bound="ObjectStoragePath")
default = "file"
-class _AirflowCloudAccessor(_CloudAccessor):
- __slots__ = ("_store",)
-
- def __init__(
- self,
- parsed_url: SplitResult | None,
- conn_id: str | None = None,
- **kwargs: typing.Any,
- ) -> None:
- # warning: we are not calling super().__init__ here
- # as it will try to create a new fs from a different
- # set if registered filesystems
- if parsed_url and parsed_url.scheme:
- self._store = attach(parsed_url.scheme, conn_id)
- else:
- self._store = attach("file", conn_id)
-
- @property
- def _fs(self) -> AbstractFileSystem:
- return self._store.fs
-
- def __eq__(self, other):
- return isinstance(other, _AirflowCloudAccessor) and self._store ==
other._store
-
-
class ObjectStoragePath(CloudPath):
"""A path-like object for object storage."""
- _accessor: _AirflowCloudAccessor
-
__version__: typing.ClassVar[int] = 1
- _default_accessor = _AirflowCloudAccessor
+ _protocol_dispatch = False
sep: typing.ClassVar[str] = "/"
root_marker: typing.ClassVar[str] = "/"
- _bucket: str
- _key: str
- _protocol: str
- _hash: int | None
-
- __slots__ = (
- "_bucket",
- "_key",
- "_conn_id",
- "_protocol",
- "_hash",
- )
-
- def __new__(
- cls: type[PT],
- *args: str | os.PathLike,
- scheme: str | None = None,
- conn_id: str | None = None,
- **kwargs: typing.Any,
- ) -> PT:
- args_list = list(args)
-
- if args_list:
- other = args_list.pop(0) or "."
- else:
- other = "."
-
- if isinstance(other, PurePath):
- _cls: typing.Any = type(other)
- drv, root, parts = _cls._parse_args(args_list)
- drv, root, parts = _cls._flavour.join_parsed_parts(
- other._drv, # type: ignore[attr-defined]
- other._root, # type: ignore[attr-defined]
- other._parts, # type: ignore[attr-defined]
- drv,
- root,
- parts, # type: ignore
- )
-
- _kwargs = getattr(other, "_kwargs", {})
- _url = getattr(other, "_url", None)
- other_kwargs = _kwargs.copy()
- if _url and _url.scheme:
- other_kwargs["url"] = _url
- new_kwargs = _kwargs.copy()
- new_kwargs.update(kwargs)
-
- return _cls(_cls._format_parsed_parts(drv, root, parts,
**other_kwargs), **new_kwargs)
-
- url = stringify_path(other)
- parsed_url: SplitResult = urlsplit(url)
-
- if scheme: # allow override of protocol
- parsed_url = parsed_url._replace(scheme=scheme)
-
- if not parsed_url.path: # ensure path has root
- parsed_url = parsed_url._replace(path="/")
-
- if not parsed_url.scheme and not split_protocol(url)[0]:
- args_list.insert(0, url)
- else:
- args_list.insert(0, parsed_url.path)
+ __slots__ = ("_hash_cached",)
+
+ @classmethod
+ def _transform_init_args(
+ cls,
+ args: tuple[str | os.PathLike, ...],
+ protocol: str,
+ storage_options: dict[str, Any],
+ ) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]:
+ """Extract conn_id from the URL and set it as a storage option."""
+ if args:
+ arg0 = args[0]
+ parsed_url = urlsplit(stringify_path(arg0))
+ userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
+ if have_info:
+ storage_options.setdefault("conn_id", userinfo or None)
+ parsed_url = parsed_url._replace(netloc=hostinfo)
+ args = (parsed_url.geturl(),) + args[1:]
+ protocol = protocol or parsed_url.scheme
+ return args, protocol, storage_options
- # This matches the parsing logic in urllib.parse; see:
- #
https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203
- userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
- if have_info:
- conn_id = conn_id or userinfo or None
- parsed_url = parsed_url._replace(netloc=hostinfo)
+ @classmethod
+ def _parse_storage_options(
+ cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
+ ) -> dict[str, Any]:
+ fs = attach(protocol or "file",
conn_id=storage_options.get("conn_id")).fs
+ pth_storage_options = type(fs)._get_kwargs_from_urls(urlpath)
+ return {**pth_storage_options, **storage_options}
- return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id,
**kwargs) # type: ignore
+ @classmethod
+ def _fs_factory(
+ cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
+ ) -> AbstractFileSystem:
+ return attach(protocol or "file", storage_options.get("conn_id")).fs
- @functools.lru_cache
def __hash__(self) -> int:
- return hash(str(self))
+ self._hash_cached: int
+ try:
+ return self._hash_cached
+ except AttributeError:
+ self._hash_cached = hash(str(self))
+ return self._hash_cached
def __eq__(self, other: typing.Any) -> bool:
return self.samestore(other) and str(self) == str(other)
def samestore(self, other: typing.Any) -> bool:
- return isinstance(other, ObjectStoragePath) and self._accessor ==
other._accessor
+ return (
+ isinstance(other, ObjectStoragePath)
+ and self.protocol == other.protocol
+ and self.storage_options.get("conn_id") ==
other.storage_options.get("conn_id")
+ )
@property
def container(self) -> str:
@@ -186,12 +126,17 @@ class ObjectStoragePath(CloudPath):
def namespace(self) -> str:
return f"{self.protocol}://{self.bucket}" if self.bucket else
self.protocol
+ def open(self, mode="r", **kwargs):
+ """Open the file pointed to by this path."""
+ kwargs.setdefault("block_size", kwargs.pop("buffering", None))
+ return self.fs.open(self.path, mode=mode, **kwargs)
+
def stat(self) -> stat_result: # type: ignore[override]
"""Call ``stat`` and return the result."""
return stat_result(
- self._accessor.stat(self),
+ self.fs.stat(self.path),
protocol=self.protocol,
- conn_id=self._accessor._store.conn_id,
+ conn_id=self.storage_options.get("conn_id"),
)
def samefile(self, other_path: typing.Any) -> bool:
@@ -368,7 +313,11 @@ class ObjectStoragePath(CloudPath):
if path == self.path:
continue
- src_obj = ObjectStoragePath(path,
conn_id=self._accessor._store.conn_id)
+ src_obj = ObjectStoragePath(
+ path,
+ protocol=self.protocol,
+ conn_id=self.storage_options.get("conn_id"),
+ )
# skip directories, empty directories will not be created
if src_obj.is_dir():
@@ -401,7 +350,7 @@ class ObjectStoragePath(CloudPath):
self.unlink()
def serialize(self) -> dict[str, typing.Any]:
- _kwargs = self._kwargs.copy()
+ _kwargs = {**self.storage_options}
conn_id = _kwargs.pop("conn_id", None)
return {
diff --git a/airflow/providers/common/io/xcom/backend.py
b/airflow/providers/common/io/xcom/backend.py
index 6e995c30e1..3028a49be2 100644
--- a/airflow/providers/common/io/xcom/backend.py
+++ b/airflow/providers/common/io/xcom/backend.py
@@ -132,7 +132,7 @@ class XComObjectStoreBackend(BaseXCom):
if not p.parent.exists():
p.parent.mkdir(parents=True, exist_ok=True)
- with p.open("wb", compression=compression) as f:
+ with p.open(mode="wb", compression=compression) as f:
f.write(s_val)
return BaseXCom.serialize_value(str(p))
@@ -152,7 +152,7 @@ class XComObjectStoreBackend(BaseXCom):
try:
p = ObjectStoragePath(path) / XComObjectStoreBackend._get_key(data)
- return json.load(p.open("rb", compression="infer"),
cls=XComDecoder)
+ return json.load(p.open(mode="rb", compression="infer"),
cls=XComDecoder)
except TypeError:
return data
except ValueError:
diff --git a/pyproject.toml b/pyproject.toml
index f53c1002a3..42265978a7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -148,14 +148,7 @@ dependencies = [
# We should also remove "licenses/LICENSE-unicodecsv.txt" file when we
remove this dependency
"unicodecsv>=0.14.1",
# The Universal Pathlib provides Pathlib-like interface for FSSPEC
- # In 0.1. *It was not very well defined for extension, so the way how we
use it for 0.1.*
- # so we used a lot of private methods and attributes that were not defined
in the interface
- # an they are broken with version 0.2.0 which is much better suited for
extension and supports
- # Python 3.12. We should limit it, unti we migrate to 0.2.0
- # See:
https://github.com/fsspec/universal_pathlib/pull/173#issuecomment-1937090528
- # This is prerequistite to make Airflow compatible with Python 3.12
- # Tracked in https://github.com/apache/airflow/pull/36755
- "universal-pathlib>=0.1.4,<0.2.0",
+ "universal-pathlib>=0.2.1",
# Werkzug 3 breaks Flask-Login 0.6.2, also connexion needs to be updated
to >= 3.0
# we should remove this limitation when FAB supports Flask 2.3 and we
migrate connexion to 3+
"werkzeug>=2.0,<3",
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index deb8d412cc..e03b40e0e4 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -20,11 +20,13 @@ from __future__ import annotations
import uuid
from stat import S_ISDIR, S_ISREG
from tempfile import NamedTemporaryFile
+from typing import Any, ClassVar
from unittest import mock
import pytest
from fsspec.implementations.local import LocalFileSystem
-from fsspec.utils import stringify_path
+from fsspec.implementations.memory import MemoryFileSystem
+from fsspec.registry import _registry as _fsspec_registry,
register_implementation
from airflow.datasets import Dataset
from airflow.io import _register_filesystems, get_fs
@@ -38,19 +40,46 @@ FOO = "file:///mnt/warehouse/foo"
BAR = FOO
-class FakeRemoteFileSystem(LocalFileSystem):
- id = "fakefs"
- auto_mk_dir = True
+class FakeLocalFileSystem(MemoryFileSystem):
+ protocol = ("file", "local")
+ root_marker = "/"
+ store: ClassVar[dict[str, Any]] = {}
+ pseudo_dirs = [""]
- @property
- def fsid(self):
- return self.id
+ def __init__(self, *args, **kwargs):
+ self.conn_id = kwargs.pop("conn_id", None)
+ super().__init__(*args, **kwargs)
@classmethod
- def _strip_protocol(cls, path) -> str:
- path = stringify_path(path)
- i = path.find("://")
- return path[i + 3 :] if i > 0 else path
+ def _strip_protocol(cls, path):
+ for protocol in cls.protocol:
+ if path.startswith(f"{protocol}://"):
+ return path[len(f"{protocol}://") :]
+ if "::" in path or "://" in path:
+ return path.rstrip("/")
+ path = path.lstrip("/").rstrip("/")
+ return "/" + path if path else ""
+
+
+class FakeRemoteFileSystem(MemoryFileSystem):
+ protocol = ("s3", "fakefs", "ffs", "ffs2")
+ root_marker = ""
+ store: ClassVar[dict[str, Any]] = {}
+ pseudo_dirs = [""]
+
+ def __init__(self, *args, **kwargs):
+ self.conn_id = kwargs.pop("conn_id", None)
+ super().__init__(*args, **kwargs)
+
+ @classmethod
+ def _strip_protocol(cls, path):
+ for protocol in cls.protocol:
+ if path.startswith(f"{protocol}://"):
+ return path[len(f"{protocol}://") :]
+ if "::" in path or "://" in path:
+ return path.rstrip("/")
+ path = path.lstrip("/").rstrip("/")
+ return "/" + path if path else ""
def get_fs_no_storage_options(_: str):
@@ -60,10 +89,15 @@ def get_fs_no_storage_options(_: str):
class TestFs:
def setup_class(self):
self._store_cache = _STORE_CACHE.copy()
+ self._fsspec_registry = _fsspec_registry.copy()
+ for protocol in FakeRemoteFileSystem.protocol:
+ register_implementation(protocol, FakeRemoteFileSystem,
clobber=True)
def teardown(self):
_STORE_CACHE.clear()
_STORE_CACHE.update(self._store_cache)
+ _fsspec_registry.clear()
+ _fsspec_registry.update(self._fsspec_registry)
def test_alias(self):
store = attach("file", alias="local")
@@ -71,22 +105,24 @@ class TestFs:
assert "local" in _STORE_CACHE
def test_init_objectstoragepath(self):
- path = ObjectStoragePath("file://bucket/key/part1/part2")
+ attach("s3", fs=FakeRemoteFileSystem())
+
+ path = ObjectStoragePath("s3://bucket/key/part1/part2")
assert path.bucket == "bucket"
assert path.key == "key/part1/part2"
- assert path.protocol == "file"
+ assert path.protocol == "s3"
assert path.path == "bucket/key/part1/part2"
path2 = ObjectStoragePath(path / "part3")
assert path2.bucket == "bucket"
assert path2.key == "key/part1/part2/part3"
- assert path2.protocol == "file"
+ assert path2.protocol == "s3"
assert path2.path == "bucket/key/part1/part2/part3"
path3 = ObjectStoragePath(path2 / "2023")
assert path3.bucket == "bucket"
assert path3.key == "key/part1/part2/part3/2023"
- assert path3.protocol == "file"
+ assert path3.protocol == "s3"
assert path3.path == "bucket/key/part1/part2/part3/2023"
def test_read_write(self):
@@ -116,49 +152,57 @@ class TestFs:
assert not o.exists()
- @pytest.fixture()
- def fake_fs(self):
- fs = mock.Mock()
- fs._strip_protocol.return_value = "/"
- fs.conn_id = "fake"
- return fs
-
- def test_objectstoragepath_init_conn_id_in_uri(self, fake_fs):
- fake_fs.stat.return_value = {"stat": "result"}
- attach(protocol="fake", conn_id="fake", fs=fake_fs)
+ def test_objectstoragepath_init_conn_id_in_uri(self):
+ attach(protocol="fake", conn_id="fake",
fs=FakeRemoteFileSystem(conn_id="fake"))
p = ObjectStoragePath("fake://fake@bucket/path")
- assert p.stat() == {"stat": "result", "conn_id": "fake", "protocol":
"fake"}
+ p.touch()
+ fsspec_info = p.fs.info(p.path)
+ assert p.stat() == {**fsspec_info, "conn_id": "fake", "protocol":
"fake"}
+
+ @pytest.fixture
+ def fake_local_files(self):
+ obj = FakeLocalFileSystem()
+ obj.touch(FOO)
+ try:
+ yield
+ finally:
+ FakeLocalFileSystem.store.clear()
+ FakeLocalFileSystem.pseudo_dirs[:] = [""]
@pytest.mark.parametrize(
"fn, args, fn2, path, expected_args, expected_kwargs",
[
- ("checksum", {}, "checksum", FOO,
FakeRemoteFileSystem._strip_protocol(BAR), {}),
- ("size", {}, "size", FOO,
FakeRemoteFileSystem._strip_protocol(BAR), {}),
+ ("checksum", {}, "checksum", FOO,
FakeLocalFileSystem._strip_protocol(BAR), {}),
+ ("size", {}, "size", FOO,
FakeLocalFileSystem._strip_protocol(BAR), {}),
(
"sign",
{"expiration": 200, "extra": "xtra"},
"sign",
FOO,
- FakeRemoteFileSystem._strip_protocol(BAR),
+ FakeLocalFileSystem._strip_protocol(BAR),
{"expiration": 200, "extra": "xtra"},
),
- ("ukey", {}, "ukey", FOO,
FakeRemoteFileSystem._strip_protocol(BAR), {}),
+ ("ukey", {}, "ukey", FOO,
FakeLocalFileSystem._strip_protocol(BAR), {}),
(
"read_block",
{"offset": 0, "length": 1},
"read_block",
FOO,
- FakeRemoteFileSystem._strip_protocol(BAR),
+ FakeLocalFileSystem._strip_protocol(BAR),
{"delimiter": None, "length": 1, "offset": 0},
),
],
)
- def test_standard_extended_api(self, fake_fs, fn, args, fn2, path,
expected_args, expected_kwargs):
- store = attach(protocol="file", conn_id="fake", fs=fake_fs)
- o = ObjectStoragePath(path, conn_id="fake")
+ def test_standard_extended_api(
+ self, fake_local_files, fn, args, fn2, path, expected_args,
expected_kwargs
+ ):
+ fs = FakeLocalFileSystem()
+ with mock.patch.object(fs, fn2) as method:
+ attach(protocol="file", conn_id="fake", fs=fs)
+ o = ObjectStoragePath(path, conn_id="fake")
- getattr(o, fn)(**args)
- getattr(store.fs, fn2).assert_called_once_with(expected_args,
**expected_kwargs)
+ getattr(o, fn)(**args)
+ method.assert_called_once_with(expected_args, **expected_kwargs)
def test_stat(self):
with NamedTemporaryFile() as f:
@@ -168,6 +212,8 @@ class TestFs:
assert S_ISDIR(o.parent.stat().st_mode)
def test_bucket_key_protocol(self):
+ attach(protocol="s3", fs=FakeRemoteFileSystem())
+
bucket = "bkt"
key = "yek"
protocol = "s3"
@@ -227,24 +273,23 @@ class TestFs:
_to.unlink()
def test_copy_remote_remote(self):
- # foo = xxx added to prevent same fs token
- attach("ffs", fs=FakeRemoteFileSystem(auto_mkdir=True, foo="bar"))
- attach("ffs2", fs=FakeRemoteFileSystem(auto_mkdir=True, foo="baz"))
+ attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True))
+ attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True))
- dir_src = f"/tmp/{str(uuid.uuid4())}"
- dir_dst = f"/tmp/{str(uuid.uuid4())}"
+ dir_src = f"bucket1/{str(uuid.uuid4())}"
+ dir_dst = f"bucket2/{str(uuid.uuid4())}"
key = "foo/bar/baz.txt"
- # note we are dealing with object storage characteristics
- # while working on a local filesystem, so it might feel not intuitive
_from = ObjectStoragePath(f"ffs://{dir_src}")
_from_file = _from / key
_from_file.touch()
+ assert _from.bucket == "bucket1"
assert _from_file.exists()
_to = ObjectStoragePath(f"ffs2://{dir_dst}")
_from.copy(_to)
+ assert _to.bucket == "bucket2"
assert _to.exists()
assert _to.is_dir()
assert (_to / _from.key / key).exists()
@@ -254,7 +299,7 @@ class TestFs:
_to.rmdir(recursive=True)
def test_serde_objectstoragepath(self):
- path = "file://bucket/key/part1/part2"
+ path = "file:///bucket/key/part1/part2"
o = ObjectStoragePath(path)
s = o.serialize()
@@ -312,6 +357,8 @@ class TestFs:
_register_filesystems.cache_clear()
def test_dataset(self):
+ attach("s3", fs=FakeRemoteFileSystem())
+
p = "s3"
f = "/tmp/foo"
i = Dataset(uri=f"{p}://{f}", extra={"foo": "bar"})
diff --git a/tests/providers/common/io/xcom/test_backend.py
b/tests/providers/common/io/xcom/test_backend.py
index fce5ed985e..0641e18fe0 100644
--- a/tests/providers/common/io/xcom/test_backend.py
+++ b/tests/providers/common/io/xcom/test_backend.py
@@ -181,7 +181,7 @@ class TestXcomObjectStoreBackend:
run_id=task_instance.run_id,
session=session,
)
- assert self.path in qry.first().value
+ assert str(p) == qry.first().value
@pytest.mark.db_test
def test_clear(self, task_instance, session):