https://github.com/python/cpython/commit/d7ae4dc5c14bc014ca0c056dab54c86ba8f395cb
commit: d7ae4dc5c14bc014ca0c056dab54c86ba8f395cb
branch: main
author: Barney Gale <[email protected]>
committer: barneygale <[email protected]>
date: 2024-08-23T20:03:11+01:00
summary:

GH-73991: Disallow copying directory into itself via `pathlib.Path.copy()` 
(#122924)

files:
M Lib/pathlib/_abc.py
M Lib/test/test_pathlib/test_pathlib_abc.py

diff --git a/Lib/pathlib/_abc.py b/Lib/pathlib/_abc.py
index 720756cac66f68..9943ea4d14148e 100644
--- a/Lib/pathlib/_abc.py
+++ b/Lib/pathlib/_abc.py
@@ -14,6 +14,7 @@
 import functools
 import operator
 import posixpath
+from errno import EINVAL
 from glob import _GlobberBase, _no_recurse_symlinks
 from stat import S_ISDIR, S_ISLNK, S_ISREG, S_ISSOCK, S_ISBLK, S_ISCHR, 
S_ISFIFO
 from pathlib._os import copyfileobj
@@ -564,14 +565,38 @@ def samefile(self, other_path):
         return (st.st_ino == other_st.st_ino and
                 st.st_dev == other_st.st_dev)
 
-    def _samefile_safe(self, other_path):
+    def _ensure_different_file(self, other_path):
         """
-        Like samefile(), but returns False rather than raising OSError.
+        Raise OSError(EINVAL) if both paths refer to the same file.
         """
         try:
-            return self.samefile(other_path)
+            if not self.samefile(other_path):
+                return
         except (OSError, ValueError):
-            return False
+            return
+        err = OSError(EINVAL, "Source and target are the same file")
+        err.filename = str(self)
+        err.filename2 = str(other_path)
+        raise err
+
+    def _ensure_distinct_path(self, other_path):
+        """
+        Raise OSError(EINVAL) if the other path is within this path.
+        """
+        # Note: there is no straightforward, foolproof algorithm to determine
+        # if one directory is within another (a particularly perverse example
+        # would be a single network share mounted in one location via NFS, and
+        # in another location via CIFS), so we simply checks whether the
+        # other path is lexically equal to, or within, this path.
+        if self == other_path:
+            err = OSError(EINVAL, "Source and target are the same path")
+        elif self in other_path.parents:
+            err = OSError(EINVAL, "Source path is a parent of target path")
+        else:
+            return
+        err.filename = str(self)
+        err.filename2 = str(other_path)
+        raise err
 
     def open(self, mode='r', buffering=-1, encoding=None,
              errors=None, newline=None):
@@ -826,8 +851,7 @@ def _copy_file(self, target):
         """
         Copy the contents of this file to the given target.
         """
-        if self._samefile_safe(target):
-            raise OSError(f"{self!r} and {target!r} are the same file")
+        self._ensure_different_file(target)
         with self.open('rb') as source_f:
             try:
                 with target.open('wb') as target_f:
@@ -847,6 +871,13 @@ def copy(self, target, *, follow_symlinks=True, 
dirs_exist_ok=False,
         """
         if not isinstance(target, PathBase):
             target = self.with_segments(target)
+        try:
+            self._ensure_distinct_path(target)
+        except OSError as err:
+            if on_error is None:
+                raise
+            on_error(err)
+            return
         stack = [(self, target)]
         while stack:
             src, dst = stack.pop()
diff --git a/Lib/test/test_pathlib/test_pathlib_abc.py 
b/Lib/test/test_pathlib/test_pathlib_abc.py
index f222fd5b1ec082..5b714756e95e10 100644
--- a/Lib/test/test_pathlib/test_pathlib_abc.py
+++ b/Lib/test/test_pathlib/test_pathlib_abc.py
@@ -1501,19 +1501,20 @@ def iterdir(self):
             raise FileNotFoundError(errno.ENOENT, "File not found", path)
 
     def mkdir(self, mode=0o777, parents=False, exist_ok=False):
-        path = str(self.resolve())
-        if path in self._directories:
+        path = str(self.parent.resolve() / self.name)
+        parent = str(self.parent.resolve())
+        if path in self._directories or path in self._symlinks:
             if exist_ok:
                 return
             else:
                 raise FileExistsError(errno.EEXIST, "File exists", path)
         try:
             if self.name:
-                self._directories[str(self.parent)].add(self.name)
+                self._directories[parent].add(self.name)
             self._directories[path] = set()
         except KeyError:
             if not parents:
-                raise FileNotFoundError(errno.ENOENT, "File not found", 
str(self.parent)) from None
+                raise FileNotFoundError(errno.ENOENT, "File not found", 
parent) from None
             self.parent.mkdir(parents=True, exist_ok=True)
             self.mkdir(mode, parents=False, exist_ok=exist_ok)
 
@@ -1758,6 +1759,32 @@ def test_copy_symlink_follow_symlinks_false(self):
         self.assertTrue(target.is_symlink())
         self.assertEqual(source.readlink(), target.readlink())
 
+    @needs_symlinks
+    def test_copy_symlink_to_itself(self):
+        base = self.cls(self.base)
+        source = base / 'linkA'
+        self.assertRaises(OSError, source.copy, source)
+
+    @needs_symlinks
+    def test_copy_symlink_to_existing_symlink(self):
+        base = self.cls(self.base)
+        source = base / 'copySource'
+        target = base / 'copyTarget'
+        source.symlink_to(base / 'fileA')
+        target.symlink_to(base / 'dirC')
+        self.assertRaises(OSError, source.copy, target)
+        self.assertRaises(OSError, source.copy, target, follow_symlinks=False)
+
+    @needs_symlinks
+    def test_copy_symlink_to_existing_directory_symlink(self):
+        base = self.cls(self.base)
+        source = base / 'copySource'
+        target = base / 'copyTarget'
+        source.symlink_to(base / 'fileA')
+        target.symlink_to(base / 'dirC')
+        self.assertRaises(OSError, source.copy, target)
+        self.assertRaises(OSError, source.copy, target, follow_symlinks=False)
+
     @needs_symlinks
     def test_copy_directory_symlink_follow_symlinks_false(self):
         base = self.cls(self.base)
@@ -1769,6 +1796,42 @@ def 
test_copy_directory_symlink_follow_symlinks_false(self):
         self.assertTrue(target.is_symlink())
         self.assertEqual(source.readlink(), target.readlink())
 
+    @needs_symlinks
+    def test_copy_directory_symlink_to_itself(self):
+        base = self.cls(self.base)
+        source = base / 'linkB'
+        self.assertRaises(OSError, source.copy, source)
+        self.assertRaises(OSError, source.copy, source, follow_symlinks=False)
+
+    @needs_symlinks
+    def test_copy_directory_symlink_into_itself(self):
+        base = self.cls(self.base)
+        source = base / 'linkB'
+        target = base / 'linkB' / 'copyB'
+        self.assertRaises(OSError, source.copy, target)
+        self.assertRaises(OSError, source.copy, target, follow_symlinks=False)
+        self.assertFalse(target.exists())
+
+    @needs_symlinks
+    def test_copy_directory_symlink_to_existing_symlink(self):
+        base = self.cls(self.base)
+        source = base / 'copySource'
+        target = base / 'copyTarget'
+        source.symlink_to(base / 'dirC')
+        target.symlink_to(base / 'fileA')
+        self.assertRaises(FileExistsError, source.copy, target)
+        self.assertRaises(FileExistsError, source.copy, target, 
follow_symlinks=False)
+
+    @needs_symlinks
+    def test_copy_directory_symlink_to_existing_directory_symlink(self):
+        base = self.cls(self.base)
+        source = base / 'copySource'
+        target = base / 'copyTarget'
+        source.symlink_to(base / 'dirC' / 'dirD')
+        target.symlink_to(base / 'dirC')
+        self.assertRaises(FileExistsError, source.copy, target)
+        self.assertRaises(FileExistsError, source.copy, target, 
follow_symlinks=False)
+
     def test_copy_file_to_existing_file(self):
         base = self.cls(self.base)
         source = base / 'fileA'
@@ -1782,8 +1845,7 @@ def test_copy_file_to_existing_directory(self):
         base = self.cls(self.base)
         source = base / 'fileA'
         target = base / 'dirA'
-        with self.assertRaises(OSError):
-            source.copy(target)
+        self.assertRaises(OSError, source.copy, target)
 
     @needs_symlinks
     def test_copy_file_to_existing_symlink(self):
@@ -1823,6 +1885,13 @@ def test_copy_file_empty(self):
         self.assertTrue(target.exists())
         self.assertEqual(target.read_bytes(), b'')
 
+    def test_copy_file_to_itself(self):
+        base = self.cls(self.base)
+        source = base / 'empty'
+        source.write_bytes(b'')
+        self.assertRaises(OSError, source.copy, source)
+        self.assertRaises(OSError, source.copy, source, follow_symlinks=False)
+
     def test_copy_dir_simple(self):
         base = self.cls(self.base)
         source = base / 'dirC'
@@ -1909,6 +1978,28 @@ def 
test_copy_dir_to_existing_directory_dirs_exist_ok(self):
         self.assertTrue(target.joinpath('fileC').read_text(),
                         "this is file C\n")
 
+    def test_copy_dir_to_itself(self):
+        base = self.cls(self.base)
+        source = base / 'dirC'
+        self.assertRaises(OSError, source.copy, source)
+        self.assertRaises(OSError, source.copy, source, follow_symlinks=False)
+
+    def test_copy_dir_to_itself_on_error(self):
+        base = self.cls(self.base)
+        source = base / 'dirC'
+        errors = []
+        source.copy(source, on_error=errors.append)
+        self.assertEqual(len(errors), 1)
+        self.assertIsInstance(errors[0], OSError)
+
+    def test_copy_dir_into_itself(self):
+        base = self.cls(self.base)
+        source = base / 'dirC'
+        target = base / 'dirC' / 'dirD' / 'copyC'
+        self.assertRaises(OSError, source.copy, target)
+        self.assertRaises(OSError, source.copy, target, follow_symlinks=False)
+        self.assertFalse(target.exists())
+
     def test_copy_missing_on_error(self):
         base = self.cls(self.base)
         source = base / 'foo'
@@ -2876,8 +2967,12 @@ def readlink(self):
             raise FileNotFoundError(errno.ENOENT, "File not found", path)
 
     def symlink_to(self, target, target_is_directory=False):
-        self._directories[str(self.parent)].add(self.name)
-        self._symlinks[str(self)] = str(target)
+        path = str(self.parent.resolve() / self.name)
+        parent = str(self.parent.resolve())
+        if path in self._symlinks:
+            raise FileExistsError(errno.EEXIST, "File exists", path)
+        self._directories[parent].add(self.name)
+        self._symlinks[path] = str(target)
 
 
 class DummyPathWithSymlinksTest(DummyPathTest):

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to