This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d42623e2b46 Fix `ObjectStoragePath` to exclude `conn_id` from storage
options passed to fsspec (#62701)
d42623e2b46 is described below
commit d42623e2b461fa362f2a5f7e6f2ca9a028d9e4b9
Author: Jason(Zhe-You) Liu <[email protected]>
AuthorDate: Tue Mar 3 21:04:22 2026 +0800
Fix `ObjectStoragePath` to exclude `conn_id` from storage options passed to
fsspec (#62701)
* Fix ObjectStoragePath to exclude conn_id from storage options passed to
fsspec
* Add conn_id property to avoid re-initialization
* Fix div/ joinpath operations by overriding _from_upath method
* Fix mypy
---
task-sdk/src/airflow/sdk/io/path.py | 48 ++++++++++++++++++++++-----------
task-sdk/tests/task_sdk/io/test_path.py | 43 +++++++++++++++++++++++++++++
2 files changed, 75 insertions(+), 16 deletions(-)
diff --git a/task-sdk/src/airflow/sdk/io/path.py
b/task-sdk/src/airflow/sdk/io/path.py
index 89c49759c77..cf78bb4ebc6 100644
--- a/task-sdk/src/airflow/sdk/io/path.py
+++ b/task-sdk/src/airflow/sdk/io/path.py
@@ -23,7 +23,7 @@ from urllib.parse import urlsplit
from fsspec.utils import stringify_path
from upath import UPath
-from upath.extensions import ProxyUPath
+from upath.extensions import ProxyUPath, classmethod_or_method
from airflow.sdk.io.stat import stat_result
from airflow.sdk.io.store import attach
@@ -84,7 +84,7 @@ class ObjectStoragePath(ProxyUPath):
sep: ClassVar[str] = "/"
root_marker: ClassVar[str] = "/"
- __slots__ = ("_hash_cached",)
+ __slots__ = ("_conn_id", "_hash_cached")
def __init__(
self,
@@ -99,7 +99,7 @@ class ObjectStoragePath(ProxyUPath):
if args:
arg0 = args[0]
if isinstance(arg0, type(self)):
- storage_options["conn_id"] =
arg0.storage_options.get("conn_id")
+ storage_options["conn_id"] = arg0.conn_id
else:
parsed_url = urlsplit(stringify_path(arg0))
userinfo, have_info, hostinfo =
parsed_url.netloc.rpartition("@")
@@ -111,13 +111,33 @@ class ObjectStoragePath(ProxyUPath):
# override conn_id if explicitly provided
if conn_id is not None:
storage_options["conn_id"] = conn_id
+
+ # pop conn_id before calling super to prevent it from being passed
+ # to the underlying fsspec filesystem, which doesn't understand it
+ self._conn_id = storage_options.pop("conn_id", None)
super().__init__(*args, protocol=protocol, **storage_options)
+ @classmethod_or_method # type: ignore[arg-type]
+ def _from_upath(cls_or_self, upath, /):
+ """Wrap a UPath, propagating conn_id from the calling instance."""
+ is_instance = isinstance(cls_or_self, ObjectStoragePath)
+ cls = type(cls_or_self) if is_instance else cls_or_self
+ if isinstance(upath, cls):
+ return upath
+ obj = object.__new__(cls)
+ obj.__wrapped__ = upath
+ obj._conn_id = getattr(cls_or_self, "_conn_id", None) if is_instance
else None
+ return obj
+
+ @property
+ def conn_id(self) -> str | None:
+ """Return the connection ID for this path."""
+ return getattr(self, "_conn_id", None)
+
@property
def fs(self) -> AbstractFileSystem:
"""Return the filesystem for this path, using airflow's attach
mechanism."""
- conn_id = self.storage_options.get("conn_id")
- return attach(self.protocol or "file", conn_id).fs
+ return attach(self.protocol or "file", self.conn_id).fs
def __hash__(self) -> int:
self._hash_cached: int
@@ -134,7 +154,7 @@ class ObjectStoragePath(ProxyUPath):
return (
isinstance(other, ObjectStoragePath)
and self.protocol == other.protocol
- and self.storage_options.get("conn_id") ==
other.storage_options.get("conn_id")
+ and self.conn_id == other.conn_id
)
@property
@@ -169,7 +189,7 @@ class ObjectStoragePath(ProxyUPath):
return stat_result(
self.fs.stat(self.path),
protocol=self.protocol,
- conn_id=self.storage_options.get("conn_id"),
+ conn_id=self.conn_id,
)
def samefile(self, other_path: Any) -> bool:
@@ -353,7 +373,7 @@ class ObjectStoragePath(ProxyUPath):
src_obj = ObjectStoragePath(
path,
protocol=self.protocol,
- conn_id=self.storage_options.get("conn_id"),
+ conn_id=self.conn_id,
)
# skip directories, empty directories will not be created
@@ -424,13 +444,10 @@ class ObjectStoragePath(ProxyUPath):
self.move(dst_path, recursive=recursive, **kwargs)
def serialize(self) -> dict[str, Any]:
- _kwargs = {**self.storage_options}
- conn_id = _kwargs.pop("conn_id", None)
-
return {
"path": str(self),
- "conn_id": conn_id,
- "kwargs": _kwargs,
+ "conn_id": self.conn_id,
+ "kwargs": {**self.storage_options},
}
@classmethod
@@ -445,7 +462,6 @@ class ObjectStoragePath(ProxyUPath):
return ObjectStoragePath(path, conn_id=conn_id, **_kwargs)
def __str__(self):
- conn_id = self.storage_options.get("conn_id")
- if self.protocol and conn_id:
- return f"{self.protocol}://{conn_id}@{self.path}"
+ if self.protocol and self.conn_id:
+ return f"{self.protocol}://{self.conn_id}@{self.path}"
return super().__str__()
diff --git a/task-sdk/tests/task_sdk/io/test_path.py
b/task-sdk/tests/task_sdk/io/test_path.py
index 01fc925de30..e65c6584293 100644
--- a/task-sdk/tests/task_sdk/io/test_path.py
+++ b/task-sdk/tests/task_sdk/io/test_path.py
@@ -59,6 +59,49 @@ def test_str(input_str):
assert str(o) == input_str
+class TestConnIdPropagation:
+ """conn_id must survive all path-producing operations."""
+
+ @pytest.fixture
+ def base(self):
+ return ObjectStoragePath("s3://aws_default@bucket/prefix")
+
+ def test_truediv(self, base):
+ child = base / "x"
+ assert child.conn_id == "aws_default"
+
+ def test_joinpath(self, base):
+ child = base.joinpath("a", "b")
+ assert child.conn_id == "aws_default"
+
+ def test_parent(self, base):
+ assert base.parent.conn_id == "aws_default"
+
+ def test_parents(self, base):
+ for p in base.parents:
+ assert p.conn_id == "aws_default"
+
+ def test_with_name(self, base):
+ assert base.with_name("other").conn_id == "aws_default"
+
+ def test_with_suffix(self, base):
+ p = ObjectStoragePath("s3://aws_default@bucket/file.txt")
+ assert p.with_suffix(".csv").conn_id == "aws_default"
+
+ def test_with_stem(self, base):
+ p = ObjectStoragePath("s3://aws_default@bucket/file.txt")
+ assert p.with_stem("other").conn_id == "aws_default"
+
+ def test_nested_truediv(self, base):
+ grandchild = base / "x" / "y" / "z"
+ assert grandchild.conn_id == "aws_default"
+
+ def test_no_conn_id_stays_none(self):
+ p = ObjectStoragePath("s3://bucket/key")
+ child = p / "x"
+ assert child.conn_id is None
+
+
def test_cwd():
assert ObjectStoragePath.cwd()