This is an automated email from the ASF dual-hosted git repository.

bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fa94ee9ebb Simplify io.path (#35747)
fa94ee9ebb is described below

commit fa94ee9ebbc091cf8e0ec76e498d095b6221bcde
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Nov 27 02:33:14 2023 +0900

    Simplify io.path (#35747)
    
    It seems to me the accessor always only takes conn_id anyway, so I've
    removed the possibility to take a store. A conn_id is always available
    from a store, and the additional initialization costs almost nothing
    due to caching anyway.
---
 airflow/io/path.py    | 50 +++++++++++++++++++++++---------------------------
 tests/io/test_path.py |  6 +++---
 2 files changed, 26 insertions(+), 30 deletions(-)

diff --git a/airflow/io/path.py b/airflow/io/path.py
index 9fd51091dc..f5eeb14eff 100644
--- a/airflow/io/path.py
+++ b/airflow/io/path.py
@@ -40,21 +40,20 @@ if typing.TYPE_CHECKING:
 
 PT = typing.TypeVar("PT", bound="ObjectStoragePath")
 
-default = "file"
-
 
 class _AirflowCloudAccessor(_CloudAccessor):
     __slots__ = ("_store",)
 
-    def __init__(self, parsed_url: SplitResult | None, **kwargs: typing.Any) 
-> None:
-        store = kwargs.pop("store", None)
-        conn_id = kwargs.pop("conn_id", None)
-        if store:
-            self._store = store
-        elif parsed_url and parsed_url.scheme:
+    def __init__(
+        self,
+        parsed_url: SplitResult | None,
+        conn_id: str | None = None,
+        **kwargs: typing.Any,
+    ) -> None:
+        if parsed_url and parsed_url.scheme:
             self._store = attach(parsed_url.scheme, conn_id)
         else:
-            self._store = attach(default, conn_id)
+            self._store = attach("file", conn_id)
 
     @property
     def _fs(self) -> AbstractFileSystem:
@@ -71,7 +70,7 @@ class ObjectStoragePath(CloudPath):
 
     __version__: typing.ClassVar[int] = 1
 
-    _default_accessor = _AirflowCloudAccessor
+    _default_accessor: type[_CloudAccessor] = _AirflowCloudAccessor
 
     sep: typing.ClassVar[str] = "/"
     root_marker: typing.ClassVar[str] = "/"
@@ -89,15 +88,18 @@ class ObjectStoragePath(CloudPath):
         "_hash",
     )
 
-    def __new__(cls: type[PT], *args: str | os.PathLike, **kwargs: typing.Any) 
-> PT:
+    def __new__(
+        cls: type[PT],
+        *args: str | os.PathLike,
+        scheme: str | None = None,
+        **kwargs: typing.Any,
+    ) -> PT:
         args_list = list(args)
 
-        try:
-            other = args_list.pop(0)
-        except IndexError:
-            other = "."
+        if args_list:
+            other = args_list.pop(0) or "."
         else:
-            other = other or "."
+            other = "."
 
         if isinstance(other, PurePath):
             _cls: typing.Any = type(other)
@@ -123,20 +125,14 @@ class ObjectStoragePath(CloudPath):
 
         url = stringify_path(other)
         parsed_url: SplitResult = urlsplit(url)
-        protocol: str | None = split_protocol(url)[0] or parsed_url.scheme
-
-        # allow override of protocol
-        protocol = kwargs.get("scheme", protocol)
 
-        for key in ["scheme", "url"]:
-            val = kwargs.pop(key, None)
-            if val:
-                parsed_url = parsed_url._replace(**{key: val})
+        if scheme:  # allow override of protocol
+            parsed_url = parsed_url._replace(scheme=scheme)
 
-        if not parsed_url.path:
-            parsed_url = parsed_url._replace(path="/")  # ensure path has root
+        if not parsed_url.path:  # ensure path has root
+            parsed_url = parsed_url._replace(path="/")
 
-        if not protocol:
+        if not parsed_url.scheme and not split_protocol(url)[0]:
             args_list.insert(0, url)
         else:
             args_list.insert(0, parsed_url.path)
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index 832f8ae663..1ac263c59f 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -124,13 +124,13 @@ class TestFs:
             ),
         ],
     )
-    def test_standard_extended_api(self, fn, args, fn2, path, expected_args, 
expected_kwargs):
+    def test_standard_extended_api(self, monkeypatch, fn, args, fn2, path, 
expected_args, expected_kwargs):
         _fs = mock.Mock()
         _fs._strip_protocol.return_value = "/"
         _fs.conn_id = "fake"
 
-        store = attach(protocol="mock", fs=_fs)
-        o = ObjectStoragePath(path, store=store)
+        store = attach(protocol="file", conn_id="fake", fs=_fs)
+        o = ObjectStoragePath(path, conn_id="fake")
 
         getattr(o, fn)(**args)
         getattr(store.fs, fn2).assert_called_once_with(expected_args, 
**expected_kwargs)

Reply via email to