This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch arm in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit a81c982cfea59bd2d41b2b3a9745ee798479d654 Author: Alastair McFarlane <[email protected]> AuthorDate: Mon Mar 9 17:55:09 2026 +0000 #643 - Add safe.RevisionNumber and utilise unsafe.UnsafeStr for remaining str types. --- atr/admin/__init__.py | 13 +++++++--- atr/api/__init__.py | 47 +++++++++++++++++++---------------- atr/attestable.py | 41 ++++++++++++++++-------------- atr/blueprints/api.py | 2 +- atr/blueprints/common.py | 51 ++++++++++++++++++++++++-------------- atr/blueprints/get.py | 2 +- atr/blueprints/post.py | 2 +- atr/construct.py | 8 +++--- atr/db/interaction.py | 33 ++++++++++++++---------- atr/get/announce.py | 2 +- atr/get/checks.py | 2 +- atr/get/committees.py | 11 ++++---- atr/get/keys.py | 13 +++++----- atr/get/manual.py | 6 ++--- atr/get/projects.py | 21 ++++++++-------- atr/get/ref.py | 9 ++++--- atr/get/test.py | 7 +++--- atr/get/voting.py | 6 ++--- atr/merge.py | 8 +++--- atr/models/api.py | 8 +++--- atr/models/safe.py | 25 ++++++++++++++++--- atr/models/sql.py | 10 ++++++++ atr/models/unsafe.py | 5 +++- atr/paths.py | 6 ++--- atr/post/announce.py | 6 ++--- atr/post/draft.py | 12 +++------ atr/post/keys.py | 9 ++++--- atr/post/manual.py | 2 +- atr/post/projects.py | 11 ++++---- atr/post/revisions.py | 2 +- atr/post/upload.py | 9 ++++--- atr/post/voting.py | 6 ++--- atr/shared/ignores.py | 3 ++- atr/shared/revisions.py | 3 ++- atr/ssh.py | 4 ++- atr/storage/readers/releases.py | 5 ++-- atr/storage/writers/announce.py | 12 ++++----- atr/storage/writers/checks.py | 4 +-- atr/storage/writers/release.py | 16 +++++++----- atr/storage/writers/revision.py | 18 ++++++++------ atr/storage/writers/vote.py | 6 ++--- atr/tasks/__init__.py | 22 ++++++++-------- atr/tasks/checks/__init__.py | 18 +++++++------- atr/tasks/checks/compare.py | 2 +- atr/tasks/quarantine.py | 2 +- atr/tasks/vote.py | 2 +- atr/web.py | 7 +++--- atr/worker.py | 5 ++-- tests/unit/test_create_revision.py | 15 +++++++++-- 49 files changed, 316 insertions(+), 223 deletions(-) diff --git a/atr/admin/__init__.py b/atr/admin/__init__.py index c6763110..3fe6e7dc 100644 --- a/atr/admin/__init__.py +++ b/atr/admin/__init__.py @@ -599,7 +599,8 @@ async def ongoing_tasks_get( ) -> web.QuartResponse: project = safe.ProjectName(project_name) version = safe.VersionName(version_name) - return await _ongoing_tasks(session, project, version, revision) + revision_number = safe.RevisionNumber(revision) + return await _ongoing_tasks(session, project, version, revision_number) @admin.post("/ongoing-tasks/<project_name>/<version_name>/<revision>") @@ -608,7 +609,8 @@ async def ongoing_tasks_post( ) -> web.QuartResponse: project = safe.ProjectName(project_name) version = safe.VersionName(version_name) - return await _ongoing_tasks(session, project, version, revision) + revision_number = safe.RevisionNumber(revision) + return await _ongoing_tasks(session, project, version, revision_number) @admin.get("/performance") @@ -1172,13 +1174,16 @@ async def _get_filesystem_dirs_unfinished(filesystem_dirs: list[str]) -> None: async def _ongoing_tasks( - session: web.Committer, project_name: safe.ProjectName, version_name: safe.VersionName, revision: str + session: web.Committer, + project_name: safe.ProjectName, + version_name: safe.VersionName, + revision: safe.RevisionNumber, ) -> web.QuartResponse: try: ongoing = await interaction.tasks_ongoing(project_name, version_name, revision) return web.TextResponse(str(ongoing)) except Exception: - log.exception(f"Error fetching ongoing task count for {project_name!s} {version_name!s} rev {revision}:") + log.exception(f"Error fetching ongoing task count for {project_name!s} {version_name!s} rev {revision!s}:") return web.TextResponse("") diff --git a/atr/api/__init__.py b/atr/api/__init__.py index 97c5b750..6ff0c505 100644 --- a/atr/api/__init__.py +++ b/atr/api/__init__.py @@ -39,6 +39,7 @@ import atr.log as log import atr.models as models import atr.models.safe as safe import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.paths as paths import atr.principal as principal import atr.storage as storage @@ -100,7 +101,7 @@ async def checks_list_revision( _checks_list: Literal["checks/list"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str, + revision: safe.RevisionNumber, ) -> DictResponse: """ URL: GET /checks/list/<project_name>/<version_name>/<revision> @@ -121,7 +122,7 @@ async def checks_list_revision( exceptions.NotFound(f"Release '{release_name}' does not exist") ) - revision_result = await data.revision(release_name=release_name, number=revision).get() + revision_result = await data.revision(release_name=release_name, number=str(revision)).get() if revision_result is None: raise exceptions.NotFound(f"Revision '{revision}' does not exist for release '{release_name}'") @@ -130,7 +131,7 @@ async def checks_list_revision( return models.api.ChecksListResults( endpoint="/checks/list", checks=check_results, - checks_revision=revision, + checks_revision=str(revision), current_phase=release_result.phase, ).model_dump(mode="json"), 200 @@ -141,7 +142,7 @@ async def checks_ongoing( _checks_ongoing: Literal["checks/ongoing"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str | None = None, + revision: safe.RevisionNumber | None = None, ) -> DictResponse: """ URL: GET /checks/ongoing/<project_name>/<version_name>[/<revision>] @@ -179,7 +180,7 @@ async def checks_ongoing( @quart_schema.validate_response(models.api.CommitteeGetResults, 200) async def committee_get( _committee_get: Literal["committee/get"], - name: str, + name: unsafe.UnsafeStr, ) -> DictResponse: """ URL: GET /committee/get/<name> @@ -192,7 +193,9 @@ async def committee_get( "simple-example". """ async with db.session() as data: - committee = await data.committee(name=name).demand(exceptions.NotFound(f"Committee '{name}' was not found")) + committee = await data.committee(name=str(name)).demand( + exceptions.NotFound(f"Committee '{name!s}' was not found") + ) return models.api.CommitteeGetResults( endpoint="/committee/get", committee=committee, @@ -203,7 +206,7 @@ async def committee_get( @quart_schema.validate_response(models.api.CommitteeKeysResults, 200) async def committee_keys( _committee_keys: Literal["committee/keys"], - name: str, + name: unsafe.UnsafeStr, ) -> DictResponse: """ URL: GET /committee/keys/<name> @@ -216,8 +219,8 @@ async def committee_keys( "simple-example". """ async with db.session() as data: - committee = await data.committee(name=name, _public_signing_keys=True).demand( - exceptions.NotFound(f"Committee '{name}' was not found") + committee = await data.committee(name=str(name), _public_signing_keys=True).demand( + exceptions.NotFound(f"Committee '{name!s}' was not found") ) return models.api.CommitteeKeysResults( endpoint="/committee/keys", @@ -229,7 +232,7 @@ async def committee_keys( @quart_schema.validate_response(models.api.CommitteeProjectsResults, 200) async def committee_projects( _committee_projects: Literal["committee/projects"], - name: str, + name: unsafe.UnsafeStr, ) -> DictResponse: """ URL: GET /committee/projects/<name> @@ -242,8 +245,8 @@ async def committee_projects( "simple-example". """ async with db.session() as data: - committee = await data.committee(name=name, _projects=True).demand( - exceptions.NotFound(f"Committee '{name}' was not found") + committee = await data.committee(name=str(name), _projects=True).demand( + exceptions.NotFound(f"Committee '{name!s}' was not found") ) return models.api.CommitteeProjectsResults( endpoint="/committee/projects", @@ -583,7 +586,7 @@ async def key_delete( @quart_schema.validate_response(models.api.KeyGetResults, 200) async def key_get( _key_get: Literal["key/get"], - fingerprint: str, + fingerprint: unsafe.UnsafeStr, ) -> DictResponse: """ URL: GET /key/get/<fingerprint> @@ -593,8 +596,8 @@ async def key_get( All public OpenPGP keys stored within the database are accessible. """ async with db.session() as data: - key = await data.public_signing_key(fingerprint=fingerprint.lower()).demand( - exceptions.NotFound(f"Key '{fingerprint}' not found") + key = await data.public_signing_key(fingerprint=str(fingerprint).lower()).demand( + exceptions.NotFound(f"Key '{fingerprint!s}' not found") ) return models.api.KeyGetResults( endpoint="/key/get", @@ -669,7 +672,7 @@ async def keys_upload( @quart_schema.validate_response(models.api.KeysUserResults, 200) async def keys_user( _keys_user: Literal["keys/user"], - asf_uid: str, + asf_uid: unsafe.UnsafeStr, ) -> DictResponse: """ URL: GET /keys/user/<asf_uid> @@ -677,7 +680,7 @@ async def keys_user( List public OpenPGP keys by the ASF UID of a user. """ async with db.session() as data: - keys = await data.public_signing_key(apache_uid=asf_uid).all() + keys = await data.public_signing_key(apache_uid=str(asf_uid)).all() return models.api.KeysUserResults( endpoint="/keys/user", keys=keys, @@ -1068,7 +1071,7 @@ async def release_paths( _release_paths: Literal["release/paths"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str | None = None, + revision: safe.RevisionNumber | None = None, ) -> DictResponse: """ URL: GET /release/paths/<project_name>/<version_name>[/<revision>] @@ -1081,8 +1084,8 @@ async def release_paths( if revision is None: dir_path = paths.release_directory(release) else: - await data.revision(release_name=release_name, number=revision).demand(exceptions.NotFound()) - dir_path = paths.release_directory_version(release) / revision + await data.revision(release_name=release_name, number=str(revision)).demand(exceptions.NotFound()) + dir_path = paths.release_directory_version(release) / str(revision) if not (await aiofiles.os.path.isdir(dir_path)): raise exceptions.NotFound("Files not found") files: list[str] = [str(path) for path in [p async for p in util.paths_recursive(dir_path)]] @@ -1322,7 +1325,7 @@ async def ssh_key_delete( @rate_limiter.rate_limit(10, datetime.timedelta(hours=1)) async def ssh_keys_list( _ssh_keys_list: Literal["ssh-keys/list"], - asf_uid: str, + asf_uid: unsafe.UnsafeStr, query_args: models.api.SshKeysListQuery, ) -> DictResponse: """ @@ -1335,7 +1338,7 @@ async def ssh_keys_list( async with db.session() as data: statement = ( sqlmodel.select(sql.SSHKey) - .where(sql.SSHKey.asf_uid == asf_uid) + .where(sql.SSHKey.asf_uid == str(asf_uid)) .limit(query_args.limit) .offset(query_args.offset) .order_by(via(sql.SSHKey.fingerprint).asc()) diff --git a/atr/attestable.py b/atr/attestable.py index f5e8592e..a020776d 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -36,31 +36,34 @@ if TYPE_CHECKING: def attestable_checks_path( - project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str + project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: safe.RevisionNumber ) -> pathlib.Path: - return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number}.checks.json" + return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number!s}.checks.json" def attestable_path( - project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str + project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: safe.RevisionNumber ) -> pathlib.Path: - return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number}.json" + return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number!s}.json" def attestable_paths_path( - project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str + project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: safe.RevisionNumber ) -> pathlib.Path: - return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number}.paths.json" + return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number!s}.paths.json" def github_tp_payload_path( - project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str + project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: safe.RevisionNumber ) -> pathlib.Path: - return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number}.github-tp.json" + return paths.get_attestable_dir() / str(project_name) / str(version_name) / f"{revision_number!s}.github-tp.json" async def github_tp_payload_write( - project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str, github_payload: dict[str, Any] + project_name: safe.ProjectName, + version_name: safe.VersionName, + revision_number: safe.RevisionNumber, + github_payload: dict[str, Any], ) -> None: payload_path = github_tp_payload_path(project_name, version_name, revision_number) await util.atomic_write_file(payload_path, json.dumps(github_payload, indent=2)) @@ -69,7 +72,7 @@ async def github_tp_payload_write( async def load( project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, ) -> models.AttestableV1 | None: file_path = attestable_path(project_name, version_name, revision_number) if not await aiofiles.os.path.isfile(file_path): @@ -86,7 +89,7 @@ async def load( async def load_checks( project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, ) -> dict[str, dict[str, str]]: file_path = attestable_checks_path(project_name, version_name, revision_number) # TODO: Once we're sure everyone is on V2, we should be strict about failures here, @@ -108,7 +111,7 @@ async def load_checks( async def load_paths( project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, ) -> dict[str, str] | None: file_path = attestable_paths_path(project_name, version_name, revision_number) if await aiofiles.os.path.isfile(file_path): @@ -174,7 +177,7 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str, async def write_checks_data( project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, rel_path: str, checks: dict[str, str], ) -> None: @@ -198,7 +201,7 @@ async def write_checks_data( async def write_files_data( project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, release_policy: dict[str, Any] | None, uploader_uid: str, previous: models.AttestableV1 | None, @@ -222,7 +225,7 @@ def _compute_hashes_with_attribution( # noqa: C901 path_to_size: dict[str, int], previous: models.AttestableV1 | None, uploader_uid: str, - revision_number: str, + revision_number: safe.RevisionNumber, ) -> dict[str, models.HashEntry]: previous_hash_to_paths: dict[str, set[str]] = {} if previous is not None: @@ -243,7 +246,7 @@ def _compute_hashes_with_attribution( # noqa: C901 if hash_ref not in new_hashes: new_hashes[hash_ref] = models.HashEntry( size=file_size, - uploaders=[(uploader_uid, revision_number)], + uploaders=[(uploader_uid, str(revision_number))], basenames=sorted(current_basenames), ) continue @@ -256,8 +259,8 @@ def _compute_hashes_with_attribution( # noqa: C901 if len(current_paths) > len(previous_paths): existing_entries = set(new_hashes[hash_ref].uploaders) - if (uploader_uid, revision_number) not in existing_entries: - new_hashes[hash_ref].uploaders.append((uploader_uid, revision_number)) + if (uploader_uid, str(revision_number)) not in existing_entries: + new_hashes[hash_ref].uploaders.append((uploader_uid, str(revision_number))) return new_hashes @@ -265,7 +268,7 @@ def _compute_hashes_with_attribution( # noqa: C901 def _generate_files_data( path_to_hash: dict[str, str], path_to_size: dict[str, int], - revision_number: str, + revision_number: safe.RevisionNumber, release_policy: dict[str, Any] | None, uploader_uid: str, previous: models.AttestableV1 | None, diff --git a/atr/blueprints/api.py b/atr/blueprints/api.py index 5df652e7..f04428cb 100644 --- a/atr/blueprints/api.py +++ b/atr/blueprints/api.py @@ -73,7 +73,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]: query_safe_params = common.safe_params_for_type(query_param[1]) if query_param is not None else [] async def wrapper(*_args: Any, **kwargs: Any) -> Any: - await common.run_validators(kwargs, validated_params) + await common.validate_params(kwargs, validated_params) kwargs.update(literal_params) if body_param is not None: diff --git a/atr/blueprints/common.py b/atr/blueprints/common.py index 670d52a2..b9c2d219 100644 --- a/atr/blueprints/common.py +++ b/atr/blueprints/common.py @@ -38,7 +38,7 @@ QUART_CONVERTERS: dict[Any, str] = { unsafe.Path: "path", } -VALIDATED_TYPES: set[Any] = {safe.ProjectName, safe.VersionName} +VALIDATED_TYPES: set[Any] = {safe.ProjectName, safe.RevisionNumber, safe.VersionName, unsafe.UnsafeStr} async def authenticate() -> web.Committer: @@ -101,12 +101,7 @@ def build_path( form_param = (param_name, hint) continue - segment = _param_to_segment(param_name, hint, func.__name__) - segments.append(segment) - if hint in VALIDATED_TYPES: - validated_params.append((param_name, hint)) - elif get_origin(hint) is Literal: - literal_params[param_name] = str(get_args(hint)[0]) + _classify_url_param(param_name, hint, func.__name__, segments, validated_params, literal_params) path = "/" + "/".join(segments) return path, validated_params, literal_params, form_param, public @@ -165,18 +160,12 @@ def build_api_path( inner, is_optional = _unwrap_optional(hint) if is_optional: - segment = _param_to_segment(param_name, inner, func.__name__) - segments.append(segment) + segments.append(_param_to_segment(param_name, inner, func.__name__)) optional_params.append(param_name) # Note - this means that safe types which are optional will not get validated - no current use case for this continue - segment = _param_to_segment(param_name, hint, func.__name__) - segments.append(segment) - if hint in VALIDATED_TYPES: - validated_params.append((param_name, hint)) - elif get_origin(hint) is Literal: - literal_params[param_name] = str(get_args(hint)[0]) + _classify_url_param(param_name, hint, func.__name__, segments, validated_params, literal_params) path = "/" + "/".join(segments) return path, validated_params, literal_params, body_param, query_param, optional_params @@ -196,9 +185,9 @@ def safe_params_for_type(cls: type) -> list[tuple[str, type]]: return [(name, hint) for name, hint in hints.items() if hint in VALIDATED_TYPES] -async def run_validators(kwargs: dict[str, Any], validated_params: list[tuple[str, type]]) -> None: +async def validate_params(kwargs: dict[str, Any], known_params: list[tuple[str, type]]) -> None: """Validate URL parameters in order, using the type-specific validators.""" - for param_name, param_type in validated_params: + for param_name, param_type in known_params: raw = kwargs[param_name] if param_type is safe.ProjectName: try: @@ -210,6 +199,13 @@ async def run_validators(kwargs: dict[str, Any], validated_params: list[tuple[st kwargs[param_name] = safe.VersionName(raw) except ValueError: raise base.ASFQuartException(f"Version name {param_name!r} is invalid. ") + elif param_type is safe.RevisionNumber: + try: + kwargs[param_name] = safe.RevisionNumber(raw) + except ValueError: + raise base.ASFQuartException(f"Revision number {param_name!r} is invalid. ") + elif param_type is unsafe.UnsafeStr: + kwargs[param_name] = unsafe.UnsafeStr(raw) async def validate_safe_fields( @@ -227,12 +223,31 @@ async def validate_safe_fields( value = getattr(instance, name, None) if value is not None: temp[name] = str(value) - await run_validators(temp, [(n, t) for n, t in safe_params if n in temp]) + await validate_params(temp, [(n, t) for n, t in safe_params if n in temp]) for name, _ in safe_params: if name in temp: setattr(instance, name, temp[name]) +def _classify_url_param( + param_name: str, + hint: Any, + func_name: str, + segments: list[str], + validated_params: list[tuple[str, type]], + literal_params: dict[str, str], +) -> None: + """Build a URL segment for a parameter and classify it as validated or literal.""" + segment = _param_to_segment(param_name, hint, func_name) + segments.append(segment) + if hint in VALIDATED_TYPES: + validated_params.append((param_name, hint)) + elif get_origin(hint) is Literal: + literal_params[param_name] = str(get_args(hint)[0]) + elif hint is str: + raise TypeError(f"Parameter {param_name!r} in {func_name} is unguarded str") + + def _is_body_type(hint: Any) -> bool: """Check if a type hint is a pydantic BaseModel subclass (but not a Form).""" if not isinstance(hint, type): diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py index ea48245a..05e597b5 100644 --- a/atr/blueprints/get.py +++ b/atr/blueprints/get.py @@ -65,7 +65,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: async def wrapper(*_args: Any, **kwargs: Any) -> Any: enhanced_session = await common.authenticate_public() if public else await common.authenticate() - await common.run_validators(kwargs, validated_params) + await common.validate_params(kwargs, validated_params) kwargs.update(literal_params) start_time_ns = time.perf_counter_ns() diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py index 77fdd44a..6c585507 100644 --- a/atr/blueprints/post.py +++ b/atr/blueprints/post.py @@ -70,7 +70,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: async def wrapper(*_args: Any, **kwargs: Any) -> Any: enhanced_session = await common.authenticate_public() if public else await common.authenticate() - await common.run_validators(kwargs, validated_params) + await common.validate_params(kwargs, validated_params) kwargs.update(literal_params) if check_access and (enhanced_session is not None) and (project_name_var is not None): diff --git a/atr/construct.py b/atr/construct.py index f4ca7326..32a18770 100644 --- a/atr/construct.py +++ b/atr/construct.py @@ -56,7 +56,7 @@ class AnnounceReleaseOptions: fullname: str project_name: safe.ProjectName version_name: safe.VersionName - revision_number: str + revision_number: safe.RevisionNumber @dataclasses.dataclass @@ -65,7 +65,7 @@ class StartVoteOptions: fullname: str project_name: safe.ProjectName version_name: safe.VersionName - revision_number: str + revision_number: safe.RevisionNumber vote_duration: int @@ -102,7 +102,7 @@ async def announce_release_subject_and_body( raise RuntimeError(f"Release {options.project_name} {options.version_name} has no committee") committee = release.committee - revision = await data.revision(release_name=release.name, number=options.revision_number).get() + revision = await data.revision(release_name=release.name, number=str(options.revision_number)).get() revision_number = revision.number if revision else "" revision_tag = revision.tag if (revision and revision.tag) else "" @@ -206,7 +206,7 @@ async def start_vote_subject_and_body(subject: str, body: str, options: StartVot raise RuntimeError(f"Release {options.project_name} {options.version_name} has no committee") committee = release.committee - revision = await data.revision(release_name=release.name, number=options.revision_number).get() + revision = await data.revision(release_name=release.name, number=str(options.revision_number)).get() revision_number = revision.number if revision else "" revision_tag = revision.tag if (revision and revision.tag) else "" diff --git a/atr/db/interaction.py b/atr/db/interaction.py index e921f208..35f86a8b 100644 --- a/atr/db/interaction.py +++ b/atr/db/interaction.py @@ -164,13 +164,13 @@ async def candidates(project: sql.Project) -> list[sql.Release]: async def checks_for( release: sql.Release, - revision: str | None = None, + revision: safe.RevisionNumber | None = None, rel_path: str | None = None, caller_data: db.Session | None = None, ) -> list[sql.CheckResult]: """Get the check results for a release, optionally for a specific revision and/or file path.""" if revision is None: - revision = release.unwrap_revision_number + revision = release.safe_latest_revision_number file_path_checks = await attestable.load_checks(release.safe_project_name, release.safe_version_name, revision) if file_path_checks: if rel_path is not None: @@ -194,7 +194,10 @@ async def checks_for( async def count_checks_for_revision_by_status( - status: sql.CheckResultStatus, release: sql.Release, revision_number: str, caller_data: db.Session | None = None + status: sql.CheckResultStatus, + release: sql.Release, + revision_number: safe.RevisionNumber, + caller_data: db.Session | None = None, ): file_path_checks = await attestable.load_checks( release.safe_project_name, release.safe_version_name, revision_number @@ -228,14 +231,18 @@ async def full_releases(project: sql.Project) -> list[sql.Release]: return await releases_by_phase(project, sql.ReleasePhase.RELEASE) -async def has_blocker_checks(release: sql.Release, revision_number: str, caller_data: db.Session | None = None) -> bool: +async def has_blocker_checks( + release: sql.Release, revision_number: safe.RevisionNumber, caller_data: db.Session | None = None +) -> bool: count = await count_checks_for_revision_by_status( sql.CheckResultStatus.BLOCKER, release, revision_number, caller_data ) return count > 0 -async def has_failing_checks(release: sql.Release, revision_number: str, caller_data: db.Session | None = None) -> bool: +async def has_failing_checks( + release: sql.Release, revision_number: safe.RevisionNumber, caller_data: db.Session | None = None +) -> bool: count = await count_checks_for_revision_by_status( sql.CheckResultStatus.FAILURE, release, revision_number, caller_data ) @@ -244,7 +251,7 @@ async def has_failing_checks(release: sql.Release, revision_number: str, caller_ async def latest_info( project_name: safe.ProjectName, version_name: safe.VersionName -) -> tuple[str, str, datetime.datetime] | None: +) -> tuple[safe.RevisionNumber, str, datetime.datetime] | None: """Get the name, editor, and timestamp of the latest revision.""" release_name = sql.release_name(project_name, version_name) async with db.session() as data: @@ -258,7 +265,7 @@ async def latest_info( revision = await data.revision(release_name=str(release_name), number=release.latest_revision_number).get() if not revision: return None - return revision.number, revision.asfuid, revision.created + return revision.safe_number, revision.asfuid, revision.created async def latest_revision(release: sql.Release, caller_data: db.Session | None = None) -> sql.Revision | None: @@ -298,7 +305,7 @@ async def release_ready_for_vote( # noqa: C901 session: web.Committer, project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str, + revision: safe.RevisionNumber, data: db.Session, manual_vote: bool = False, ) -> tuple[sql.Release, sql.Committee] | str: @@ -315,7 +322,7 @@ async def release_ready_for_vote( # noqa: C901 if selected_revision_number is None: return "No revision found for this release" - if selected_revision_number != revision: + if release.safe_latest_revision_number != revision: return "This revision does not match the revision you are voting on" committee = release.committee @@ -397,7 +404,7 @@ def task_recipient_get(latest_vote_task: sql.Task) -> str | None: async def tasks_ongoing( - project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str | None = None + project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: safe.RevisionNumber | None = None ) -> int: tasks = sqlmodel.select(sqlalchemy.func.count()).select_from(sql.Task) async with db.session() as data: @@ -405,7 +412,7 @@ async def tasks_ongoing( sql.Task.project_name == str(project_name), sql.Task.version_name == str(version_name), sql.Task.revision_number - == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is None) else revision_number), + == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is None) else str(revision_number)), sql.validate_instrumented_attribute(sql.Task.status).in_([sql.TaskStatus.QUEUED, sql.TaskStatus.ACTIVE]), ) result = await data.execute(query) @@ -415,7 +422,7 @@ async def tasks_ongoing( async def tasks_ongoing_revision( project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str | None = None, + revision_number: safe.RevisionNumber | None = None, ) -> tuple[int, str | None]: via = sql.validate_instrumented_attribute subquery = ( @@ -438,7 +445,7 @@ async def tasks_ongoing_revision( .where( sql.Task.project_name == str(project_name), sql.Task.version_name == str(version_name), - sql.Task.revision_number == (subquery if (revision_number is None) else revision_number), + sql.Task.revision_number == (subquery if (revision_number is None) else str(revision_number)), sql.validate_instrumented_attribute(sql.Task.status).in_( [sql.TaskStatus.QUEUED, sql.TaskStatus.ACTIVE], ), diff --git a/atr/get/announce.py b/atr/get/announce.py index b33d4069..5054f68d 100644 --- a/atr/get/announce.py +++ b/atr/get/announce.py @@ -69,7 +69,7 @@ async def selected( fullname=session.fullname, project_name=project_name, version_name=version_name, - revision_number=latest_revision_number, + revision_number=release.safe_latest_revision_number, ) default_subject, default_body = await construct.announce_release_subject_and_body( default_subject_template, default_body_template, options diff --git a/atr/get/checks.py b/atr/get/checks.py index 03e26b91..52ef3a6d 100644 --- a/atr/get/checks.py +++ b/atr/get/checks.py @@ -146,7 +146,7 @@ async def selected_revision( _checks: Literal["checks"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, ) -> web.QuartResponse: """ URL: /checks/<project_name>/<version_name>/<revision_number> diff --git a/atr/get/committees.py b/atr/get/committees.py index f565d8b9..1ab70af9 100644 --- a/atr/get/committees.py +++ b/atr/get/committees.py @@ -24,6 +24,7 @@ import atr.blueprints.get as get import atr.db as db import atr.form as form import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.post as post import atr.shared as shared import atr.template as template @@ -48,17 +49,17 @@ async def directory(_session: web.Public, _committees: Literal["committees"]) -> @get.typed -async def view(_session: web.Public, _committees: Literal["committees"], name: str) -> str: +async def view(_session: web.Public, _committees: Literal["committees"], name: unsafe.UnsafeStr) -> str: """ URL: /committees/<name> """ # TODO: Could also import this from keys.py async with db.session() as data: committee = await data.committee( - name=name, + name=str(name), _projects=True, _public_signing_keys=True, - ).demand(base.ASFQuartException(f"Committee {name} not found", errorcode=404)) + ).demand(base.ASFQuartException(f"Committee {name!s} not found", errorcode=404)) project_list = list(committee.projects) for project in project_list: # Workaround for the usual loading problem @@ -74,8 +75,8 @@ async def view(_session: web.Public, _committees: Literal["committees"], name: s model_cls=shared.keys.UpdateCommitteeKeysForm, action=util.as_url(post.keys.keys), submit_label="Regenerate KEYS file", - defaults={"committee_name": name}, + defaults={"committee_name": str(name)}, empty=True, ), - is_standing=util.committee_is_standing(name), + is_standing=util.committee_is_standing(str(name)), ) diff --git a/atr/get/keys.py b/atr/get/keys.py index 18223ff8..d0fba17f 100644 --- a/atr/get/keys.py +++ b/atr/get/keys.py @@ -27,6 +27,7 @@ import atr.db as db import atr.form as form import atr.htm as htm import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.post as post import atr.shared as shared import atr.storage as storage @@ -72,14 +73,14 @@ async def add(_session: web.Committer, _keys_add: Literal["keys/add"]) -> str: @get.typed -async def details(session: web.Committer, _keys_details: Literal["keys/details"], fingerprint: str) -> str: +async def details(session: web.Committer, _keys_details: Literal["keys/details"], fingerprint: unsafe.UnsafeStr) -> str: """ URL: /keys/details/<fingerprint> Display details for a specific OpenPGP key. """ - fingerprint = fingerprint.lower() + key_fingerprint = str(fingerprint).lower() async with db.session() as data: - key, is_owner = await _key_and_is_owner(data, session, fingerprint) + key, is_owner = await _key_and_is_owner(data, session, key_fingerprint) user_committees = [] if is_owner: project_list = session.committees + session.projects @@ -154,7 +155,7 @@ async def details(session: web.Committer, _keys_details: Literal["keys/details"] checkboxes = _render_committee_checkboxes(committee_choices, current_committee_names) pmc_div.form( method="post", - action=util.as_url(post.keys.details, fingerprint=fingerprint), + action=util.as_url(post.keys.details, fingerprint=key_fingerprint), )[ form.csrf_input(), checkboxes, @@ -182,7 +183,7 @@ async def details(session: web.Committer, _keys_details: Literal["keys/details"] @get.typed async def export( - _session: web.Committer, _keys_export: Literal["keys/export"], committee_name: str + _session: web.Committer, _keys_export: Literal["keys/export"], committee_name: unsafe.UnsafeStr ) -> web.TextResponse: """ URL: /keys/export/<committee_name> @@ -190,7 +191,7 @@ async def export( """ async with storage.write() as write: wafc = write.as_foundation_committer() - keys_file_text = await wafc.keys.keys_file_text(committee_name) + keys_file_text = await wafc.keys.keys_file_text(str(committee_name)) return web.TextResponse(keys_file_text) diff --git a/atr/get/manual.py b/atr/get/manual.py index 7d90f0e9..e9ba1162 100644 --- a/atr/get/manual.py +++ b/atr/get/manual.py @@ -71,7 +71,7 @@ async def start_selected_revision( _manual_start: Literal["manual/start"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str, + revision: safe.RevisionNumber, ) -> web.WerkzeugResponse | str: """ URL: /manual/start/<project_name>/<version_name>/<revision> @@ -87,12 +87,12 @@ async def start_selected_revision( error=error, project_name=str(project_name), version_name=str(version_name), - revision=revision, + revision=str(revision), ) case (release, _committee): pass - content = await _render_page(release=release, revision=revision) + content = await _render_page(release=release, revision=str(revision)) return await template.blank( title=f"Start manual vote on {release.project.short_display_name} {release.version}", content=content diff --git a/atr/get/projects.py b/atr/get/projects.py index ec992b13..d57622cf 100644 --- a/atr/get/projects.py +++ b/atr/get/projects.py @@ -35,6 +35,7 @@ import atr.get.start as start import atr.htm as htm import atr.models.safe as safe import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.post as post import atr.registry as registry import atr.shared as shared @@ -46,33 +47,33 @@ import atr.web as web @get.typed async def add_project( - session: web.Committer, _project_add: Literal["project/add"], committee_name: str + session: web.Committer, _project_add: Literal["project/add"], committee_name: unsafe.UnsafeStr ) -> web.WerkzeugResponse | str: """ URL: /project/add/<committee_name> """ - await session.check_access_committee(committee_name) + await session.check_access_committee(str(committee_name)) async with db.session() as data: - committee = await data.committee(name=committee_name).demand( - base.ASFQuartException(f"Committee {committee_name} not found", errorcode=404) + committee = await data.committee(name=str(committee_name)).demand( + base.ASFQuartException(f"Committee {committee_name!s} not found", errorcode=404) ) page = htm.Block() - page.p[htm.a(".atr-back-link", href=util.as_url(committees.view, name=committee_name))["← Back to committee"]] + page.p[htm.a(".atr-back-link", href=util.as_url(committees.view, name=str(committee_name)))["← Back to committee"]] page.h1["Add project"] page.p[f"Add a new project to the {committee.display_name} committee."] - committee_display_name = committee.full_name or committee_name.title() + committee_display_name = committee.full_name or str(committee_name).title() form.render_block( page, model_cls=shared.projects.AddProjectForm, - action=util.as_url(post.projects.add_project, committee_name=committee_name), + action=util.as_url(post.projects.add_project, committee_name=str(committee_name)), submit_label="Add project", - cancel_url=util.as_url(committees.view, name=committee_name), + cancel_url=util.as_url(committees.view, name=str(committee_name)), defaults={ - "committee_name": committee_name, + "committee_name": str(committee_name), }, ) @@ -80,7 +81,7 @@ async def add_project( page.append( htpy.div( "#projects-add-config.d-none", - data_committee_name=committee_name, + data_committee_name=str(committee_name), data_committee_display_name=committee_display_name, ) ) diff --git a/atr/get/ref.py b/atr/get/ref.py index 9a665ac5..84b18a76 100644 --- a/atr/get/ref.py +++ b/atr/get/ref.py @@ -41,9 +41,10 @@ async def resolve(_session: web.Public, _ref: Literal["ref"], ref_path: unsafe.P Resolve a code reference to a GitHub permalink. """ project_root = pathlib.Path(config.get().PROJECT_ROOT) + path = str(ref_path) - if ":" in ref_path: - file_path_str, symbol = ref_path.rsplit(":", 1) + if ":" in path: + file_path_str, symbol = path.rsplit(":", 1) resolved_file, validated_path_str = _validate_and_resolve_path(file_path_str, project_root) if (not await aiofiles.os.path.exists(resolved_file)) or (not await aiofiles.os.path.isfile(resolved_file)): @@ -57,8 +58,8 @@ async def resolve(_session: web.Public, _ref: Literal["ref"], ref_path: unsafe.P github_url = f"https://github.com/apache/tooling-trusted-releases/blob/main/{validated_path_str}#L{line_number}" return quart.redirect(github_url, code=303) - is_directory = ref_path.endswith("/") - path_str = ref_path.rstrip("/") + is_directory = path.endswith("/") + path_str = path.rstrip("/") resolved_path, validated_path_str = _validate_and_resolve_path(path_str, project_root) if not await aiofiles.os.path.exists(resolved_path): diff --git a/atr/get/test.py b/atr/get/test.py index 7d36472f..a8075c23 100644 --- a/atr/get/test.py +++ b/atr/get/test.py @@ -33,6 +33,7 @@ import atr.htm as htm import atr.models.safe as safe import atr.models.session import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.paths as paths import atr.shared as shared import atr.storage as storage @@ -204,7 +205,7 @@ async def test_single(session: web.Public, _test_single: Literal["test/single"]) async def test_vote( session: web.Public, _test_vote: Literal["test/vote"], - category: str, + category: unsafe.UnsafeStr, project_name: safe.ProjectName, version_name: safe.VersionName, ) -> str: @@ -222,10 +223,10 @@ async def test_vote( "pmc_member_rm": vote.UserCategory.PMC_MEMBER_RM, } - user_category = category_map.get(category.lower()) + user_category = category_map.get(str(category).lower()) if user_category is None: raise base.ASFQuartException( - f"Invalid category: {category}. Valid options: {', '.join(category_map.keys())}", + f"Invalid category: {category!s}. Valid options: {', '.join(category_map.keys())}", errorcode=400, ) diff --git a/atr/get/voting.py b/atr/get/voting.py index a92c496d..90db36ff 100644 --- a/atr/get/voting.py +++ b/atr/get/voting.py @@ -47,7 +47,7 @@ async def selected_revision( _voting: Literal["voting"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str, + revision: safe.RevisionNumber, ) -> web.WerkzeugResponse | str: """ URL: /voting/<project_name>/<version_name>/<revision> @@ -63,7 +63,7 @@ async def selected_revision( error=error, project_name=str(project_name), version_name=str(version_name), - revision=revision, + revision=str(revision), ) case (release, committee): pass @@ -94,7 +94,7 @@ async def selected_revision( content = await _render_page( release=release, - revision_number=revision, + revision_number=str(revision), permitted_recipients=permitted_recipients, default_subject=default_subject, subject_template_hash=subject_template_hash, diff --git a/atr/merge.py b/atr/merge.py index 81ddc81a..84e7a050 100644 --- a/atr/merge.py +++ b/atr/merge.py @@ -38,7 +38,7 @@ async def merge( prior_dir: pathlib.Path, project_name: safe.ProjectName, version_name: safe.VersionName, - prior_revision_number: str, + prior_revision_number: safe.RevisionNumber, temp_dir: pathlib.Path, n_inodes: dict[str, int], n_hashes: dict[str, str], @@ -122,7 +122,7 @@ async def _add_from_prior( prior_hashes: dict[str, str] | None, project_name: safe.ProjectName, version_name: safe.VersionName, - prior_revision_number: str, + prior_revision_number: safe.RevisionNumber, ) -> dict[str, str] | None: target = temp_dir / path await asyncio.to_thread(_makedirs_with_permissions, target.parent, temp_dir) @@ -179,7 +179,7 @@ async def _merge_all_present( prior_hashes: dict[str, str] | None, project_name: safe.ProjectName, version_name: safe.VersionName, - prior_revision_number: str, + prior_revision_number: safe.RevisionNumber, ) -> dict[str, str] | None: # Cases 6, 8: prior and new share an inode so they already agree if p_ino == n_ino: @@ -240,7 +240,7 @@ async def _replace_with_prior( prior_hashes: dict[str, str] | None, project_name: safe.ProjectName, version_name: safe.VersionName, - prior_revision_number: str, + prior_revision_number: safe.RevisionNumber, ) -> dict[str, str] | None: await aiofiles.os.remove(temp_dir / path) await aiofiles.os.link(prior_dir / path, temp_dir / path) diff --git a/atr/models/api.py b/atr/models/api.py index 56f2e737..342773b2 100644 --- a/atr/models/api.py +++ b/atr/models/api.py @@ -166,7 +166,7 @@ class DistributionRecordResults(schema.Strict): class IgnoreAddArgs(schema.Strict): project_name: safe.ProjectName = schema.example("example") release_glob: str | None = schema.default_example(None, "example-0.0.*") - revision_number: str | None = schema.default_example(None, "00001") + revision_number: safe.RevisionNumber | None = schema.default_example(None, "00001") checker_glob: str | None = schema.default_example(None, "atr.tasks.checks.license.files") primary_rel_path_glob: str | None = schema.default_example(None, "apache-example-0.0.1-*.tar.gz") member_rel_path_glob: str | None = schema.default_example(None, "apache-example-0.0.1/*.xml") @@ -357,7 +357,7 @@ class PublisherReleaseAnnounceArgs(schema.Strict): publisher: str = schema.example("user") jwt: str = schema.example("eyJhbGciOiJIUzI1[...]mMjLiuyu5CSpyHI=") version: safe.VersionName = schema.example("0.0.1") - revision: str = schema.example("00005") + revision: safe.RevisionNumber = schema.example("00005") email_to: str = schema.example("[email protected]") body: str = schema.example("The Apache Example team is pleased to announce the release of Example 1.0.0...") path_suffix: str = schema.example("example/1.0.0") @@ -396,7 +396,7 @@ class PublisherVoteResolveResults(schema.Strict): class ReleaseAnnounceArgs(schema.Strict): project: safe.ProjectName = schema.example("example") version: safe.VersionName = schema.example("1.0.0") - revision: str = schema.example("00005") + revision: safe.RevisionNumber = schema.example("00005") email_to: str = schema.example("[email protected]") body: str = schema.example("The Apache Example team is pleased to announce the release of Example 1.0.0...") path_suffix: str = schema.example("example/1.0.0") @@ -581,7 +581,7 @@ class VoteResolveResults(schema.Strict): class VoteStartArgs(schema.Strict): project: safe.ProjectName = schema.example("example") version: safe.VersionName = schema.example("0.0.1") - revision: str = schema.example("00005") + revision: safe.RevisionNumber = schema.example("00005") email_to: str = schema.example("[email protected]") vote_duration: int = schema.example(10) subject: str = schema.example("[VOTE] Apache Example 0.0.1 release") diff --git a/atr/models/safe.py b/atr/models/safe.py index 083a8b2b..9ec73067 100644 --- a/atr/models/safe.py +++ b/atr/models/safe.py @@ -22,16 +22,17 @@ import unicodedata from typing import Any, Final _ALPHANUM: Final = frozenset(string.ascii_letters + string.digits + "-") +_NUMERIC: Final = frozenset(string.digits) _VERSION_CHARS: Final = _ALPHANUM | frozenset(".+") -class Alphanumeric: +class SafeType: __slots__ = ("_value",) @classmethod def _valid_chars(cls) -> frozenset[str]: # default is the base set; subclasses can override this method - return _ALPHANUM + return frozenset() def _additional_validations(self, value: str): pass @@ -53,7 +54,7 @@ class Alphanumeric: return True def __eq__(self, other: object) -> bool: - if isinstance(other, Alphanumeric): + if isinstance(other, self.__class__): return self._value == other._value return NotImplemented @@ -76,6 +77,20 @@ class Alphanumeric: ) +class Alphanumeric(SafeType): + @classmethod + def _valid_chars(cls) -> frozenset[str]: + # default is the base set; subclasses can override this method + return _ALPHANUM + + +class Numeric(SafeType): + @classmethod + def _valid_chars(cls) -> frozenset[str]: + # default is the base set; subclasses can override this method + return _NUMERIC + + class ProjectName(Alphanumeric): """A project name that has been validated for safety.""" @@ -88,6 +103,10 @@ class ReleaseName(Alphanumeric): return _VERSION_CHARS +class RevisionNumber(Numeric): + """A revision number that has been validated for safety.""" + + class VersionName(Alphanumeric): """A version name that has been validated for safety""" diff --git a/atr/models/sql.py b/atr/models/sql.py index 3b288602..2ba21727 100644 --- a/atr/models/sql.py +++ b/atr/models/sql.py @@ -923,6 +923,11 @@ class Release(sqlmodel.SQLModel, table=True): # return None return project.committee + @property + def safe_latest_revision_number(self) -> safe.RevisionNumber: + """Get the typesafe validated name for the Release""" + return safe.RevisionNumber(self.unwrap_revision_number) + @property def safe_name(self) -> safe.ReleaseName: """Get the typesafe validated name for the Release""" @@ -1317,6 +1322,11 @@ class Revision(sqlmodel.SQLModel, table=True): tag: str | None = sqlmodel.Field(default=None, **example("rc1")) was_quarantined: bool = sqlmodel.Field(default=False, **example(False)) + @property + def safe_number(self) -> safe.RevisionNumber: + """Get the typesafe validated number for the revision""" + return safe.RevisionNumber(self.number) + def model_post_init(self, _context): if isinstance(self.created, str): self.created = datetime.datetime.fromisoformat(self.created.rstrip("Z")) diff --git a/atr/models/unsafe.py b/atr/models/unsafe.py index 6e38d4e4..d76127e7 100644 --- a/atr/models/unsafe.py +++ b/atr/models/unsafe.py @@ -28,5 +28,8 @@ class UnsafeStr: def __repr__(self) -> str: return f"UnsafeStr({self._value!r})" + def __str__(self) -> str: + return self._value -Path = NewType("Path", str) + +Path = NewType("Path", UnsafeStr) diff --git a/atr/paths.py b/atr/paths.py index 2aae97ae..7931db1d 100644 --- a/atr/paths.py +++ b/atr/paths.py @@ -23,9 +23,9 @@ import atr.models.sql as sql def base_path_for_revision( - project_name: safe.ProjectName, version_name: safe.VersionName, revision: str + project_name: safe.ProjectName, version_name: safe.VersionName, revision: safe.RevisionNumber ) -> pathlib.Path: - return pathlib.Path(get_unfinished_dir(), str(project_name), str(version_name), revision) + return pathlib.Path(get_unfinished_dir(), str(project_name), str(version_name), str(revision)) def get_attestable_dir() -> pathlib.Path: @@ -135,6 +135,6 @@ def release_directory_version(release: sql.Release) -> pathlib.Path: def revision_path_for_file( - project_name: safe.ProjectName, version_name: safe.VersionName, revision: str, file_name: str + project_name: safe.ProjectName, version_name: safe.VersionName, revision: safe.RevisionNumber, file_name: str ) -> pathlib.Path: return base_path_for_revision(project_name, version_name, revision) / file_name diff --git a/atr/post/announce.py b/atr/post/announce.py index f988b40a..23ea535f 100644 --- a/atr/post/announce.py +++ b/atr/post/announce.py @@ -61,14 +61,14 @@ async def selected( with_release_policy=True, with_project_release_policy=True, ) - preview_revision_number = release.unwrap_revision_number + preview_revision_number = release.safe_latest_revision_number # Validate that the revision number matches - if announce_form.revision_number != preview_revision_number: + if announce_form.revision_number != str(preview_revision_number): return await session.redirect( get.announce.selected, error=f"The release has been updated since you loaded the form. " - f"Please review the current revision ({preview_revision_number}) and submit the form again.", + f"Please review the current revision ({preview_revision_number!s}) and submit the form again.", project_name=str(project_name), version_name=str(version_name), ) diff --git a/atr/post/draft.py b/atr/post/draft.py index 675d4ef8..dcfa6add 100644 --- a/atr/post/draft.py +++ b/atr/post/draft.py @@ -257,7 +257,8 @@ async def sbomgen( URL: /draft/sbomgen/<project_name>/<version_name>/<file_path> Generate a CycloneDX SBOM file for a candidate draft file, creating a new revision. """ - rel_path = form.to_relpath(file_path) + path = str(file_path) + rel_path = form.to_relpath(path) if rel_path is None: await quart.flash("Invalid file path", "error") return await session.redirect( @@ -265,14 +266,9 @@ async def sbomgen( ) # Check that the file is a .tar.gz archive before creating a revision - if not ( - file_path.endswith(".tar.gz") - or file_path.endswith(".tgz") - or file_path.endswith(".zip") - or file_path.endswith(".jar") - ): + if not (path.endswith(".tar.gz") or path.endswith(".tgz") or path.endswith(".zip") or path.endswith(".jar")): raise base.ASFQuartException( - f"SBOM generation requires .tar.gz, .tgz, .zip or .jar files. Received: {file_path}", errorcode=400 + f"SBOM generation requires .tar.gz, .tgz, .zip or .jar files. Received: {path}", errorcode=400 ) try: diff --git a/atr/post/keys.py b/atr/post/keys.py index 9042211a..df295ea1 100644 --- a/atr/post/keys.py +++ b/atr/post/keys.py @@ -31,6 +31,7 @@ import atr.htm as htm import atr.log as log import atr.models.safe as safe import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.shared as shared import atr.storage as storage import atr.storage.outcome as outcome @@ -93,18 +94,18 @@ async def add( async def details( session: web.Committer, _keys_details: Literal["keys/details"], - fingerprint: str, + fingerprint: unsafe.UnsafeStr, update_form: shared.keys.UpdateKeyCommitteesForm, ) -> web.WerkzeugResponse: """ URL: /keys/details/<fingerprint> Update committee associations for an OpenPGP key. """ - fingerprint = fingerprint.lower() + key_fingerprint = str(fingerprint).lower() try: async with db.session() as data: - key = await data.public_signing_key(fingerprint=fingerprint, _committees=True).get() + key = await data.public_signing_key(fingerprint=key_fingerprint, _committees=True).get() if not key: await quart.flash("OpenPGP key not found", "error") return await session.redirect(get.keys.keys) @@ -135,7 +136,7 @@ async def details( log.exception("Error updating key committee associations:") await quart.flash(f"An unexpected error occurred: {e!s}", "error") - return await session.redirect(get.keys.details, fingerprint=fingerprint) + return await session.redirect(get.keys.details, fingerprint=key_fingerprint) @post.typed diff --git a/atr/post/manual.py b/atr/post/manual.py index 4fd8a06e..591b323d 100644 --- a/atr/post/manual.py +++ b/atr/post/manual.py @@ -87,7 +87,7 @@ async def start_selected_revision( _manual_start: Literal["manual/start"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str, + revision: safe.RevisionNumber, _form: form.Empty, ) -> web.WerkzeugResponse | str: """ diff --git a/atr/post/projects.py b/atr/post/projects.py index 68f9d878..6763bf48 100644 --- a/atr/post/projects.py +++ b/atr/post/projects.py @@ -27,6 +27,7 @@ import atr.db as db import atr.get as get import atr.models.safe as safe import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.shared as shared import atr.storage as storage import atr.web as web @@ -36,7 +37,7 @@ import atr.web as web async def add_project( session: web.Committer, _project_add: Literal["project/add"], - committee_name: str, + committee_name: unsafe.UnsafeStr, project_form: shared.projects.AddProjectForm, ) -> web.WerkzeugResponse: """ @@ -47,12 +48,12 @@ async def add_project( # TODO: Is this right? Name is unvalidated async with storage.write(session) as write: - wacm = await write.as_project_committee_member(safe.ProjectName(committee_name)) + wacm = await write.as_project_committee_member(safe.ProjectName(str(committee_name))) try: - await wacm.project.create(committee_name, display_name, label) + await wacm.project.create(str(committee_name), display_name, label) except storage.AccessError as e: return await session.redirect( - get.projects.add_project, committee_name=committee_name, error=f"Error adding project: {e}" + get.projects.add_project, committee_name=str(committee_name), error=f"Error adding project: {e}" ) return await session.redirect( @@ -88,7 +89,7 @@ async def delete( async def view( session: web.Committer, _projects: Literal["projects"], - name: str, + name: unsafe.UnsafeStr, project_form: shared.projects.ProjectViewForm, ) -> web.WerkzeugResponse: """ diff --git a/atr/post/revisions.py b/atr/post/revisions.py index a89a22ac..1acd8d9c 100644 --- a/atr/post/revisions.py +++ b/atr/post/revisions.py @@ -60,7 +60,7 @@ async def _set_revision( if release.phase not in {sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT, sql.ReleasePhase.RELEASE_PREVIEW}: raise base.ASFQuartException("Cannot set revision for non-draft or preview release", errorcode=400) - selected_revision = await data.revision(release_name=release.name, number=selected_revision_number).demand( + selected_revision = await data.revision(release_name=release.name, number=str(selected_revision_number)).demand( base.ASFQuartException(f"Revision {selected_revision_number} not found", errorcode=404) ) if (release.phase == sql.ReleasePhase.RELEASE_PREVIEW) and ( diff --git a/atr/post/upload.py b/atr/post/upload.py index 54b1a056..930eda14 100644 --- a/atr/post/upload.py +++ b/atr/post/upload.py @@ -33,6 +33,7 @@ import atr.get as get import atr.log as log import atr.models.safe as safe import atr.models.sql as sql +import atr.models.unsafe as unsafe import atr.paths as paths import atr.shared as shared import atr.storage as storage @@ -49,14 +50,14 @@ async def finalise( _upload_finalise: Literal["upload/finalise"], project_name: safe.ProjectName, version_name: safe.VersionName, - upload_session: str, + upload_session: unsafe.UnsafeStr, ) -> web.WerkzeugResponse: """ URL: /upload/finalise/<project_name>/<version_name>/<upload_session> """ try: - staging_dir = paths.get_upload_staging_dir(upload_session) + staging_dir = paths.get_upload_staging_dir(str(upload_session)) except ValueError: return _json_error("Invalid session token", 400) @@ -143,14 +144,14 @@ async def stage( _upload_stage: Literal["upload/stage"], _project_name: safe.ProjectName, _version_name: safe.VersionName, - upload_session: str, + upload_session: unsafe.UnsafeStr, ) -> web.WerkzeugResponse: """ URL: /upload/stage/<project_name>/<version_name>/<upload_session> """ try: - staging_dir = paths.get_upload_staging_dir(upload_session) + staging_dir = paths.get_upload_staging_dir(str(upload_session)) except ValueError: return _json_error("Invalid session token", 400) diff --git a/atr/post/voting.py b/atr/post/voting.py index bed6534e..cac23325 100644 --- a/atr/post/voting.py +++ b/atr/post/voting.py @@ -43,7 +43,7 @@ async def body_preview( _voting_body_preview: Literal["voting/body/preview"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, preview_form: BodyPreviewForm, ) -> web.QuartResponse: """ @@ -72,7 +72,7 @@ async def selected_revision( _voting: Literal["voting"], project_name: safe.ProjectName, version_name: safe.VersionName, - revision: str, + revision: safe.RevisionNumber, start_voting_form: shared.voting.StartVotingForm, ) -> web.WerkzeugResponse | str: """ @@ -89,7 +89,7 @@ async def selected_revision( error=error, project_name=str(project_name), version_name=str(version_name), - revision=revision, + revision=str(revision), ) case (release, committee): pass diff --git a/atr/shared/ignores.py b/atr/shared/ignores.py index 2b41f66f..56cb0b84 100644 --- a/atr/shared/ignores.py +++ b/atr/shared/ignores.py @@ -23,6 +23,7 @@ from typing import Annotated, Literal import pydantic import atr.form as form +import atr.models.safe as safe import atr.models.sql as sql import atr.models.validation as validation @@ -43,7 +44,7 @@ class IgnoreStatus(enum.Enum): class AddIgnoreForm(form.Form): variant: ADD = form.value(ADD) release_glob: str = form.label("Release pattern", default="") - revision_number: str = form.label("Revision number (literal)", default="") + revision_number: safe.RevisionNumber = form.label("Revision number (literal)", default="") checker_glob: str = form.label("Checker pattern", default="") primary_rel_path_glob: str = form.label("Primary rel path pattern", default="") member_rel_path_glob: str = form.label("Member rel path pattern", default="") diff --git a/atr/shared/revisions.py b/atr/shared/revisions.py index c8b6dae9..f981cf2c 100644 --- a/atr/shared/revisions.py +++ b/atr/shared/revisions.py @@ -21,6 +21,7 @@ from typing import Annotated, Literal import pydantic import atr.form as form +import atr.models.safe as safe SET_REVISION = Literal["set_revision"] SET_TAG = Literal["set_tag"] @@ -28,7 +29,7 @@ SET_TAG = Literal["set_tag"] class SetRevisionForm(form.Form): variant: SET_REVISION = form.value(SET_REVISION) - revision_number: str = form.label("Revision number", widget=form.Widget.HIDDEN) + revision_number: safe.RevisionNumber = form.label("Revision number", widget=form.Widget.HIDDEN) class SetTagForm(form.Form): diff --git a/atr/ssh.py b/atr/ssh.py index be81b1c9..7e2f53a5 100644 --- a/atr/ssh.py +++ b/atr/ssh.py @@ -616,7 +616,9 @@ async def _step_07b_process_validated_rsync_write( else: github_payload = server._get_github_payload(process) if github_payload is not None: - await attestable.github_tp_payload_write(project_name, version_name, result.number, github_payload) + await attestable.github_tp_payload_write( + project_name, version_name, result.safe_number, github_payload + ) log.info(f"rsync upload successful for revision {result.number}") host = config.get().APP_HOST message = f"\nATR: Created revision {result.number} of {project_name} {version_name}\n" diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py index cfbfca0e..502f1109 100644 --- a/atr/storage/readers/releases.py +++ b/atr/storage/readers/releases.py @@ -24,6 +24,7 @@ import pathlib import atr.classify as classify import atr.db as db import atr.db.interaction as interaction +import atr.models.safe as safe import atr.models.sql as sql import atr.paths as paths import atr.storage as storage @@ -60,7 +61,7 @@ class GeneralPublic: latest_revision_number = release.latest_revision_number if latest_revision_number is None: return None - await self.__successes_errors_warnings(release, latest_revision_number, info) + await self.__successes_errors_warnings(release, release.safe_latest_revision_number, info) base_path = paths.release_directory(release) source_matcher = None source_artifact_paths = release.project.policy_source_artifact_paths @@ -121,7 +122,7 @@ class GeneralPublic: ) async def __successes_errors_warnings( - self, release: sql.Release, latest_revision_number: str, info: types.PathInfo + self, release: sql.Release, latest_revision_number: safe.RevisionNumber, info: types.PathInfo ) -> None: match_ignore = await self.__read_as.checks.ignores_matcher(release.safe_project_name) attestable_checks = await interaction.checks_for( diff --git a/atr/storage/writers/announce.py b/atr/storage/writers/announce.py index a56a2997..69f10aa5 100644 --- a/atr/storage/writers/announce.py +++ b/atr/storage/writers/announce.py @@ -106,7 +106,7 @@ class CommitteeMember(CommitteeParticipant): self, project_name: safe.ProjectName, version_name: safe.VersionName, - preview_revision_number: str, + preview_revision_number: safe.RevisionNumber, recipient: str, body: str, download_path_suffix: str, @@ -124,14 +124,14 @@ class CommitteeMember(CommitteeParticipant): project_name=str(project_name), version=str(version_name), phase=sql.ReleasePhase.RELEASE_PREVIEW, - latest_revision_number=preview_revision_number, + latest_revision_number=str(preview_revision_number), _project_release_policy=True, _revisions=True, _distributions=True, _release_policy=True, ).demand( storage.AccessError( - f"Release {project_name} {version_name} {preview_revision_number} does not exist", + f"Release {project_name!s} {version_name!s} {preview_revision_number!s} does not exist", ) ) if (committee := release.project.committee) is None: @@ -199,7 +199,7 @@ class CommitteeMember(CommitteeParticipant): asf_uid=self.__asf_uid, project_name=str(project_name), version_name=str(version_name), - revision_number=preview_revision_number, + revision_number=str(preview_revision_number), source_directory=unfinished_dir, target_directory=finished_dir, email_recipient=recipient, @@ -282,7 +282,7 @@ class CommitteeMember(CommitteeParticipant): return predicted_finished_release async def __promote_in_database( - self, release: sql.Release, preview_revision_number: str, release_date: datetime.datetime + self, release: sql.Release, preview_revision_number: safe.RevisionNumber, release_date: datetime.datetime ) -> None: """Promote a release preview to a release and delete its old revisions.""" via = sql.validate_instrumented_attribute @@ -292,7 +292,7 @@ class CommitteeMember(CommitteeParticipant): .where( via(sql.Release.name) == release.name, via(sql.Release.phase) == sql.ReleasePhase.RELEASE_PREVIEW, - sql.latest_revision_number_query() == preview_revision_number, + sql.latest_revision_number_query() == str(preview_revision_number), ) .values( phase=sql.ReleasePhase.RELEASE, diff --git a/atr/storage/writers/checks.py b/atr/storage/writers/checks.py index 8140dc28..36094fdd 100644 --- a/atr/storage/writers/checks.py +++ b/atr/storage/writers/checks.py @@ -95,7 +95,7 @@ class CommitteeMember(CommitteeParticipant): self, project_name: safe.ProjectName, release_glob: str | None = None, - revision_number: str | None = None, + revision_number: safe.RevisionNumber | None = None, checker_glob: str | None = None, primary_rel_path_glob: str | None = None, member_rel_path_glob: str | None = None, @@ -115,7 +115,7 @@ class CommitteeMember(CommitteeParticipant): created=datetime.datetime.now(datetime.UTC), project_name=str(project_name), release_glob=release_glob, - revision_number=revision_number, + revision_number=str(revision_number), checker_glob=checker_glob, primary_rel_path_glob=primary_rel_path_glob, member_rel_path_glob=member_rel_path_glob, diff --git a/atr/storage/writers/release.py b/atr/storage/writers/release.py index e273def3..443dd8e0 100644 --- a/atr/storage/writers/release.py +++ b/atr/storage/writers/release.py @@ -321,7 +321,7 @@ class CommitteeParticipant(FoundationCommitter): async def promote_to_candidate( self, release_name: safe.ReleaseName, - selected_revision_number: str, + selected_revision_number: safe.RevisionNumber, vote_manual: bool = False, ) -> str | None: """Promote a release candidate draft to a new phase.""" @@ -330,6 +330,7 @@ class CommitteeParticipant(FoundationCommitter): ) project_name = release_for_pre_checks.safe_project_name version_name = release_for_pre_checks.safe_version_name + revision_number = release_for_pre_checks.safe_latest_revision_number # Check for ongoing tasks ongoing_tasks = await self.__tasks_ongoing(project_name, version_name, selected_revision_number) @@ -341,7 +342,7 @@ class CommitteeParticipant(FoundationCommitter): return "This release is not in the candidate draft phase" # Check that the revision number is the latest - if release_for_pre_checks.latest_revision_number != selected_revision_number: + if revision_number != selected_revision_number: return "The selected revision number does not match the latest revision number" # Check that there is at least one file in the draft @@ -356,7 +357,7 @@ class CommitteeParticipant(FoundationCommitter): .where( via(sql.Release.name) == release_for_pre_checks.name, via(sql.Release.phase) == sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT, - sql.latest_revision_number_query() == selected_revision_number, + sql.latest_revision_number_query() == str(selected_revision_number), ) .values( phase=sql.ReleasePhase.RELEASE_CANDIDATE, @@ -377,7 +378,7 @@ class CommitteeParticipant(FoundationCommitter): self.__write_as.append_to_audit_log( asf_uid=self.__asf_uid, release_name=str(release_name), - selected_revision_number=selected_revision_number, + selected_revision_number=str(selected_revision_number), vote_manual=vote_manual, ) return None @@ -765,14 +766,17 @@ class CommitteeParticipant(FoundationCommitter): moved_files_names.append(f.name) async def __tasks_ongoing( - self, project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str | None = None + self, + project_name: safe.ProjectName, + version_name: safe.VersionName, + revision_number: safe.RevisionNumber | None = None, ) -> int: tasks = sqlmodel.select(sqlalchemy.func.count()).select_from(sql.Task) query = tasks.where( sql.Task.project_name == str(project_name), sql.Task.version_name == str(version_name), sql.Task.revision_number - == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is None) else revision_number), + == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is None) else str(revision_number)), sql.validate_instrumented_attribute(sql.Task.status).in_([sql.TaskStatus.QUEUED, sql.TaskStatus.ACTIVE]), ) result = await self.__data.execute(query) diff --git a/atr/storage/writers/revision.py b/atr/storage/writers/revision.py index 0f9394d6..d0f37af6 100644 --- a/atr/storage/writers/revision.py +++ b/atr/storage/writers/revision.py @@ -196,7 +196,7 @@ async def _commit_new_revision( await attestable.write_files_data( project_name, version_name, - new_revision.number, + new_revision.safe_number, policy.model_dump() if policy else None, asf_uid, previous_attestable, @@ -216,7 +216,7 @@ async def _commit_new_revision( # It does, however, need a transaction to be created using data.begin() if release.phase == sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT: # Must use caller_data here because we acquired the write lock - await tasks.draft_checks(asf_uid, project_name, version_name, new_revision.number, caller_data=data) + await tasks.draft_checks(asf_uid, project_name, version_name, new_revision.safe_number, caller_data=data) return new_revision @@ -263,12 +263,14 @@ async def _lock_and_merge( if ( merge_enabled and (old_revision is not None) + and (latest is not None) and (prior_revision_name is not None) and (prior_revision_name != old_revision.name) ): merge_base_revision_name = prior_revision_name - prior_number = prior_revision_name.split()[-1] - prior_dir = paths.release_directory_base(merged_release) / prior_number + # This won't be None because prior_revision_name is not None here + prior_number = latest.safe_number + prior_dir = paths.release_directory_base(merged_release) / str(prior_number) await merge.merge( base_inodes, base_hashes, @@ -359,7 +361,7 @@ class CommitteeParticipant(FoundationCommitter): set_local_cache: bool = False, reset_to_global_cache: bool = False, modify: Callable[[pathlib.Path, sql.Revision | None], Awaitable[None]] | None = None, - clone_from: str | None = None, + clone_from: safe.RevisionNumber | None = None, ) -> sql.Revision | sql.Quarantined: """Create a new revision, quarantining archives that require validation.""" release_name = sql.release_name(str(project_name), str(version_name)) @@ -368,7 +370,7 @@ class CommitteeParticipant(FoundationCommitter): RuntimeError("Release does not exist for new revision creation") ) if clone_from is not None: - old_revision = await data.revision(release_name=release_name, number=clone_from).demand( + old_revision = await data.revision(release_name=release_name, number=str(clone_from)).demand( RuntimeError(f"Revision {clone_from} does not exist") ) else: @@ -379,7 +381,7 @@ class CommitteeParticipant(FoundationCommitter): release.check_cache_key = None if clone_from is not None: - old_release_dir = paths.release_directory_base(release) / clone_from + old_release_dir = paths.release_directory_base(release) / str(clone_from) else: old_release_dir = paths.release_directory(release) merge_enabled = clone_from is None @@ -427,7 +429,7 @@ class CommitteeParticipant(FoundationCommitter): try: path_to_hash, path_to_size = await attestable.paths_to_hashes_and_sizes(temp_dir_path) - parent_revision_number = old_revision.number if old_revision else None + parent_revision_number = old_revision.safe_number if old_revision else None previous_attestable = None if parent_revision_number is not None: previous_attestable = await attestable.load(project_name, version_name, parent_revision_number) diff --git a/atr/storage/writers/vote.py b/atr/storage/writers/vote.py index 1afb6d56..fb8e7c59 100644 --- a/atr/storage/writers/vote.py +++ b/atr/storage/writers/vote.py @@ -139,7 +139,7 @@ class CommitteeParticipant(FoundationCommitter): email_to: str, project_name: safe.ProjectName, version_name: safe.VersionName, - selected_revision_number: str, + selected_revision_number: safe.RevisionNumber, vote_duration_choice: int, subject: str, body_data: str, @@ -358,7 +358,7 @@ class CommitteeMember(CommitteeParticipant): fullname=asf_fullname, project_name=release.safe_project_name, version_name=release.safe_version_name, - revision_number=revision_number, + revision_number=release.safe_latest_revision_number, vote_duration=vote_duration, ) subject_data, body_data = await construct.start_vote_subject_and_body( @@ -369,7 +369,7 @@ class CommitteeMember(CommitteeParticipant): permitted_recipients=[incubator_vote_address], project_name=release.safe_project_name, version_name=release.safe_version_name, - selected_revision_number=revision_number, + selected_revision_number=release.safe_latest_revision_number, asf_uid=self.__asf_uid, asf_fullname=asf_fullname, vote_duration_choice=vote_duration, diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 4bdd1ae4..15c4e1e2 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -54,7 +54,7 @@ import atr.util as util async def asc_checks( - asf_uid: str, release: sql.Release, revision: str, signature_path: str, data: db.Session + asf_uid: str, release: sql.Release, revision: safe.RevisionNumber, signature_path: str, data: db.Session ) -> list[sql.Task | None]: """Create signature check task for a .asc file.""" tasks = [] @@ -135,13 +135,13 @@ async def draft_checks( asf_uid: str, project_name: safe.ProjectName, release_version: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, caller_data: db.Session | None = None, ) -> int: """Core logic to analyse a draft revision and queue checks.""" # Construct path to the specific revision # We don't have the release object here, so we can't use util.release_directory - revision_path = file_paths.get_unfinished_dir() / str(project_name) / str(release_version) / revision_number + revision_path = file_paths.get_unfinished_dir() / str(project_name) / str(release_version) / str(revision_number) relative_paths = [path async for path in util.paths_recursive(revision_path)] async with db.ensure_session(caller_data) as data: @@ -265,7 +265,7 @@ async def queued( asf_uid: str, task_type: sql.TaskType, release: sql.Release, - revision_number: str, + revision_number: safe.RevisionNumber, primary_rel_path: str | None = None, extra_args: dict[str, Any] | None = None, check_cache_key: dict[str, Any] | None = None, @@ -289,7 +289,7 @@ async def queued( asf_uid=asf_uid, project_name=release.project.name, version_name=release.version, - revision_number=revision_number, + revision_number=str(revision_number), primary_rel_path=primary_rel_path, inputs_hash=hash_val, ) @@ -352,7 +352,7 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results async def sha_checks( - asf_uid: str, release: sql.Release, revision: str, hash_file: str, data: db.Session + asf_uid: str, release: sql.Release, revision: safe.RevisionNumber, hash_file: str, data: db.Session ) -> list[sql.Task | None]: """Create hash check task for a .sha256 or .sha512 file.""" tasks = [] @@ -380,7 +380,7 @@ async def sha_checks( async def tar_gz_checks( - asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session + asf_uid: str, release: sql.Release, revision: safe.RevisionNumber, path: str, data: db.Session ) -> list[sql.Task | None]: """Create check tasks for a .tar.gz or .tgz file.""" # This release has committee, as guaranteed in draft_checks @@ -489,7 +489,7 @@ async def workflow_update( async def zip_checks( - asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session + asf_uid: str, release: sql.Release, revision: safe.RevisionNumber, path: str, data: db.Session ) -> list[sql.Task | None]: """Create check tasks for a .zip file.""" # This release has committee, as guaranteed in draft_checks @@ -592,7 +592,7 @@ async def _draft_file_checks( project_name: safe.ProjectName, release: sql.Release, release_version: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, ): path_str = str(path) task_function: Callable[[str, sql.Release, str, str, db.Session], Awaitable[list[sql.Task | None]]] | None = None @@ -603,7 +603,7 @@ async def _draft_file_checks( if task_function: for task in await task_function(asf_uid, release, revision_number, path_str, data): if task: - task.revision_number = revision_number + task.revision_number = str(revision_number) await _add_task(data, task) # TODO: Should we check .json files for their content? # Ideally we would not have to do that @@ -617,7 +617,7 @@ async def _draft_file_checks( extra_args={ "project_name": str(project_name), "version_name": str(release_version), - "revision_number": revision_number, + "revision_number": str(revision_number), "previous_release_version": previous_version.version if previous_version else None, "file_path": path_str, "asf_uid": asf_uid, diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index 99681200..2e552c9b 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -53,7 +53,7 @@ class FunctionArguments: asf_uid: str project_name: safe.ProjectName version_name: safe.VersionName - revision_number: str + revision_number: safe.RevisionNumber primary_rel_path: str | None extra_args: dict[str, Any] @@ -65,7 +65,7 @@ class Recorder: version_name: safe.VersionName primary_rel_path: str | None member_rel_path: str | None - revision_number: str + revision_number: safe.RevisionNumber afresh: bool __cached: bool __input_hash: str | None @@ -76,7 +76,7 @@ class Recorder: inputs_hash: str | None, project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, primary_rel_path: str | None = None, member_rel_path: str | None = None, afresh: bool = True, @@ -102,7 +102,7 @@ class Recorder: inputs_hash: str, project_name: safe.ProjectName, version_name: safe.VersionName, - revision_number: str, + revision_number: safe.RevisionNumber, primary_rel_path: str | None = None, member_rel_path: str | None = None, afresh: bool = True, @@ -146,7 +146,7 @@ class Recorder: result = sql.CheckResult( release_name=str(self.release_name), - revision_number=self.revision_number, + revision_number=str(self.revision_number), checker=self.checker, primary_rel_path=primary_rel_path or self.primary_rel_path, member_rel_path=member_rel_path, @@ -317,7 +317,7 @@ async def resolve_cache_key( checker_version: str, policy_keys: list[str], release: sql.Release, - revision: str, + revision: safe.RevisionNumber, args: dict[str, Any] | None = None, file: str | None = None, path: pathlib.Path | None = None, @@ -384,7 +384,7 @@ async def _resolve_all_files(release: sql.Release, rel_path: str | None = None) return [] if not ( base_path := file_paths.base_path_for_revision( - release.safe_project_name, release.safe_version_name, release.latest_revision_number + release.safe_project_name, release.safe_version_name, release.safe_latest_revision_number ) ): return [] @@ -407,7 +407,7 @@ async def _resolve_github_tp_sha(release: sql.Release, rel_path: str | None = No if not release.latest_revision_number: return "" payload_path = attestable.github_tp_payload_path( - release.safe_project_name, release.safe_version_name, release.latest_revision_number + release.safe_project_name, release.safe_version_name, release.safe_latest_revision_number ) if not await aiofiles.os.path.isfile(payload_path): return "" @@ -435,7 +435,7 @@ async def _resolve_unsuffixed_file_hash(release: sql.Release, rel_path: str | No if (not rel_path) or (not release.latest_revision_number): return "" abs_path = file_paths.revision_path_for_file( - release.safe_project_name, release.safe_version_name, release.latest_revision_number, rel_path + release.safe_project_name, release.safe_version_name, release.safe_latest_revision_number, rel_path ) plain_path = abs_path.with_suffix("") if await aiofiles.os.path.isfile(plain_path): diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py index f2d29881..54906530 100644 --- a/atr/tasks/checks/compare.py +++ b/atr/tasks/checks/compare.py @@ -374,7 +374,7 @@ async def _find_archive_root(archive_path: pathlib.Path, extract_dir: pathlib.Pa async def _load_tp_payload( - project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str + project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: safe.RevisionNumber ) -> github_models.TrustedPublisherPayload | None: payload_path = attestable.github_tp_payload_path(project_name, version_name, revision_number) if not await aiofiles.os.path.isfile(payload_path): diff --git a/atr/tasks/quarantine.py b/atr/tasks/quarantine.py index 81beef4f..85f261cf 100644 --- a/atr/tasks/quarantine.py +++ b/atr/tasks/quarantine.py @@ -320,7 +320,7 @@ async def _promote( previous_attestable = None if old_revision is not None: - previous_attestable = await attestable.load(project_name, version_name, old_revision.number) + previous_attestable = await attestable.load(project_name, version_name, old_revision.safe_number) base_inodes: dict[str, int] = {} base_hashes: dict[str, str] = {} diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py index 665daa4b..3f6ea9a2 100644 --- a/atr/tasks/vote.py +++ b/atr/tasks/vote.py @@ -78,7 +78,7 @@ async def _initiate_core_logic(args: Initiate) -> results.Results | None: raise VoteInitiationError(f"No revisions found for release {args.release_name!s}") ongoing_tasks = await interaction.tasks_ongoing( - release.safe_project_name, release.safe_version_name, latest_revision_number + release.safe_project_name, release.safe_version_name, release.safe_latest_revision_number ) if ongoing_tasks > 0: raise VoteInitiationError( diff --git a/atr/web.py b/atr/web.py index 66671007..e8cbd8e0 100644 --- a/atr/web.py +++ b/atr/web.py @@ -165,7 +165,7 @@ class Committer: project_name: safe.ProjectName, version_name: safe.VersionName, phase: sql.ReleasePhase | db.NotSet | None = db.NOT_SET, - latest_revision_number: str | db.NotSet | None = db.NOT_SET, + latest_revision_number: safe.RevisionNumber | db.NotSet | None = db.NOT_SET, data: db.Session | None = None, with_committee: bool = True, with_project: bool = True, @@ -182,13 +182,14 @@ class Committer: phase_value = sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT else: phase_value = phase + revision = db.NOT_SET if latest_revision_number == db.NOT_SET else str(latest_revision_number) release_name = sql.release_name(project_name, version_name) if data is None: async with db.session() as data: release = await data.release( name=str(release_name), phase=phase_value, - latest_revision_number=latest_revision_number, + latest_revision_number=revision, _committee=with_committee, _project=with_project, _release_policy=with_release_policy, @@ -200,7 +201,7 @@ class Committer: release = await data.release( name=str(release_name), phase=phase_value, - latest_revision_number=latest_revision_number, + latest_revision_number=revision, _committee=with_committee, _project=with_project, _release_policy=with_release_policy, diff --git a/atr/worker.py b/atr/worker.py index 1694069f..1f65865a 100644 --- a/atr/worker.py +++ b/atr/worker.py @@ -121,6 +121,7 @@ async def _execute_check_task( project_name = safe.ProjectName(task_obj.project_name) version_name = safe.VersionName(task_obj.version_name) + revision_number = safe.RevisionNumber(task_obj.revision_number) async def recorder_factory() -> checks.Recorder: return await checks.Recorder.create( @@ -128,7 +129,7 @@ async def _execute_check_task( inputs_hash=task_obj.inputs_hash or "", project_name=project_name, version_name=version_name, - revision_number=task_obj.revision_number or "", + revision_number=revision_number, primary_rel_path=task_obj.primary_rel_path, ) @@ -137,7 +138,7 @@ async def _execute_check_task( asf_uid=task_obj.asf_uid, project_name=project_name, version_name=version_name, - revision_number=task_obj.revision_number, + revision_number=revision_number, primary_rel_path=task_obj.primary_rel_path, extra_args=task_args, ) diff --git a/tests/unit/test_create_revision.py b/tests/unit/test_create_revision.py index 8a94458d..82d07ad3 100644 --- a/tests/unit/test_create_revision.py +++ b/tests/unit/test_create_revision.py @@ -21,6 +21,7 @@ import unittest.mock as mock import pytest +import atr.models.safe as safe import atr.models.sql as sql import atr.storage.types as types import atr.storage.writers.revision as revision @@ -58,6 +59,10 @@ class FakeRevision: self.release_name = release_name self.was_quarantined = was_quarantined + @property + def safe_number(self) -> safe.RevisionNumber: + return safe.RevisionNumber(self.number) + class MockSafeData: def __init__(self, parent_name: str, new_number: str = "00006"): @@ -109,10 +114,12 @@ async def test_clone_from_older_revision_skips_merge_without_intervening_change( latest_revision = mock.MagicMock() latest_revision.name = f"{release_name} 00005" latest_revision.number = "00005" + latest_revision.safe_number = safe.RevisionNumber("00005") selected_revision = mock.MagicMock() selected_revision.name = f"{release_name} 00002" selected_revision.number = "00002" + selected_revision.safe_number = safe.RevisionNumber("00002") mock_session = _mock_db_session(release, selected_revision=selected_revision) participant = _make_participant() @@ -149,7 +156,9 @@ async def test_clone_from_older_revision_skips_merge_without_intervening_change( mock.patch.object(revision.paths, "release_directory", return_value=tmp_path / "releases" / "00006"), mock.patch.object(revision.paths, "release_directory_base", return_value=tmp_path / "releases"), ): - await participant.create_revision_with_quarantine("proj", "1.0", "test", clone_from="00002") + await participant.create_revision_with_quarantine( + safe.ProjectName("proj"), safe.VersionName("1.0"), "test", clone_from=safe.RevisionNumber("00002") + ) if merge_mock.called: raise AssertionError( @@ -193,10 +202,12 @@ async def test_intervening_revision_triggers_merge_and_uses_latest_parent(tmp_pa old_revision = mock.MagicMock() old_revision.name = f"{release_name} 00005" old_revision.number = "00005" + old_revision.safe_number = safe.RevisionNumber("00005") intervening_revision = mock.MagicMock() intervening_revision.name = f"{release_name} 00006" intervening_revision.number = "00006" + intervening_revision.safe_number = safe.RevisionNumber("00006") first_attestable = mock.MagicMock(paths={"dist/a.tar.gz": "h1"}) second_attestable = mock.MagicMock(paths={"dist/b.tar.gz": "h2"}) @@ -244,7 +255,7 @@ async def test_intervening_revision_triggers_merge_and_uses_latest_parent(tmp_pa merge_await_args = merge_mock.await_args assert merge_await_args is not None - assert merge_await_args.args[5] == "00006" + assert merge_await_args.args[5] == safe.RevisionNumber("00006") @pytest.mark.asyncio --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
