Fix _EbuildFetcherProcess to handle SIGTERM, so that FETCHCOMMAND
processes will not be left running in the background:

* Convert the fetch function to an async_fetch coroutine function
  so that it can use asyncio.CancelledError handlers to terminate
  running processes.

* Use multiprocessing.active_children() to detect and terminate
  any processes that asyncio.CancelledError handlers did not have
  an opportunity to terminate because the exception arrived too
  soon after fork/spawn.

* Add unit test to verify that a child process is correctly
  killed when EbuildFetcher is cancelled, with short timeout in
  case it takes some time for the process to disappear.

Bug: https://bugs.gentoo.org/936273
Signed-off-by: Zac Medico <zmed...@gentoo.org>
---
Changes in [PATCH v2]:
* Use multiple multiprocessing.active_children() loops to
  allow some more time for children to respond to SIGTERM.

* Fix event loop recursion in _userpriv_test_write_file.

* Optimize away the spawn in _userpriv_test_write_file if a
  portage group write permission bit is detected.

 lib/_emerge/EbuildFetcher.py           |  68 ++++++++++++---
 lib/portage/package/ebuild/fetch.py    | 116 ++++++++++++++++++++++---
 lib/portage/tests/ebuild/test_fetch.py | 100 ++++++++++++++++++++-
 lib/portage/tests/util/test_socks5.py  |  16 +++-
 4 files changed, 268 insertions(+), 32 deletions(-)

diff --git a/lib/_emerge/EbuildFetcher.py b/lib/_emerge/EbuildFetcher.py
index 81d4b1054b..994271236c 100644
--- a/lib/_emerge/EbuildFetcher.py
+++ b/lib/_emerge/EbuildFetcher.py
@@ -4,6 +4,8 @@
 import copy
 import functools
 import io
+import multiprocessing
+import signal
 import sys
 
 import portage
@@ -17,11 +19,12 @@ from portage.package.ebuild.fetch import (
     _check_distfile,
     _drop_privs_userfetch,
     _want_userfetch,
-    fetch,
+    async_fetch,
 )
 from portage.util._async.AsyncTaskFuture import AsyncTaskFuture
 from portage.util._async.ForkProcess import ForkProcess
 from portage.util._pty import _create_pty_or_pipe
+from portage.util.futures import asyncio
 from _emerge.CompositeTask import CompositeTask
 
 
@@ -34,6 +37,7 @@ class EbuildFetcher(CompositeTask):
         "logfile",
         "pkg",
         "prefetch",
+        "pre_exec",
         "_fetcher_proc",
     )
 
@@ -253,6 +257,7 @@ class _EbuildFetcherProcess(ForkProcess):
             self._get_manifest(),
             self._uri_map,
             self.fetchonly,
+            self.pre_exec,
         )
         ForkProcess._start(self)
 
@@ -263,7 +268,10 @@ class _EbuildFetcherProcess(ForkProcess):
         self._settings = None
 
     @staticmethod
-    def _target(settings, manifest, uri_map, fetchonly):
+    def _target(settings, manifest, uri_map, fetchonly, pre_exec):
+        if pre_exec is not None:
+            pre_exec()
+
         # Force consistent color output, in case we are capturing fetch
         # output through a normal pipe due to unavailability of ptys.
         portage.output.havecolor = settings.get("NOCOLOR") not in ("yes", 
"true")
@@ -273,17 +281,53 @@ class _EbuildFetcherProcess(ForkProcess):
         if _want_userfetch(settings):
             _drop_privs_userfetch(settings)
 
-        rval = 1
         allow_missing = manifest.allow_missing or "digest" in settings.features
-        if fetch(
-            uri_map,
-            settings,
-            fetchonly=fetchonly,
-            digests=copy.deepcopy(manifest.getTypeDigests("DIST")),
-            allow_missing_digests=allow_missing,
-        ):
-            rval = os.EX_OK
-        return rval
+
+        async def main():
+            loop = asyncio.get_event_loop()
+            task = asyncio.ensure_future(
+                async_fetch(
+                    uri_map,
+                    settings,
+                    fetchonly=fetchonly,
+                    digests=copy.deepcopy(manifest.getTypeDigests("DIST")),
+                    allow_missing_digests=allow_missing,
+                )
+            )
+
+            def sigterm_handler(signum, _frame):
+                loop.call_soon_threadsafe(task.cancel)
+                signal.signal(signal.SIGTERM, signal.SIG_IGN)
+
+            signal.signal(signal.SIGTERM, sigterm_handler)
+            try:
+                await task
+            except asyncio.CancelledError:
+                # If asyncio.CancelledError arrives too soon after fork/spawn
+                # then handers will not have an opportunity to terminate
+                # the corresponding process, so clean up after this race.
+                for proc in multiprocessing.active_children():
+                    proc.terminate()
+
+                # Use a non-zero timeout only for the first join because
+                # later joins are delayed by the first join.
+                timeout = 0.25
+                for proc in multiprocessing.active_children():
+                    proc.join(timeout)
+                    timeout = 0
+
+                for proc in multiprocessing.active_children():
+                    proc.kill()
+                    # Wait upon the process in order to ensure that its
+                    # pid will trigger ProcessLookupError for tests.
+                    proc.join()
+
+                signal.signal(signal.SIGTERM, signal.SIG_DFL)
+                os.kill(os.getpid(), signal.SIGTERM)
+
+            return os.EX_OK if task.result() else 1
+
+        return asyncio.run(main())
 
     def _get_ebuild_path(self):
         if self.ebuild_path is not None:
diff --git a/lib/portage/package/ebuild/fetch.py 
b/lib/portage/package/ebuild/fetch.py
index ed40cf6ede..b3ffa7a930 100644
--- a/lib/portage/package/ebuild/fetch.py
+++ b/lib/portage/package/ebuild/fetch.py
@@ -1,4 +1,4 @@
-# Copyright 2010-2021 Gentoo Authors
+# Copyright 2010-2024 Gentoo Authors
 # Distributed under the terms of the GNU General Public License v2
 
 __all__ = ["fetch"]
@@ -73,6 +73,7 @@ from portage.util import (
     writemsg_level,
     writemsg_stdout,
 )
+from portage.util.futures import asyncio
 from portage.process import spawn
 
 _download_suffix = ".__download__"
@@ -111,7 +112,7 @@ def _drop_privs_userfetch(settings):
     """
     spawn_kwargs = dict(_userpriv_spawn_kwargs)
     try:
-        _ensure_distdir(settings, settings["DISTDIR"])
+        asyncio.run(_ensure_distdir(settings, settings["DISTDIR"]))
     except PortageException:
         if not os.path.isdir(settings["DISTDIR"]):
             raise
@@ -179,13 +180,37 @@ def _spawn_fetch(settings, args, **kwargs):
     return rval
 
 
+# Instrumentation hooks for use by unit tests.
+_async_spawn_fetch_pre_wait = None
+_async_spawn_fetch_post_terminate = None
+
+
+async def _async_spawn_fetch(settings, args, **kwargs):
+    kwargs["returnproc"] = True
+    proc = _spawn_fetch(settings, args, **kwargs)
+    try:
+        if _async_spawn_fetch_pre_wait is not None:
+            _async_spawn_fetch_pre_wait(proc)
+        return await proc.wait()
+    except asyncio.CancelledError:
+        proc.terminate()
+        if _async_spawn_fetch_post_terminate is not None:
+            _async_spawn_fetch_post_terminate(proc)
+        raise
+
+
+_async_spawn_fetch.__doc__ = _spawn_fetch.__doc__
+_async_spawn_fetch.__doc__ += """
+    This function is a coroutine.
+"""
+
 _userpriv_test_write_file_cache = {}
 _userpriv_test_write_cmd_script = (
     ">> %(file_path)s 2>/dev/null ; rval=$? ; " + "rm -f  %(file_path)s ; exit 
$rval"
 )
 
 
-def _userpriv_test_write_file(settings, file_path):
+async def _userpriv_test_write_file(settings, file_path):
     """
     Drop privileges and try to open a file for writing. The file may or
     may not exist, and the parent directory is assumed to exist. The file
@@ -201,20 +226,40 @@ def _userpriv_test_write_file(settings, file_path):
     if rval is not None:
         return rval
 
+    # Optimize away the spawn when privileges do not need to be dropped.
+    if not (
+        "userfetch" in settings.features
+        and os.getuid() == 0
+        and portage_gid
+        and portage_uid
+        and hasattr(os, "setgroups")
+    ):
+        rval = os.access(os.path.dirname(file_path), os.W_OK)
+        _userpriv_test_write_file_cache[file_path] = rval
+        return rval
+
+    # Optimize away the spawn if we can detect a portage group write
+    # permission bit, but if this optimization fails then continue with
+    # the spawn for ACL support.
+    st = os.stat(os.path.dirname(file_path))
+    if st.st_gid == int(portage_gid) and stat.S_IMODE(st.st_mode) & 0o020:
+        _userpriv_test_write_file_cache[file_path] = True
+        return True
+
     args = [
         BASH_BINARY,
         "-c",
         _userpriv_test_write_cmd_script % {"file_path": 
_shell_quote(file_path)},
     ]
 
-    returncode = _spawn_fetch(settings, args)
+    returncode = await _async_spawn_fetch(settings, args)
 
     rval = returncode == os.EX_OK
     _userpriv_test_write_file_cache[file_path] = rval
     return rval
 
 
-def _ensure_distdir(settings, distdir):
+async def _ensure_distdir(settings, distdir):
     """
     Ensure that DISTDIR exists with appropriate permissions.
 
@@ -240,7 +285,7 @@ def _ensure_distdir(settings, distdir):
     userpriv = portage.data.secpass >= 2 and "userpriv" in settings.features
     write_test_file = os.path.join(distdir, ".__portage_test_write__")
 
-    if _userpriv_test_write_file(settings, write_test_file):
+    if await _userpriv_test_write_file(settings, write_test_file):
         return
 
     _userpriv_test_write_file_cache.pop(write_test_file, None)
@@ -687,7 +732,12 @@ def get_mirror_url(mirror_url, filename, mysettings, 
cache_path=None):
     @param cache_path: Path for mirror metadata cache
     @return: Full URL to fetch
     """
+    return asyncio.run(
+        async_mirror_url(mirror_url, filename, mysettings, 
cache_path=cache_path)
+    )
 
+
+async def async_mirror_url(mirror_url, filename, mysettings, cache_path=None):
     mirror_conf = MirrorLayoutConfig()
 
     cache = {}
@@ -708,7 +758,7 @@ def get_mirror_url(mirror_url, filename, mysettings, 
cache_path=None):
             if mirror_url[:1] == "/":
                 tmpfile = os.path.join(mirror_url, "layout.conf")
                 mirror_conf.read_from_file(tmpfile)
-            elif fetch(
+            elif await async_fetch(
                 {tmpfile: (mirror_url + "/distfiles/layout.conf",)},
                 mysettings,
                 force=1,
@@ -738,6 +788,12 @@ def get_mirror_url(mirror_url, filename, mysettings, 
cache_path=None):
         return mirror_url + "/distfiles/" + path
 
 
+async_mirror_url.__doc__ = get_mirror_url.__doc__
+async_mirror_url.__doc__ += """
+    This function is a coroutine.
+"""
+
+
 def fetch(
     myuris,
     mysettings,
@@ -783,6 +839,34 @@ def fetch(
     @rtype: int
     @return: 1 if successful, 0 otherwise.
     """
+    return asyncio.run(
+        async_fetch(
+            myuris,
+            mysettings,
+            listonly=listonly,
+            fetchonly=fetchonly,
+            locks_in_subdir=locks_in_subdir,
+            use_locks=use_locks,
+            try_mirrors=try_mirrors,
+            digests=digests,
+            allow_missing_digests=allow_missing_digests,
+            force=force,
+        )
+    )
+
+
+async def async_fetch(
+    myuris,
+    mysettings,
+    listonly=0,
+    fetchonly=0,
+    locks_in_subdir=".locks",
+    use_locks=1,
+    try_mirrors=1,
+    digests=None,
+    allow_missing_digests=True,
+    force=False,
+):
 
     if force and digests:
         # Since the force parameter can trigger unnecessary fetch when the
@@ -1050,7 +1134,7 @@ def fetch(
             for l in itertools.chain(*location_lists):
                 filedict[myfile].append(
                     functools.partial(
-                        get_mirror_url, l, myfile, mysettings, mirror_cache
+                        async_mirror_url, l, myfile, mysettings, mirror_cache
                     )
                 )
         if myuri is None:
@@ -1119,7 +1203,7 @@ def fetch(
 
     if can_fetch and not fetch_to_ro:
         try:
-            _ensure_distdir(mysettings, mysettings["DISTDIR"])
+            await _ensure_distdir(mysettings, mysettings["DISTDIR"])
         except PortageException as e:
             if not os.path.isdir(mysettings["DISTDIR"]):
                 writemsg(f"!!! {str(e)}\n", noiselevel=-1)
@@ -1381,7 +1465,7 @@ def fetch(
                 if distdir_writable and ro_distdirs:
                     readonly_file = None
                     for x in ro_distdirs:
-                        filename = get_mirror_url(x, myfile, mysettings)
+                        filename = await async_mirror_url(x, myfile, 
mysettings)
                         match, mystat = _check_distfile(
                             filename, pruned_digests, eout, 
hash_filter=hash_filter
                         )
@@ -1427,7 +1511,7 @@ def fetch(
 
                 if fsmirrors and not os.path.exists(myfile_path) and has_space:
                     for mydir in fsmirrors:
-                        mirror_file = get_mirror_url(mydir, myfile, mysettings)
+                        mirror_file = await async_mirror_url(mydir, myfile, 
mysettings)
                         try:
                             shutil.copyfile(mirror_file, download_path)
                             writemsg(_("Local mirror has file: %s\n") % myfile)
@@ -1554,7 +1638,7 @@ def fetch(
             while uri_list:
                 loc = uri_list.pop()
                 if isinstance(loc, functools.partial):
-                    loc = loc()
+                    loc = await loc()
                 # Eliminate duplicates here in case we've switched to
                 # "primaryuri" mode on the fly due to a checksum failure.
                 if loc in tried_locations:
@@ -1740,7 +1824,7 @@ def fetch(
 
                     myret = -1
                     try:
-                        myret = _spawn_fetch(mysettings, myfetch)
+                        myret = await _async_spawn_fetch(mysettings, myfetch)
 
                     finally:
                         try:
@@ -1992,3 +2076,9 @@ def fetch(
     if failed_files:
         return 0
     return 1
+
+
+async_fetch.__doc__ = fetch.__doc__
+async_fetch.__doc__ += """
+    This function is a coroutine.
+"""
diff --git a/lib/portage/tests/ebuild/test_fetch.py 
b/lib/portage/tests/ebuild/test_fetch.py
index 1856bb52b8..1ad9580362 100644
--- a/lib/portage/tests/ebuild/test_fetch.py
+++ b/lib/portage/tests/ebuild/test_fetch.py
@@ -1,9 +1,11 @@
-# Copyright 2019-2023 Gentoo Authors
+# Copyright 2019-2024 Gentoo Authors
 # Distributed under the terms of the GNU General Public License v2
 
 import functools
 import io
+import multiprocessing
 import shlex
+import signal
 import tempfile
 import types
 
@@ -36,6 +38,85 @@ from _emerge.Package import Package
 
 
 class EbuildFetchTestCase(TestCase):
+
+    async def _test_interrupt(self, loop, server, async_fetch, pkg, 
ebuild_path):
+        """Test interrupt, with server responses temporarily paused."""
+        server.pause()
+        pr, pw = multiprocessing.Pipe(duplex=False)
+        timeout = loop.create_future()
+        loop.add_reader(pr.fileno(), lambda: timeout.done() or 
timeout.set_result(None))
+        self.assertEqual(
+            await async_fetch(
+                pkg,
+                ebuild_path,
+                timeout=timeout,
+                pre_exec=functools.partial(self._pre_exec_interrupt_patch, pw),
+            ),
+            -signal.SIGTERM,
+        )
+        loop.remove_reader(pr.fileno())
+        pw.close()
+
+        # Read pid written by _async_spawn_fetch_pre_wait hook (the
+        # corresponding write served to trigger the timeout above).
+        pid = pr.recv()
+
+        # Read pid written by _async_spawn_fetch_post_terminate hook,
+        # in order to know when the ProcessLookupError test should
+        # succeed.
+        pid = pr.recv()
+        pr.close()
+
+        # Poll the process table until the pid has disappeared,
+        # and fail if a short timeout expires.
+        tries = 10
+        while tries:
+            tries -= 1
+
+            msg = None
+            if tries <= 0:
+                try:
+                    with open(f"/proc/{pid}/status") as f:
+                        for line in f:
+                            if line.startswith("State:"):
+                                msg = line
+                                break
+                except OSError:
+                    pass
+
+            try:
+                with self.assertRaises(ProcessLookupError, msg=msg):
+                    os.kill(pid, 0)
+            except Exception:
+                if tries <= 0:
+                    raise
+                await asyncio.sleep(0.1)
+            else:
+                break
+
+        server.resume()
+
+    @staticmethod
+    def _pre_exec_interrupt_patch(pw):
+        portage.package.ebuild.fetch._async_spawn_fetch_pre_wait = 
functools.partial(
+            EbuildFetchTestCase._fetch_pre_wait,
+            pw,
+        )
+        portage.package.ebuild.fetch._async_spawn_fetch_post_terminate = (
+            functools.partial(
+                EbuildFetchTestCase._fetch_post_terminate,
+                pw,
+            )
+        )
+
+    @staticmethod
+    def _fetch_pre_wait(pw, proc):
+        pw.send(proc.pid)
+
+    @staticmethod
+    def _fetch_post_terminate(pw, proc):
+        pw.send(proc.pid)
+
     def testEbuildFetch(self):
         user_config = {
             "make.conf": ('GENTOO_MIRRORS="{scheme}://{host}:{port}"',),
@@ -338,7 +419,7 @@ class EbuildFetchTestCase(TestCase):
 
             config_pool = config_pool_cls(settings)
 
-            def async_fetch(pkg, ebuild_path):
+            def async_fetch(pkg, ebuild_path, pre_exec=None, timeout=None):
                 fetcher = EbuildFetcher(
                     config_pool=config_pool,
                     ebuild_path=ebuild_path,
@@ -346,9 +427,15 @@ class EbuildFetchTestCase(TestCase):
                     fetchall=True,
                     pkg=pkg,
                     scheduler=loop,
+                    pre_exec=pre_exec,
                 )
                 fetcher.start()
-                return fetcher.async_wait()
+                waiter = fetcher.async_wait()
+                if timeout is not None:
+                    timeout.add_done_callback(
+                        lambda timeout: waiter.done() or fetcher.cancel()
+                    )
+                return waiter
 
             for cpv in ebuilds:
                 metadata = dict(
@@ -414,6 +501,13 @@ class EbuildFetchTestCase(TestCase):
                     with open(os.path.join(settings["DISTDIR"], k), "rb") as f:
                         self.assertEqual(f.read(), distfiles[k])
 
+                # Test interrupt, with server responses temporarily paused.
+                for k in settings["AA"].split():
+                    os.unlink(os.path.join(settings["DISTDIR"], k))
+                loop.run_until_complete(
+                    self._test_interrupt(loop, server, async_fetch, pkg, 
ebuild_path)
+                )
+
                 # Test empty files in DISTDIR
                 for k in settings["AA"].split():
                     file_path = os.path.join(settings["DISTDIR"], k)
diff --git a/lib/portage/tests/util/test_socks5.py 
b/lib/portage/tests/util/test_socks5.py
index a8cd0c46c4..e7bc2d699f 100644
--- a/lib/portage/tests/util/test_socks5.py
+++ b/lib/portage/tests/util/test_socks5.py
@@ -58,19 +58,27 @@ class AsyncHTTPServer:
         self.server_port = None
         self._httpd = None
 
+    def pause(self):
+        """Pause responses (useful for testing timeouts)."""
+        self._loop.remove_reader(self._httpd.socket.fileno())
+
+    def resume(self):
+        """Resume responses following a previous call to pause."""
+        self._loop.add_reader(
+            self._httpd.socket.fileno(), self._httpd._handle_request_noblock
+        )
+
     def __enter__(self):
         httpd = self._httpd = HTTPServer(
             (self._host, 0), functools.partial(_Handler, self._content)
         )
         self.server_port = httpd.server_port
-        self._loop.add_reader(
-            httpd.socket.fileno(), self._httpd._handle_request_noblock
-        )
+        self.resume()
         return self
 
     def __exit__(self, exc_type, exc_value, exc_traceback):
         if self._httpd is not None:
-            self._loop.remove_reader(self._httpd.socket.fileno())
+            self.pause()
             self._httpd.socket.close()
             self._httpd = None
 
-- 
2.44.2


Reply via email to