This is an automated email from the ASF dual-hosted git repository. sbp pushed a commit to branch sbp in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit c0900cf5fdaa64204a9e2d631b77c8fbcc616baf Author: Sean B. Palmer <[email protected]> AuthorDate: Thu Feb 26 20:12:27 2026 +0000 Add a task to validate quarantined files and reject or promote them --- atr/db/__init__.py | 25 +++ atr/docs/code-conventions.md | 16 ++ atr/hashes.py | 4 + atr/models/sql.py | 1 + atr/storage/writers/revision.py | 6 +- atr/tasks/__init__.py | 3 + atr/tasks/quarantine.py | 214 +++++++++++++++++++++++++ tests/unit/test_create_revision.py | 2 + tests/unit/test_quarantine_task.py | 312 +++++++++++++++++++++++++++++++++++++ 9 files changed, 581 insertions(+), 2 deletions(-) diff --git a/atr/db/__init__.py b/atr/db/__init__.py index 97b1f967..6b3e1504 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -473,6 +473,31 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): result = await self.execute(stmt) return result.scalars().one_or_none() + def quarantined( + self, + id: Opt[int] = NOT_SET, + release_name: Opt[str] = NOT_SET, + status: Opt[sql.QuarantineStatus] = NOT_SET, + token: Opt[str] = NOT_SET, + _release: bool = False, + ) -> Query[sql.Quarantined]: + query = sqlmodel.select(sql.Quarantined) + + via = sql.validate_instrumented_attribute + if is_defined(id): + query = query.where(via(sql.Quarantined.id) == id) + if is_defined(release_name): + query = query.where(sql.Quarantined.release_name == release_name) + if is_defined(status): + query = query.where(sql.Quarantined.status == status) + if is_defined(token): + query = query.where(sql.Quarantined.token == token) + + if _release: + query = query.options(joined_load(sql.Quarantined.release)) + + return Query(self, query) + def release( self, name: Opt[str] = NOT_SET, diff --git a/atr/docs/code-conventions.md b/atr/docs/code-conventions.md index 1786ec1a..2c8f77d1 100644 --- a/atr/docs/code-conventions.md +++ b/atr/docs/code-conventions.md @@ -216,6 +216,22 @@ Exceptions to this rule apply only in these scenarios: If either exception applies, either submit a brief issue with the blockbuster traceback, notify the team via Slack, or add a code comment if part of another commit. An ATR Tooling engineer will address the issue without requiring significant time investment from you. +### Use explicit `commit()` for database transactions + +When writing database mutations within a `db.session()`, prefer calling `await data.commit()` explicitly after the mutations, rather than wrapping them in `async with data.begin():`. The explicit commit makes the transaction boundary visible and is the more common pattern that we use. + +```python +# Prefer +async with db.session() as data: + data.add(item) + await data.commit() + +# Avoid +async with db.session() as data: + async with data.begin(): + data.add(item) +``` + ### Always use parentheses to group complex nested subexpressions Complex subexpressions are those which contain a keyword or operator. diff --git a/atr/hashes.py b/atr/hashes.py index 35c5ac0c..274abf33 100644 --- a/atr/hashes.py +++ b/atr/hashes.py @@ -63,3 +63,7 @@ async def file_sha3(path: str) -> str: while chunk := await f.read(4096): sha3.update(chunk) return sha3.hexdigest() + + +def filesystem_cache_archives_key(content_hash: str) -> str: + return content_hash.replace(":", "_") diff --git a/atr/models/sql.py b/atr/models/sql.py index 1435a986..86a5d39b 100644 --- a/atr/models/sql.py +++ b/atr/models/sql.py @@ -206,6 +206,7 @@ class TaskType(enum.StrEnum): MESSAGE_SEND = "message_send" METADATA_UPDATE = "metadata_update" PATHS_CHECK = "paths_check" + QUARANTINE_VALIDATE = "quarantine_validate" RAT_CHECK = "rat_check" SBOM_AUGMENT = "sbom_augment" SBOM_GENERATE_CYCLONEDX = "sbom_generate_cyclonedx" diff --git a/atr/storage/writers/revision.py b/atr/storage/writers/revision.py index c99443cb..ac67099d 100644 --- a/atr/storage/writers/revision.py +++ b/atr/storage/writers/revision.py @@ -66,7 +66,7 @@ class SafeSession: return False -async def _finalise_revision( +async def finalise_revision( data: db.Session, *, asf_uid: str, @@ -85,6 +85,7 @@ async def _finalise_revision( temp_dir: str, temp_dir_path: pathlib.Path, version_name: str, + was_quarantined: bool = False, ) -> sql.Revision: try: # This is the only place where models.Revision is constructed @@ -98,6 +99,7 @@ async def _finalise_revision( created=datetime.datetime.now(datetime.UTC), phase=release.phase, description=description, + was_quarantined=was_quarantined, ) # Acquire the write lock and add the row @@ -326,7 +328,7 @@ class CommitteeParticipant(FoundationCommitter): raise async with SafeSession(temp_dir) as data: - return await _finalise_revision( + return await finalise_revision( data, asf_uid=asf_uid, base_hashes=base_hashes, diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 6c4a47f0..b79b8017 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -45,6 +45,7 @@ import atr.tasks.gha as gha import atr.tasks.keys as keys import atr.tasks.message as message import atr.tasks.metadata as metadata +import atr.tasks.quarantine as quarantine import atr.tasks.sbom as sbom import atr.tasks.svn as svn import atr.tasks.vote as vote @@ -303,6 +304,8 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results return metadata.update case sql.TaskType.PATHS_CHECK: return paths.check + case sql.TaskType.QUARANTINE_VALIDATE: + return quarantine.validate case sql.TaskType.RAT_CHECK: return rat.check case sql.TaskType.SBOM_AUGMENT: diff --git a/atr/tasks/quarantine.py b/atr/tasks/quarantine.py new file mode 100644 index 00000000..9ca22bbc --- /dev/null +++ b/atr/tasks/quarantine.py @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +import datetime +import pathlib + +import aiofiles.os +import aioshutil + +import atr.attestable as attestable +import atr.config as config +import atr.db as db +import atr.detection as detection +import atr.hashes as hashes +import atr.log as log +import atr.models.results as results +import atr.models.schema as schema +import atr.models.sql as sql +import atr.paths as paths +import atr.storage.writers.revision as revision +import atr.tasks.checks as checks +import atr.util as util + + +class QuarantineArchiveEntry(schema.Strict): + rel_path: str + content_hash: str + + +class QuarantineValidate(schema.Strict): + quarantined_id: int + archives: list[QuarantineArchiveEntry] + + [email protected]_model(QuarantineValidate) +async def validate(args: QuarantineValidate) -> results.Results | None: + async with db.session() as data: + quarantined = await data.quarantined(id=args.quarantined_id, _release=True).get() + + if quarantined is None: + log.error(f"Quarantined row {args.quarantined_id} not found") + return None + + if quarantined.status != sql.QuarantineStatus.PENDING: + log.error(f"Quarantined row {args.quarantined_id} is not PENDING") + return None + + release = quarantined.release + project_name = release.project_name + version_name = release.version + quarantine_dir = paths.quarantine_directory(quarantined) + + if not await aiofiles.os.path.isdir(quarantine_dir): + await _mark_failed(quarantined, None, "Quarantine directory does not exist") + return None + + file_entries, any_failed = await _run_safety_checks(args.archives, quarantine_dir) + + if any_failed: + await _mark_failed(quarantined, file_entries) + await aioshutil.rmtree(quarantine_dir) + return None + + try: + await _extract_archives_to_cache(args.archives, quarantine_dir) + except Exception: + await _mark_failed(quarantined, file_entries, "Archive extraction to cache failed") + await aioshutil.rmtree(quarantine_dir) + return None + + await _promote(quarantined, project_name, version_name, release, str(quarantine_dir)) + return None + + +async def _extract_archives_to_cache(archives: list[QuarantineArchiveEntry], quarantine_dir: pathlib.Path) -> None: + # Cannot import as archives because that shadows the parameter name + import atr.archives + + conf = config.get() + cache_base = pathlib.Path(conf.STATE_DIR) / "cache" / "archives" + await aiofiles.os.makedirs(cache_base, exist_ok=True) + + for archive in archives: + cache_dir = cache_base / hashes.filesystem_cache_archives_key(archive.content_hash) + if await aiofiles.os.path.isdir(cache_dir): + continue + archive_path = str(quarantine_dir / archive.rel_path) + extract_dir = str(cache_dir) + await aiofiles.os.makedirs(extract_dir, exist_ok=True) + try: + await asyncio.to_thread( + atr.archives.extract, + archive_path, + extract_dir, + max_size=conf.MAX_EXTRACT_SIZE, + chunk_size=conf.EXTRACT_CHUNK_SIZE, + ) + except Exception: + log.exception(f"Failed to extract archive {archive.rel_path} to cache") + await aioshutil.rmtree(cache_dir, ignore_errors=True) + raise + + +async def _mark_failed( + quarantined: sql.Quarantined, + file_entries: list[sql.QuarantineFileEntryV1] | None, + message: str | None = None, +) -> None: + async with db.session() as data: + managed = await data.merge(quarantined) + managed.status = sql.QuarantineStatus.FAILED + managed.completed = datetime.datetime.now(datetime.UTC) + if file_entries is not None: + managed.file_metadata = file_entries + await data.commit() + if message: + log.error(f"Quarantine {quarantined.id} failed: {message}") + else: + log.error(f"Quarantine {quarantined.id} failed safety checks") + + +async def _promote( + quarantined: sql.Quarantined, + project_name: str, + version_name: str, + release: sql.Release, + quarantine_dir: str, +) -> None: + quarantine_dir_path = pathlib.Path(quarantine_dir) + release_name = release.name + + path_to_hash, path_to_size = await attestable.paths_to_hashes_and_sizes(quarantine_dir_path) + + old_revision: sql.Revision | None = None + if quarantined.prior_revision_name is not None: + prior_number = quarantined.prior_revision_name.split()[-1] + async with db.session() as data: + old_revision = await data.revision(release_name=release_name, number=prior_number).get() + + previous_attestable = None + if old_revision is not None: + previous_attestable = await attestable.load(project_name, version_name, old_revision.number) + + base_inodes: dict[str, int] = {} + base_hashes: dict[str, str] = {} + if old_revision is not None: + old_release_dir = paths.release_directory_base(release) / old_revision.number + base_inodes = await asyncio.to_thread(util.paths_to_inodes, old_release_dir) + base_hashes = dict(previous_attestable.paths) if (previous_attestable is not None) else {} + n_inodes = await asyncio.to_thread(util.paths_to_inodes, quarantine_dir_path) + + async with revision.SafeSession(quarantine_dir) as data: + await revision.finalise_revision( + data, + asf_uid=quarantined.asf_uid, + base_hashes=base_hashes, + base_inodes=base_inodes, + description=quarantined.description, + merge_enabled=True, + n_inodes=n_inodes, + old_revision=old_revision, + path_to_hash=path_to_hash, + path_to_size=path_to_size, + previous_attestable=previous_attestable, + project_name=project_name, + release=release, + release_name=release_name, + temp_dir=quarantine_dir, + temp_dir_path=quarantine_dir_path, + version_name=version_name, + was_quarantined=True, + ) + + async with db.session() as data: + await data.delete(quarantined) + await data.commit() + + +async def _run_safety_checks( + archives: list[QuarantineArchiveEntry], quarantine_dir: pathlib.Path +) -> tuple[list[sql.QuarantineFileEntryV1], bool]: + file_entries: list[sql.QuarantineFileEntryV1] = [] + any_failed = False + for archive in archives: + archive_path = str(quarantine_dir / archive.rel_path) + stat = await aiofiles.os.stat(archive_path) + errors = await asyncio.to_thread(detection.check_archive_safety, archive_path) + entry = sql.QuarantineFileEntryV1( + rel_path=archive.rel_path, + size_bytes=stat.st_size, + content_hash=archive.content_hash, + errors=errors, + ) + file_entries.append(entry) + if errors: + any_failed = True + return file_entries, any_failed diff --git a/tests/unit/test_create_revision.py b/tests/unit/test_create_revision.py index 47d50c44..a38e9090 100644 --- a/tests/unit/test_create_revision.py +++ b/tests/unit/test_create_revision.py @@ -43,6 +43,7 @@ class FakeRevision: created: object, phase: sql.ReleasePhase, description: str | None, + was_quarantined: bool = False, ): self.asfuid = asfuid self.created = created @@ -53,6 +54,7 @@ class FakeRevision: self.phase = phase self.release = release self.release_name = release_name + self.was_quarantined = was_quarantined class MockSafeData: diff --git a/tests/unit/test_quarantine_task.py b/tests/unit/test_quarantine_task.py new file mode 100644 index 00000000..6f24c0f3 --- /dev/null +++ b/tests/unit/test_quarantine_task.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import pathlib +import tarfile +import unittest.mock as mock + +import pytest + +import atr.models.sql as sql +import atr.tasks as tasks +import atr.tasks.quarantine as quarantine + + [email protected] +async def test_mark_failed_persists_on_managed_instance(): + # This is a regression test for a bug during development + # Instance of issue #299 + detached = mock.MagicMock(spec=sql.Quarantined) + detached.id = 42 + + managed = mock.MagicMock(spec=sql.Quarantined) + + mock_data = mock.AsyncMock() + mock_data.merge = mock.AsyncMock(return_value=managed) + + mock_session_ctx = mock.AsyncMock() + mock_session_ctx.__aenter__ = mock.AsyncMock(return_value=mock_data) + mock_session_ctx.__aexit__ = mock.AsyncMock(return_value=False) + + entries = [ + sql.QuarantineFileEntryV1(rel_path="bad.tar.gz", size_bytes=100, content_hash="abc", errors=["traversal"]) + ] + + with mock.patch.object(quarantine.db, "session", return_value=mock_session_ctx): + await quarantine._mark_failed(detached, entries) + + assert managed.status == sql.QuarantineStatus.FAILED + assert managed.file_metadata == entries + assert managed.completed is not None + + [email protected] +async def test_promote_finalises_revision_and_deletes_quarantined(tmp_path: pathlib.Path): + quarantine_dir_path = tmp_path / "quarantine" + quarantine_dir_path.mkdir() + quarantine_dir = str(quarantine_dir_path) + + release = mock.MagicMock() + release.name = "proj-1.0" + + quarantined_row = mock.MagicMock(spec=sql.Quarantined) + quarantined_row.prior_revision_name = None + quarantined_row.asf_uid = "testuser" + quarantined_row.description = "test upload" + + mock_safe_ctx = mock.MagicMock() + mock_safe_ctx.__aenter__ = mock.AsyncMock(return_value=mock.AsyncMock()) + mock_safe_ctx.__aexit__ = mock.AsyncMock(return_value=False) + + mock_delete_data = mock.AsyncMock() + mock_delete_data.delete = mock.AsyncMock() + mock_delete_ctx = mock.AsyncMock() + mock_delete_ctx.__aenter__ = mock.AsyncMock(return_value=mock_delete_data) + mock_delete_ctx.__aexit__ = mock.AsyncMock(return_value=False) + + with ( + mock.patch.object( + quarantine.attestable, + "paths_to_hashes_and_sizes", + new_callable=mock.AsyncMock, + return_value=({"file.txt": "hash1"}, {"file.txt": 100}), + ), + mock.patch.object(quarantine.util, "paths_to_inodes", return_value={"file.txt": 12345}), + mock.patch.object(quarantine.revision, "SafeSession", return_value=mock_safe_ctx), + mock.patch.object(quarantine.revision, "finalise_revision", new_callable=mock.AsyncMock) as mock_finalise, + mock.patch.object(quarantine.db, "session", return_value=mock_delete_ctx), + ): + await quarantine._promote(quarantined_row, "proj", "1.0", release, quarantine_dir) + + mock_finalise.assert_awaited_once() + call_kwargs = mock_finalise.call_args.kwargs + assert call_kwargs["was_quarantined"] is True + assert call_kwargs["project_name"] == "proj" + assert call_kwargs["release"] is release + assert call_kwargs["path_to_hash"] == {"file.txt": "hash1"} + mock_delete_data.delete.assert_awaited_once_with(quarantined_row) + + +def test_resolve_returns_quarantine_handler(): + handler = tasks.resolve(sql.TaskType.QUARANTINE_VALIDATE) + assert handler is quarantine.validate + + [email protected] +async def test_run_safety_checks_safe_archive(tmp_path: pathlib.Path): + archive_path = tmp_path / "safe.tar.gz" + _create_safe_tar_gz(archive_path) + + archives = [quarantine.QuarantineArchiveEntry(rel_path="safe.tar.gz", content_hash="abc123")] + entries, any_failed = await quarantine._run_safety_checks(archives, tmp_path) + + assert not any_failed + assert len(entries) == 1 + assert entries[0].rel_path == "safe.tar.gz" + assert entries[0].content_hash == "abc123" + assert entries[0].errors == [] + + [email protected] +async def test_run_safety_checks_unsafe_archive(tmp_path: pathlib.Path): + archive_path = tmp_path / "unsafe.tar.gz" + _create_traversal_tar_gz(archive_path) + + archives = [quarantine.QuarantineArchiveEntry(rel_path="unsafe.tar.gz", content_hash="def456")] + entries, any_failed = await quarantine._run_safety_checks(archives, tmp_path) + + assert any_failed + assert len(entries) == 1 + assert len(entries[0].errors) > 0 + + [email protected] +async def test_validate_extraction_failure_marks_failed_and_deletes_dir(tmp_path: pathlib.Path): + quarantine_dir = tmp_path / "quarantine" + quarantine_dir.mkdir() + + row = _make_quarantined_row() + mock_data = _make_session_returning(row) + + ok_entries = [sql.QuarantineFileEntryV1(rel_path="ok.tar.gz", size_bytes=50, content_hash="abc", errors=[])] + + with ( + mock.patch.object(quarantine.db, "session", return_value=mock_data), + mock.patch.object(quarantine.paths, "quarantine_directory", return_value=quarantine_dir), + mock.patch.object( + quarantine, + "_run_safety_checks", + new_callable=mock.AsyncMock, + return_value=(ok_entries, False), + ), + mock.patch.object( + quarantine, + "_extract_archives_to_cache", + new_callable=mock.AsyncMock, + side_effect=RuntimeError("Extraction failure"), + ), + mock.patch.object(quarantine, "_mark_failed", new_callable=mock.AsyncMock) as mock_mark, + mock.patch.object(quarantine.aioshutil, "rmtree", new_callable=mock.AsyncMock) as mock_rmtree, + ): + result = await quarantine.validate( + {"quarantined_id": 1, "archives": [{"rel_path": "ok.tar.gz", "content_hash": "abc"}]} + ) + + assert result is None + mock_mark.assert_awaited_once_with(row, ok_entries, "Archive extraction to cache failed") + mock_rmtree.assert_awaited_once_with(quarantine_dir) + + [email protected] +async def test_validate_missing_quarantined_row(): + mock_data = mock.AsyncMock() + mock_query = mock.MagicMock() + mock_query.get = mock.AsyncMock(return_value=None) + mock_data.quarantined = mock.MagicMock(return_value=mock_query) + mock_data.__aenter__ = mock.AsyncMock(return_value=mock_data) + mock_data.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch.object(quarantine.db, "session", return_value=mock_data): + result = await quarantine.validate({"quarantined_id": 999, "archives": []}) + + assert result is None + + [email protected] +async def test_validate_non_pending_status(): + quarantined_row = mock.MagicMock() + quarantined_row.status = sql.QuarantineStatus.FAILED + + mock_data = mock.AsyncMock() + mock_query = mock.MagicMock() + mock_query.get = mock.AsyncMock(return_value=quarantined_row) + mock_data.quarantined = mock.MagicMock(return_value=mock_query) + mock_data.__aenter__ = mock.AsyncMock(return_value=mock_data) + mock_data.__aexit__ = mock.AsyncMock(return_value=False) + + with mock.patch.object(quarantine.db, "session", return_value=mock_data): + result = await quarantine.validate({"quarantined_id": 1, "archives": []}) + + assert result is None + + [email protected] +async def test_validate_safety_failure_marks_failed_and_deletes_dir(tmp_path: pathlib.Path): + quarantine_dir = tmp_path / "quarantine" + quarantine_dir.mkdir() + + row = _make_quarantined_row() + mock_data = _make_session_returning(row) + + fail_entries = [ + sql.QuarantineFileEntryV1( + rel_path="unsafe.tar.gz", size_bytes=50, content_hash="def", errors=["path traversal"] + ) + ] + + with ( + mock.patch.object(quarantine.db, "session", return_value=mock_data), + mock.patch.object(quarantine.paths, "quarantine_directory", return_value=quarantine_dir), + mock.patch.object( + quarantine, + "_run_safety_checks", + new_callable=mock.AsyncMock, + return_value=(fail_entries, True), + ), + mock.patch.object(quarantine, "_mark_failed", new_callable=mock.AsyncMock) as mock_mark, + mock.patch.object(quarantine.aioshutil, "rmtree", new_callable=mock.AsyncMock) as mock_rmtree, + ): + result = await quarantine.validate( + {"quarantined_id": 1, "archives": [{"rel_path": "unsafe.tar.gz", "content_hash": "def"}]} + ) + + assert result is None + mock_mark.assert_awaited_once_with(row, fail_entries) + mock_rmtree.assert_awaited_once_with(quarantine_dir) + + [email protected] +async def test_validate_success_calls_promote(tmp_path: pathlib.Path): + quarantine_dir = tmp_path / "quarantine" + quarantine_dir.mkdir() + + row = _make_quarantined_row() + mock_data = _make_session_returning(row) + + ok_entries = [sql.QuarantineFileEntryV1(rel_path="ok.tar.gz", size_bytes=50, content_hash="abc", errors=[])] + + with ( + mock.patch.object(quarantine.db, "session", return_value=mock_data), + mock.patch.object(quarantine.paths, "quarantine_directory", return_value=quarantine_dir), + mock.patch.object( + quarantine, + "_run_safety_checks", + new_callable=mock.AsyncMock, + return_value=(ok_entries, False), + ), + mock.patch.object(quarantine, "_extract_archives_to_cache", new_callable=mock.AsyncMock), + mock.patch.object(quarantine, "_promote", new_callable=mock.AsyncMock) as mock_promote, + mock.patch.object(quarantine, "_mark_failed", new_callable=mock.AsyncMock) as mock_mark, + ): + result = await quarantine.validate( + {"quarantined_id": 1, "archives": [{"rel_path": "ok.tar.gz", "content_hash": "abc"}]} + ) + + assert result is None + mock_promote.assert_awaited_once_with(row, "proj", "1.0", row.release, str(quarantine_dir)) + mock_mark.assert_not_awaited() + + +def _create_safe_tar_gz(path: pathlib.Path) -> None: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + info = tarfile.TarInfo(name="README.txt") + content = b"Safe content" + info.size = len(content) + tar.addfile(info, io.BytesIO(content)) + path.write_bytes(buf.getvalue()) + + +def _create_traversal_tar_gz(path: pathlib.Path) -> None: + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + info = tarfile.TarInfo(name="../../../etc/passwd") + content = b"traversal" + info.size = len(content) + tar.addfile(info, io.BytesIO(content)) + path.write_bytes(buf.getvalue()) + + +def _make_quarantined_row() -> mock.MagicMock: + row = mock.MagicMock(spec=sql.Quarantined) + row.id = 1 + row.status = sql.QuarantineStatus.PENDING + row.release = mock.MagicMock() + row.release.project_name = "proj" + row.release.version = "1.0" + return row + + +def _make_session_returning(quarantined_row: mock.MagicMock) -> mock.AsyncMock: + mock_data = mock.AsyncMock() + mock_query = mock.MagicMock() + mock_query.get = mock.AsyncMock(return_value=quarantined_row) + mock_data.quarantined = mock.MagicMock(return_value=mock_query) + mock_data.__aenter__ = mock.AsyncMock(return_value=mock_data) + mock_data.__aexit__ = mock.AsyncMock(return_value=False) + return mock_data --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
