This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch taint_tracking_types in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 3e3c8d92fb107b6ecba4ebe2e8638a2992e6c6d6 Author: Alastair McFarlane <[email protected]> AuthorDate: Wed Feb 25 11:56:34 2026 +0000 First cut of taint tracking types for project and version --- atr/blueprints/get.py | 165 ++++++++++++++++++++++++++++++++++++++++++++----- atr/blueprints/post.py | 32 +++++----- atr/cache.py | 62 +++++++++++++++++++ atr/construct.py | 25 ++++---- atr/db/__init__.py | 5 +- atr/db/interaction.py | 9 ++- atr/get/announce.py | 18 ++++-- atr/get/checklist.py | 11 +++- atr/get/checks.py | 27 +++++--- atr/get/compose.py | 12 +++- atr/models/safe.py | 62 +++++++++++++++++++ atr/models/unsafe.py | 28 +++++++++ atr/server.py | 4 ++ atr/shared/web.py | 16 ++--- atr/web.py | 12 ++-- 15 files changed, 410 insertions(+), 78 deletions(-) diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py index f715defd..1af5d8e4 100644 --- a/atr/blueprints/get.py +++ b/atr/blueprints/get.py @@ -15,18 +15,23 @@ # specific language governing permissions and limitations # under the License. +import inspect import time from collections.abc import Awaitable, Callable from types import ModuleType -from typing import Any +from typing import Any, Literal, get_args, get_origin, get_type_hints import asfquart.auth as auth import asfquart.base as base import asfquart.session import quart +import atr.cache as cache +import atr.db as db import atr.ldap as ldap import atr.log as log +import atr.models.safe as safe +import atr.models.sql as sql import atr.web as web _BLUEPRINT_NAME = "get_blueprint" @@ -34,17 +39,151 @@ _BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__) _routes: list[str] = [] +async def _authenticate() -> web.Committer: + web_session = await asfquart.session.read() + if web_session is None: + raise base.ASFQuartException("Not authenticated", errorcode=401) + if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): + asfquart.session.clear() + raise base.ASFQuartException("Account is disabled", errorcode=401) + return web.Committer(web_session) + + +def _register(func: Callable[..., Any]) -> None: + module_name = func.__module__.split(".")[-1] + _routes.append(f"get.{module_name}.{func.__name__}") + + +async def _validate_project(raw: str) -> safe.ProjectName: + if cache.project_version_has_project(raw): + return safe.ProjectName(raw) + async with db.session() as data: + project = await data.project(name=raw, status=sql.ProjectStatus.ACTIVE, _committee=False).get() + if project is None: + raise base.ASFQuartException(f"Project {raw!r} not found", errorcode=404) + return safe.ProjectName(project.name) + + +async def _validate_version(project_name: safe.ProjectName, raw: str) -> safe.VersionName: + if cache.project_version_has_version(project_name, raw): + return safe.VersionName(raw) + async with db.session() as data: + release = await data.release( + project_name=str(project_name), + version=raw, + _project=False, + _committee=False, + ).get() + if release is None: + raise base.ASFQuartException(f"Version {raw!r} not found for project {project_name!s}", errorcode=404) + return safe.VersionName(release.version) + + +_QUART_CONVERTERS: dict[type, str] = { + int: "int", + float: "float", +} + +_VALIDATED_TYPES: set[Any] = {safe.ProjectName, safe.VersionName} + + +def _build_path(func: Callable[..., Any]) -> tuple[str, list[tuple[str, type]], dict[str, str]]: + """Inspect a function's type hints to build a URL path and a validation plan. + + Returns (path, validated_params, literal_params) where validated_params is a + list of (param_name, param_type) for each parameter that needs async + validation, and literal_params maps parameter names to their values. + """ + hints = get_type_hints(func, include_extras=True) + params = list(inspect.signature(func).parameters.keys()) + segments: list[str] = [] + validated_params: list[tuple[str, type]] = [] + literal_params: dict[str, str] = {} + + for param_name in params: + # This is the session object + if param_name == "session": + continue + + hint = hints.get(param_name) + if hint is None: + raise TypeError(f"Parameter {param_name!r} in {func.__name__} has no type annotation") + + origin = get_origin(hint) + + if origin is Literal: + literal_value = get_args(hint)[0] + segments.append(str(literal_value)) + literal_params[param_name] = str(literal_value) + elif hint in _VALIDATED_TYPES: + segments.append(f"<{param_name}>") + validated_params.append((param_name, hint)) + elif hint in _QUART_CONVERTERS: + converter = _QUART_CONVERTERS[hint] + segments.append(f"<{converter}:{param_name}>") + elif hint is str: + segments.append(f"<{param_name}>") + else: + raise TypeError(f"Parameter {param_name!r} in {func.__name__} has unsupported type {hint!r}") + + path = "/" + "/".join(segments) + return path, validated_params, literal_params + + +async def _run_validators(kwargs: dict[str, Any], validated_params: list[tuple[str, type]]) -> None: + """Validate URL parameters in order, using the cache/DB validators.""" + for param_name, param_type in validated_params: + raw = kwargs[param_name] + if param_type is safe.ProjectName: + kwargs[param_name] = await _validate_project(raw) + elif param_type is safe.VersionName: + project_name = kwargs.get("project_name", "") + kwargs[param_name] = await _validate_version(project_name, raw) + + +def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]: + """Decorator that derives the URL path from the function's type annotations. + + - Literal["..."] parameters become literal path segments + - safe.ProjectName / safe.VersionName parameters are validated via cache/DB + - int, float use Quart's built-in type converters + - str parameters pass through as-is + """ + path, validated_params, literal_params = _build_path(func) + + async def wrapper(*_args: Any, **kwargs: Any) -> Any: + enhanced_session = await _authenticate() + await _run_validators(kwargs, validated_params) + kwargs.update(literal_params) + + start_time_ns = time.perf_counter_ns() + response = await func(enhanced_session, **kwargs) + end_time_ns = time.perf_counter_ns() + total_ns = end_time_ns - start_time_ns + total_ms = total_ns // 1_000_000 + + log.performance( + f"GET {path} {func.__name__} = 0 0 {total_ms}", + ) + + return response + + endpoint = func.__module__.replace(".", "_") + "_" + func.__name__ + wrapper.__name__ = func.__name__ + wrapper.__doc__ = func.__doc__ + wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint + + decorated = auth.require(auth.Requirements.committer)(wrapper) + _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["GET"]) + _register(func) + + return decorated + + def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: - web_session = await asfquart.session.read() - if web_session is None: - raise base.ASFQuartException("Not authenticated", errorcode=401) - if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): - asfquart.session.clear() - raise base.ASFQuartException("Account is disabled", errorcode=401) - - enhanced_session = web.Committer(web_session) + enhanced_session = await _authenticate() start_time_ns = time.perf_counter_ns() response = await func(enhanced_session, *args, **kwargs) end_time_ns = time.perf_counter_ns() @@ -65,9 +204,7 @@ def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.Rout decorated = auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["GET"]) - - module_name = func.__module__.split(".")[-1] - _routes.append(f"get.{module_name}.{func.__name__}") + _register(func) return decorated @@ -87,9 +224,7 @@ def public(path: str) -> Callable[[Callable[..., Awaitable[Any]]], web.RouteFunc wrapper.__annotations__["endpoint"] = _BLUEPRINT_NAME + "." + endpoint _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["GET"]) - - module_name = func.__module__.split(".")[-1] - _routes.append(f"get.{module_name}.{func.__name__}") + _register(func) return wrapper diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py index 8c9133a0..dca1ccc2 100644 --- a/atr/blueprints/post.py +++ b/atr/blueprints/post.py @@ -37,17 +37,25 @@ _BLUEPRINT = quart.Blueprint(_BLUEPRINT_NAME, __name__) _routes: list[str] = [] +async def _authenticate() -> web.Committer: + web_session = await asfquart.session.read() + if web_session is None: + raise base.ASFQuartException("Not authenticated", errorcode=401) + if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): + asfquart.session.clear() + raise base.ASFQuartException("Account is disabled", errorcode=401) + return web.Committer(web_session) + + +def _register(func: Callable[..., Any]) -> None: + module_name = func.__module__.split(".")[-1] + _routes.append(f"post.{module_name}.{func.__name__}") + + def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.RouteFunction[Any]]: def decorator(func: web.CommitterRouteFunction[Any]) -> web.RouteFunction[Any]: async def wrapper(*args: Any, **kwargs: Any) -> Any: - web_session = await asfquart.session.read() - if web_session is None: - raise base.ASFQuartException("Not authenticated", errorcode=401) - if (web_session.uid is None) or (not await ldap.is_active(web_session.uid)): - asfquart.session.clear() - raise base.ASFQuartException("Account is disabled", errorcode=401) - - enhanced_session = web.Committer(web_session) + enhanced_session = await _authenticate() start_time_ns = time.perf_counter_ns() response = await func(enhanced_session, *args, **kwargs) end_time_ns = time.perf_counter_ns() @@ -68,9 +76,7 @@ def committer(path: str) -> Callable[[web.CommitterRouteFunction[Any]], web.Rout decorated = auth.require(auth.Requirements.committer)(wrapper) _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=decorated, methods=["POST"]) - - module_name = func.__module__.split(".")[-1] - _routes.append(f"post.{module_name}.{func.__name__}") + _register(func) return decorated @@ -187,9 +193,7 @@ def public(path: str) -> Callable[[Callable[..., Awaitable[Any]]], web.RouteFunc wrapper.__name__ = func.__name__ _BLUEPRINT.add_url_rule(path, endpoint=endpoint, view_func=wrapper, methods=["POST"]) - - module_name = func.__module__.split(".")[-1] - _routes.append(f"post.{module_name}.{func.__name__}") + _register(func) return wrapper diff --git a/atr/cache.py b/atr/cache.py index f652a02d..e578f0aa 100644 --- a/atr/cache.py +++ b/atr/cache.py @@ -27,11 +27,14 @@ import pydantic import atr.config as config import atr.ldap as ldap import atr.log as log +import atr.models.safe as safe import atr.models.schema as schema # Fifth prime after 3600 ADMINS_POLL_INTERVAL_SECONDS: Final[int] = 3631 +PROJECT_VERSION_POLL_INTERVAL_SECONDS: Final[int] = 307 + class AdminsCache(schema.Strict): refreshed: datetime.datetime = schema.description("When the cache was last refreshed") @@ -92,6 +95,39 @@ async def admins_startup_load() -> None: log.warning(f"Failed to fetch admin users from LDAP at startup: {e}") +def project_version_get() -> dict[str, set[str]]: + if asfquart.APP is not None: + return asfquart.APP.extensions.get("project_versions", {}) + return {} + + +def project_version_has_project(project_name: str) -> bool: + return project_name in project_version_get() + + +def project_version_has_version(project_name: safe.ProjectName, version_name: str) -> bool: + projects = project_version_get() + if str(project_name) not in projects: + return False + return version_name in projects[str(project_name)] + + +async def project_version_refresh_loop() -> None: + while True: + await asyncio.sleep(PROJECT_VERSION_POLL_INTERVAL_SECONDS) + try: + await _project_version_refresh() + except Exception as e: + log.warning(f"Project/version cache refresh failed: {e}") + + +async def project_version_startup_load() -> None: + try: + await _project_version_refresh() + except Exception as e: + log.warning(f"Failed to populate project/version cache at startup: {e}") + + def _admins_path() -> pathlib.Path: return pathlib.Path(config.get().STATE_DIR) / "cache" / "admins.json" @@ -134,3 +170,29 @@ def _admins_update_app_extensions(admins: frozenset[str]) -> None: app = asfquart.APP app.extensions["admins"] = admins app.extensions["admins_refreshed"] = datetime.datetime.now(datetime.UTC) + + +async def _project_version_fetch_from_db() -> dict[str, set[str]]: + import atr.db as db + import atr.models.sql as sql + + projects: dict[str, set[str]] = {} + async with db.session() as data: + all_projects = await data.project(status=sql.ProjectStatus.ACTIVE, _committee=False).all() + for project in all_projects: + all_releases = await data.release(project_name=project.name, _project=False, _committee=False).all() + projects[project.name] = {release.version for release in all_releases} + return projects + + +async def _project_version_refresh() -> None: + projects = await _project_version_fetch_from_db() + _project_version_update_app_extensions(projects) + total_versions = sum(len(v) for v in projects.values()) + log.info(f"Project/version cache refreshed: {len(projects)} projects, {total_versions} versions") + + +def _project_version_update_app_extensions(projects: dict[str, set[str]]) -> None: + app = asfquart.APP + app.extensions["project_versions"] = projects + app.extensions["project_versions_refreshed"] = datetime.datetime.now(datetime.UTC) diff --git a/atr/construct.py b/atr/construct.py index f36ba1a6..6a639f47 100644 --- a/atr/construct.py +++ b/atr/construct.py @@ -25,6 +25,7 @@ import quart import atr.config as config import atr.db as db +import atr.models.safe as safe import atr.models.sql as sql import atr.paths as paths import atr.util as util @@ -53,8 +54,8 @@ TEMPLATE_VARIABLES: list[tuple[str, str, set[Context]]] = [ class AnnounceReleaseOptions: asfuid: str fullname: str - project_name: str - version_name: str + project_name: safe.ProjectName + version_name: safe.VersionName revision_number: str @@ -68,11 +69,11 @@ class StartVoteOptions: vote_duration: int -async def announce_release_default(project_name: str) -> str: +async def announce_release_default(project_name: safe.ProjectName) -> str: async with db.session() as data: - project = await data.project(name=project_name, status=sql.ProjectStatus.ACTIVE, _release_policy=True).demand( - RuntimeError(f"Project {project_name} not found") - ) + project = await data.project( + name=str(project_name), status=sql.ProjectStatus.ACTIVE, _release_policy=True + ).demand(RuntimeError(f"Project {project_name} not found")) return project.policy_announce_release_template @@ -131,11 +132,11 @@ async def announce_release_subject_and_body( return subject, body -async def announce_release_subject_default(project_name: str) -> str: +async def announce_release_subject_default(project_name: safe.ProjectName) -> str: async with db.session() as data: - project = await data.project(name=project_name, status=sql.ProjectStatus.ACTIVE, _release_policy=True).demand( - RuntimeError(f"Project {project_name} not found") - ) + project = await data.project( + name=str(project_name), status=sql.ProjectStatus.ACTIVE, _release_policy=True + ).demand(RuntimeError(f"Project {project_name} not found")) return project.policy_announce_release_subject @@ -151,7 +152,7 @@ def announce_template_variables() -> list[tuple[str, str]]: def checklist_body( markdown: str, project: sql.Project, - version_name: str, + version_name: safe.VersionName, committee: sql.Committee, revision: sql.Revision | None, ) -> str: @@ -172,7 +173,7 @@ def checklist_body( markdown = markdown.replace("{{REVIEW_URL}}", review_url) markdown = markdown.replace("{{REVISION}}", revision_number) markdown = markdown.replace("{{TAG}}", revision_tag) - markdown = markdown.replace("{{VERSION}}", version_name) + markdown = markdown.replace("{{VERSION}}", str(version_name)) return markdown diff --git a/atr/db/__init__.py b/atr/db/__init__.py index 97b1f967..a4c62e3a 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -34,6 +34,7 @@ import sqlmodel.sql.expression as expression import atr.config as config import atr.log as log +import atr.models.safe as safe import atr.models.schema as schema import atr.models.sql as sql import atr.util as util @@ -478,9 +479,9 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): name: Opt[str] = NOT_SET, phase: Opt[sql.ReleasePhase] = NOT_SET, created: Opt[datetime.datetime] = NOT_SET, - project_name: Opt[str] = NOT_SET, + project_name: Opt[safe.ProjectName] = NOT_SET, package_managers: Opt[list[str]] = NOT_SET, - version: Opt[str] = NOT_SET, + version: Opt[safe.VersionName] = NOT_SET, sboms: Opt[list[str]] = NOT_SET, release_policy_id: Opt[int] = NOT_SET, votes: Opt[list[sql.VoteEntry]] = NOT_SET, diff --git a/atr/db/interaction.py b/atr/db/interaction.py index 6e7e17a5..3213bc74 100644 --- a/atr/db/interaction.py +++ b/atr/db/interaction.py @@ -32,6 +32,7 @@ import atr.jwtoken as jwtoken import atr.ldap as ldap import atr.log as log import atr.models.results as results +import atr.models.safe as safe import atr.models.sql as sql import atr.user as user import atr.util as util @@ -391,12 +392,14 @@ def task_recipient_get(latest_vote_task: sql.Task) -> str | None: return result.email_to -async def tasks_ongoing(project_name: str, version_name: str, revision_number: str | None = None) -> int: +async def tasks_ongoing( + project_name: safe.ProjectName, version_name: safe.VersionName, revision_number: str | None = None +) -> int: tasks = sqlmodel.select(sqlalchemy.func.count()).select_from(sql.Task) async with db.session() as data: query = tasks.where( - sql.Task.project_name == project_name, - sql.Task.version_name == version_name, + sql.Task.project_name == str(project_name), + sql.Task.version_name == str(version_name), sql.Task.revision_number == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is None) else revision_number), sql.validate_instrumented_attribute(sql.Task.status).in_([sql.TaskStatus.QUEUED, sql.TaskStatus.ACTIVE]), diff --git a/atr/get/announce.py b/atr/get/announce.py index 235c8463..1f41617c 100644 --- a/atr/get/announce.py +++ b/atr/get/announce.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +from typing import Literal import htpy import markupsafe @@ -26,6 +26,7 @@ import atr.construct as construct import atr.form as form import atr.get.projects as projects import atr.htm as htm +import atr.models.safe as safe import atr.models.sql as sql import atr.post as post import atr.render as render @@ -35,12 +36,17 @@ import atr.util as util import atr.web as web [email protected]("/announce/<project_name>/<version_name>") -async def selected(session: web.Committer, project_name: str, version_name: str) -> str | web.WerkzeugResponse: [email protected] +async def selected( + _announce: Literal["announce"], + session: web.Committer, + project_name: safe.ProjectName, + version_name: safe.VersionName, +) -> str | web.WerkzeugResponse: """Allow the user to announce a release preview.""" await session.check_access(project_name) - release = await _get_page_data(project_name, session, version_name) + release = await _get_page_data(session, project_name, version_name) latest_revision_number = release.latest_revision_number if latest_revision_number is None: @@ -102,7 +108,9 @@ async def selected(session: web.Committer, project_name: str, version_name: str) ) -async def _get_page_data(project_name: str, session: web.Committer, version_name: str) -> sql.Release: +async def _get_page_data( + session: web.Committer, project_name: safe.ProjectName, version_name: safe.VersionName +) -> sql.Release: release = await session.release( project_name, version_name, diff --git a/atr/get/checklist.py b/atr/get/checklist.py index d33db173..f1c4ddb1 100644 --- a/atr/get/checklist.py +++ b/atr/get/checklist.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Literal import cmarkgfm import markupsafe @@ -25,14 +26,20 @@ import atr.db as db import atr.db.interaction as interaction import atr.get.vote as vote import atr.htm as htm +import atr.models.safe as safe import atr.render as render import atr.template as template import atr.util as util import atr.web as web [email protected]("/checklist/<project_name>/<version_name>") -async def selected(session: web.Committer | None, project_name: str, version_name: str) -> str: [email protected] +async def selected( + session: web.Committer | None, + _checklist: Literal["checklist"], + project_name: safe.ProjectName, + version_name: safe.VersionName +) -> str: async with db.session() as data: release = await data.release( project_name=project_name, diff --git a/atr/get/checks.py b/atr/get/checks.py index 8faa5ce2..820ca1ea 100644 --- a/atr/get/checks.py +++ b/atr/get/checks.py @@ -17,7 +17,7 @@ import pathlib from collections.abc import Callable -from typing import NamedTuple +from typing import Literal, NamedTuple import asfquart.base as base import htpy @@ -33,6 +33,7 @@ import atr.get.report as report import atr.get.sbom as sbom import atr.get.vote as vote import atr.htm as htm +import atr.models.safe as safe import atr.models.sql as sql import atr.paths as paths import atr.post as post @@ -97,8 +98,13 @@ async def get_file_totals(release: sql.Release, session: web.Committer | None) - return totals [email protected]("/checks/<project_name>/<version_name>") -async def selected(session: web.Committer | None, project_name: str, version_name: str) -> str: [email protected] +async def selected( + session: web.Committer | None, + _checks: Literal["checks"], + project_name: safe.ProjectName, + version_name: safe.VersionName +) -> str: """Show the file checks for a release candidate.""" async with db.session() as data: release = await data.release( @@ -134,11 +140,12 @@ async def selected(session: web.Committer | None, project_name: str, version_nam ) [email protected]("/checks/<project_name>/<version_name>/<revision_number>") [email protected] async def selected_revision( session: web.Committer, - project_name: str, - version_name: str, + _checks: Literal["checks"], + project_name: safe.ProjectName, + version_name: safe.VersionName, revision_number: str, ) -> web.QuartResponse: """Return JSON with ongoing count and HTML fragments for dynamic updates.""" @@ -161,7 +168,7 @@ async def selected_revision( ongoing_count = await interaction.tasks_ongoing(project_name, version_name, revision_number) - checks_summary_elem = shared.web._render_checks_summary(info, project_name, version_name) + checks_summary_elem = shared.web.render_checks_summary(info, project_name, version_name) checks_summary_html = str(checks_summary_elem) if checks_summary_elem else "" delete_file_forms: dict[str, str] = {} @@ -170,7 +177,7 @@ async def selected_revision( delete_file_forms[str(path)] = str( form.render( model_cls=draft.DeleteFileForm, - action=util.as_url(post.draft.delete_file, project_name=project_name, version_name=version_name), + action=util.as_url(post.draft.delete_file, project_name=str(project_name), version_name=str(version_name)), form_classes=".d-inline-block.m-0", submit_classes="btn-sm btn-outline-danger", submit_label="Delete", @@ -188,8 +195,8 @@ async def selected_revision( "check-selected-path-table.html", paths=paths, info=info, - project_name=project_name, - version_name=version_name, + project_name=str(project_name), + version_name=str(version_name), release=release, phase=release.phase.value, delete_file_forms=delete_file_forms, diff --git a/atr/get/compose.py b/atr/get/compose.py index 6e400872..cd7a1a71 100644 --- a/atr/get/compose.py +++ b/atr/get/compose.py @@ -15,18 +15,26 @@ # specific language governing permissions and limitations # under the License. +from typing import Literal + import asfquart.base as base import atr.blueprints.get as get import atr.db as db import atr.mapping as mapping +import atr.models.safe as safe import atr.models.sql as sql import atr.shared as shared import atr.web as web [email protected]("/compose/<project_name>/<version_name>") -async def selected(session: web.Committer, project_name: str, version_name: str) -> web.WerkzeugResponse | str: [email protected] +async def selected( + session: web.Committer, + _compose: Literal["compose"], + project_name: safe.ProjectName, + version_name: safe.VersionName, +) -> web.WerkzeugResponse | str: """Show the contents of the release candidate draft.""" await session.check_access(project_name) diff --git a/atr/models/safe.py b/atr/models/safe.py new file mode 100644 index 00000000..c6f9893b --- /dev/null +++ b/atr/models/safe.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class ProjectName: + """A project name that has been validated against the cache or database.""" + + __slots__ = ("_value",) + + def __init__(self, value: str) -> None: + self._value = value + + def __eq__(self, other: object) -> bool: + if isinstance(other, ProjectName): + return self._value == other._value + return NotImplemented + + def __hash__(self) -> int: + return hash(self._value) + + def __repr__(self) -> str: + return f"ProjectName({self._value!r})" + + def __str__(self) -> str: + return self._value + + +class VersionName: + """A version name that has been validated against the cache or database.""" + + __slots__ = ("_value",) + + def __init__(self, value: str) -> None: + self._value = value + + def __eq__(self, other: object) -> bool: + if isinstance(other, VersionName): + return self._value == other._value + return NotImplemented + + def __hash__(self) -> int: + return hash(self._value) + + def __repr__(self) -> str: + return f"VersionName({self._value!r})" + + def __str__(self) -> str: + return self._value diff --git a/atr/models/unsafe.py b/atr/models/unsafe.py new file mode 100644 index 00000000..55174c0d --- /dev/null +++ b/atr/models/unsafe.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class UnsafeStr: + """A raw string from URL routing that has not been validated.""" + + __slots__ = ("_value",) + + def __init__(self, value: str) -> None: + self._value = value + + def __repr__(self) -> str: + return f"UnsafeStr({self._value!r})" diff --git a/atr/server.py b/atr/server.py index 6dcee179..fe73716a 100644 --- a/atr/server.py +++ b/atr/server.py @@ -277,6 +277,10 @@ def _app_setup_lifecycle(app: base.QuartApp, app_config: type[config.AppConfig]) admins_task = asyncio.create_task(cache.admins_refresh_loop()) app.extensions["admins_task"] = admins_task + await cache.project_version_startup_load() + project_version_task = asyncio.create_task(cache.project_version_refresh_loop()) + app.extensions["project_version_task"] = project_version_task + worker_manager = manager.get_worker_manager() await worker_manager.start() diff --git a/atr/shared/web.py b/atr/shared/web.py index 12a3a411..640026f5 100644 --- a/atr/shared/web.py +++ b/atr/shared/web.py @@ -24,6 +24,7 @@ import atr.form as form import atr.get as get import atr.htm as htm import atr.models.results as results +import atr.models.safe as safe import atr.models.sql as sql import atr.paths as paths import atr.post as post @@ -140,7 +141,7 @@ async def check( is_local_caching = release.check_cache_key is not None - checks_summary_html = _render_checks_summary(info, release.project.name, release.version) + checks_summary_html = render_checks_summary(info, release.project.name, release.version) return await template.render( "check-selected.html", @@ -179,12 +180,9 @@ async def check( checks_summary_html=checks_summary_html, ) - -def _checker_display_name(checker: str) -> str: - return checker.removeprefix("atr.tasks.checks.").replace("_", " ").replace(".", " ").title() - - -def _render_checks_summary(info: types.PathInfo | None, project_name: str, version_name: str) -> htm.Element | None: +def render_checks_summary( + info: types.PathInfo | None, project_name: safe.ProjectName, version_name: safe.VersionName +) -> htm.Element | None: if (info is None) or (not info.checker_stats): return None @@ -210,7 +208,7 @@ def _render_checks_summary(info: types.PathInfo | None, project_name: str, versi files_div = htm.Block(htm.div, classes=".mt-2.atr-checks-files") all_files = set(stat.failure_files.keys()) | set(stat.warning_files.keys()) | set(stat.blocker_files.keys()) for file_path in sorted(all_files): - report_url = f"/report/{project_name}/{version_name}/{file_path}" + report_url = f"/report/{project_name!s}/{version_name!s}/{file_path}" error_count = stat.failure_files.get(file_path, 0) blocker_count = stat.blocker_files.get(file_path, 0) warning_count = stat.warning_files.get(file_path, 0) @@ -234,6 +232,8 @@ def _render_checks_summary(info: types.PathInfo | None, project_name: str, versi card.append(body.collect()) return card.collect() +def _checker_display_name(checker: str) -> str: + return checker.removeprefix("atr.tasks.checks.").replace("_", " ").replace(".", " ").title() def _warnings_from_vote_result(vote_task: sql.Task | None) -> list[str]: # TODO: Replace this with a schema.Strict model diff --git a/atr/web.py b/atr/web.py index 4d7d8c9c..2a2834f4 100644 --- a/atr/web.py +++ b/atr/web.py @@ -32,6 +32,7 @@ import atr.config as config import atr.db as db import atr.form as form import atr.htm as htm +import atr.models.safe as safe import atr.models.sql as sql import atr.user as user import atr.util as util @@ -42,6 +43,7 @@ if TYPE_CHECKING: import pydantic import werkzeug.wrappers.response as response + R = TypeVar("R", covariant=True) type WerkzeugResponse = response.Response @@ -86,8 +88,8 @@ class Committer: def is_admin(self) -> bool: return user.is_admin(self.uid) - async def check_access(self, project_name: str) -> None: - if not any((p.name == project_name) for p in (await self.user_projects)): + async def check_access(self, project_name: safe.ProjectName) -> None: + if not any((p.name == str(project_name)) for p in (await self.user_projects)): if self.is_admin: # Admins can view all projects # But we must warn them when the project is not one of their own @@ -160,8 +162,8 @@ class Committer: async def release( self, - project_name: str, - version_name: str, + project_name: safe.ProjectName, + version_name: safe.VersionName, phase: sql.ReleasePhase | db.NotSet | None = db.NOT_SET, latest_revision_number: str | db.NotSet | None = db.NOT_SET, data: db.Session | None = None, @@ -180,7 +182,7 @@ class Committer: phase_value = sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT else: phase_value = phase - release_name = sql.release_name(project_name, version_name) + release_name = sql.release_name(str(project_name), str(version_name)) if data is None: async with db.session() as data: release = await data.release( --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
