This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new 9881905  Use Quart-Schema for API POST input validation
9881905 is described below

commit 9881905bcf5899bc7945ad87098397cc07ce5e14
Author: Sean B. Palmer <[email protected]>
AuthorDate: Mon Jul 14 19:14:22 2025 +0100

    Use Quart-Schema for API POST input validation
---
 atr/blueprints/api/__init__.py | 17 ++++++++-
 atr/blueprints/api/api.py      | 85 ++++++++++++++++++++----------------------
 atr/models/api.py              | 27 ++++++--------
 atr/models/results.py          |  2 -
 4 files changed, 67 insertions(+), 64 deletions(-)

diff --git a/atr/blueprints/api/__init__.py b/atr/blueprints/api/__init__.py
index c08aa9d..fefb170 100644
--- a/atr/blueprints/api/__init__.py
+++ b/atr/blueprints/api/__init__.py
@@ -16,10 +16,13 @@
 # under the License.
 
 import sys
+from typing import Any
 
 import asfquart.base as base
+import pydantic
 import quart
 import quart.blueprints as blueprints
+import quart_schema
 import werkzeug.exceptions as exceptions
 
 BLUEPRINT = quart.Blueprint("api_blueprint", __name__, url_prefix="/api")
@@ -52,7 +55,17 @@ async def _handle_not_found(err: exceptions.NotFound) -> 
tuple[quart.Response, i
     return _json_error(err.description or err.name, 404)
 
 
-def _json_error(message: str, status_code: int | None) -> 
tuple[quart.Response, int]:
[email protected](quart_schema.RequestSchemaValidationError)
+async def _handle_request_validation(err: 
quart_schema.RequestSchemaValidationError) -> tuple[quart.Response, int]:
+    if not isinstance(err.validation_error, pydantic.ValidationError):
+        raise err.validation_error
+    verr: pydantic.ValidationError = err.validation_error
+    return _json_error("Input validation failed", 400, {"validation_details": 
verr.errors()})
+
+
+def _json_error(
+    message: str, status_code: int | None, extra: dict[str, Any] | None = None
+) -> tuple[quart.Response, int]:
     payload = {"error": message}
     show_traceback = False
     if show_traceback:
@@ -60,6 +73,8 @@ def _json_error(message: str, status_code: int | None) -> 
tuple[quart.Response,
 
         traceback_str = "".join(traceback.format_exception(*sys.exc_info()))
         payload["traceback"] = traceback_str
+    if extra is not None:
+        payload.update(extra)
     return quart.jsonify(payload), status_code or 500
 
 
diff --git a/atr/blueprints/api/api.py b/atr/blueprints/api/api.py
index 80c59ca..f50661f 100644
--- a/atr/blueprints/api/api.py
+++ b/atr/blueprints/api/api.py
@@ -99,7 +99,7 @@ async def checks_ongoing_project_version(
 ) -> tuple[Mapping[str, Any], int]:
     """Return a count of all unfinished check results for a given release."""
     _simple_check(project, version, revision)
-    ongoing_tasks_count, latest_revision = await 
interaction.tasks_ongoing_revision(project, version, revision)
+    ongoing_tasks_count, _latest_revision = await 
interaction.tasks_ongoing_revision(project, version, revision)
     # TODO: Is there a way to return just an int?
     # The ResponseReturnValue type in quart does not allow int
     # And if we use quart.jsonify, we must return quart.Response which 
quart_schema tries to validate
@@ -158,15 +158,14 @@ async def committees_name_projects(name: str) -> 
tuple[list[Mapping], int]:
 @api.BLUEPRINT.route("/draft/delete", methods=["POST"])
 @jwtoken.require
 @quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.ProjectVersion)
 @quart_schema.validate_response(dict[str, str], 200)
-async def draft_delete_project_version() -> tuple[dict[str, str], int]:
-    payload = await _payload_get()
-    req = models.api.DraftDeleteRequest.model_validate(payload)
+async def draft_delete_project_version(data: models.api.ProjectVersion) -> 
tuple[dict[str, str], int]:
     asf_uid = _jwt_asf_uid()
 
-    async with db.session() as data:
-        release_name = sql.release_name(req.project_name, req.version)
-        release = await data.release(
+    async with db.session() as db_data:
+        release_name = sql.release_name(data.project, data.version)
+        release = await db_data.release(
             name=release_name, phase=sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT, 
_committee=True
         ).demand(exceptions.NotFound())
         if not (user.is_committee_member(release.project.committee, asf_uid) 
or user.is_admin(asf_uid)):
@@ -180,7 +179,7 @@ async def draft_delete_project_version() -> tuple[dict[str, 
str], int]:
         await interaction.release_delete(
             release_name, phase=sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT, 
include_downloads=False
         )
-        await data.commit()
+        await db_data.commit()
     return {"deleted": release_name}, 200
 
 
@@ -207,22 +206,20 @@ async def list_project_version(
 
 
 @api.BLUEPRINT.route("/jwt", methods=["POST"])
-async def pat_jwt_post() -> quart.Response:
+@quart_schema.validate_request(models.api.AsfuidPat)
+async def pat_jwt_post(data: models.api.AsfuidPat) -> quart.Response:
     """Generate a JWT from a valid PAT."""
     # Expects {"asfuid": "uid", "pat": "pat-token"}
     # Returns {"asfuid": "uid", "jwt": "jwt-token"}
-
-    payload = await _payload_get()
-    pat_request = models.api.PATJWTRequest.model_validate(payload)
-    token_hash = hashlib.sha3_256(pat_request.pat.encode()).hexdigest()
-    pat_rec = await _get_pat(pat_request.asfuid, token_hash)
+    token_hash = hashlib.sha3_256(data.pat.encode()).hexdigest()
+    pat_rec = await _get_pat(data.asfuid, token_hash)
 
     now = datetime.datetime.now(datetime.UTC)
     if (pat_rec is None) or (pat_rec.expires < now):
         return quart.Response("Invalid PAT", status=401)
 
-    jwt_token = jwtoken.issue(pat_request.asfuid)
-    return quart.jsonify({"asfuid": pat_request.asfuid, "jwt": jwt_token})
+    jwt_token = jwtoken.issue(data.asfuid)
+    return quart.jsonify({"asfuid": data.asfuid, "jwt": jwt_token})
 
 
 @api.BLUEPRINT.route("/keys")
@@ -321,18 +318,16 @@ async def releases(query_args: models.api.Releases) -> 
quart.Response:
 @api.BLUEPRINT.route("/releases/create", methods=["POST"])
 @jwtoken.require
 @quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.ProjectVersion)
 @quart_schema.validate_response(sql.Release, 201)
-async def releases_create() -> tuple[Mapping, int]:
+async def releases_create(data: models.api.ProjectVersion) -> tuple[Mapping, 
int]:
     """Create a new release draft for a project via POSTed JSON."""
-
-    payload = await _payload_get()
-    request_data = models.api.ReleaseCreateRequest.model_validate(payload)
     asf_uid = _jwt_asf_uid()
 
     try:
         release, _project = await start.create_release_draft(
-            project_name=request_data.project_name,
-            version=request_data.version,
+            project_name=data.project,
+            version=data.version,
             asf_uid=asf_uid,
         )
     except routes.FlashError as exc:
@@ -462,13 +457,21 @@ async def tasks(query_args: models.api.Task) -> 
quart.Response:
         return quart.jsonify(result)
 
 
[email protected]("/vote/resolve", methods=["POST"])
[email protected]
+@quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.VoteStart)
+@quart_schema.validate_response(sql.Task, 201)
+async def vote_resolve(req: models.api.VoteStart) -> tuple[Mapping, int]:
+    return {}, 200
+
+
 @api.BLUEPRINT.route("/vote/start", methods=["POST"])
 @jwtoken.require
 @quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.VoteStart)
 @quart_schema.validate_response(sql.Task, 201)
-async def vote_start() -> tuple[Mapping, int]:
-    payload = await _payload_get()
-    req = models.api.VoteStartRequest.model_validate(payload)
+async def vote_start(req: models.api.VoteStart) -> tuple[Mapping, int]:
     asf_uid = _jwt_asf_uid()
 
     permitted_recipients = util.permitted_recipients(asf_uid)
@@ -476,7 +479,7 @@ async def vote_start() -> tuple[Mapping, int]:
         raise exceptions.Forbidden("Invalid mailing list choice")
 
     async with db.session() as data:
-        release_name = sql.release_name(req.project_name, req.version)
+        release_name = sql.release_name(req.project, req.version)
         release = await data.release(name=release_name, _project=True, 
_committee=True).demand(exceptions.NotFound())
 
         if not (user.is_committee_member(release.committee, asf_uid) or 
user.is_admin(asf_uid)):
@@ -503,7 +506,7 @@ async def vote_start() -> tuple[Mapping, int]:
                 subject=req.subject,
                 body=req.body,
             ).model_dump(),
-            project_name=req.project_name,
+            project_name=req.project,
             version_name=req.version,
         )
         data.add(task)
@@ -514,19 +517,18 @@ async def vote_start() -> tuple[Mapping, int]:
 @api.BLUEPRINT.route("/upload", methods=["POST"])
 @jwtoken.require
 @quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.ProjectVersionRelpathContent)
 @quart_schema.validate_response(sql.Revision, 201)
-async def upload() -> tuple[Mapping, int]:
-    payload = await _payload_get()
-    req = models.api.FileUploadRequest.model_validate(payload)
+async def upload(data: models.api.ProjectVersionRelpathContent) -> 
tuple[Mapping, int]:
     asf_uid = _jwt_asf_uid()
 
-    async with db.session() as data:
-        project = await data.project(name=req.project_name, 
_committee=True).demand(exceptions.NotFound())
+    async with db.session() as db_data:
+        project = await db_data.project(name=data.project, 
_committee=True).demand(exceptions.NotFound())
         # TODO: user.is_participant(project, asf_uid)
         if not (user.is_committee_member(project.committee, asf_uid) or 
user.is_admin(asf_uid)):
             raise exceptions.Forbidden("You do not have permission to upload 
to this project")
 
-    revision = await _upload_process_file(req, asf_uid)
+    revision = await _upload_process_file(data, asf_uid)
     return revision.model_dump(), 201
 
 
@@ -556,24 +558,17 @@ def _pagination_args_validate(query_args: 
models.api.Pagination) -> None:
         raise exceptions.BadRequest("Maximum limit of 1000 exceeded")
 
 
-async def _payload_get() -> dict:
-    payload = await quart.request.get_json(force=True, silent=False)
-    if not isinstance(payload, dict):
-        raise exceptions.BadRequest("Invalid JSON")
-    return payload
-
-
 def _simple_check(*args: str | None) -> None:
     for arg in args:
         if arg == "None":
             raise exceptions.BadRequest("Argument cannot be the string 'None'")
 
 
-async def _upload_process_file(req: models.api.FileUploadRequest, asf_uid: 
str) -> sql.Revision:
-    file_bytes = base64.b64decode(req.content, validate=True)
-    file_path = req.rel_path.lstrip("/")
+async def _upload_process_file(args: models.api.ProjectVersionRelpathContent, 
asf_uid: str) -> sql.Revision:
+    file_bytes = base64.b64decode(args.content, validate=True)
+    file_path = args.relpath.lstrip("/")
     description = f"Upload via API: {file_path}"
-    async with revision.create_and_manage(req.project_name, req.version, 
asf_uid, description=description) as creating:
+    async with revision.create_and_manage(args.project, args.version, asf_uid, 
description=description) as creating:
         target_path = pathlib.Path(creating.interim_path) / file_path
         await aiofiles.os.makedirs(target_path.parent, exist_ok=True)
         async with aiofiles.open(target_path, "wb") as f:
@@ -581,5 +576,5 @@ async def _upload_process_file(req: 
models.api.FileUploadRequest, asf_uid: str)
     if creating.new is None:
         raise exceptions.InternalServerError("Failed to create revision")
     async with db.session() as data:
-        release_name = sql.release_name(req.project_name, req.version)
+        release_name = sql.release_name(args.project, args.version)
         return await data.revision(release_name=release_name, 
number=creating.new.number).demand(exceptions.NotFound())
diff --git a/atr/models/api.py b/atr/models/api.py
index f88f552..bb668ac 100644
--- a/atr/models/api.py
+++ b/atr/models/api.py
@@ -38,34 +38,29 @@ class Task(Pagination):
     status: str | None = None
 
 
-class DraftDeleteRequest(schema.Strict):
-    project_name: str
-    version: str
+class AsfuidPat(schema.Strict):
+    asfuid: str
+    pat: str
 
 
-class FileUploadRequest(schema.Strict):
-    project_name: str
+class ProjectVersion(schema.Strict):
+    project: str
     version: str
-    rel_path: str
-    content: str
 
 
-class PATJWTRequest(schema.Strict):
-    asfuid: str
-    pat: str
-
-
-class ReleaseCreateRequest(schema.Strict):
-    project_name: str
+class ProjectVersionRelpathContent(schema.Strict):
+    project: str
     version: str
+    relpath: str
+    content: str
 
 
 class ResultCount(schema.Strict):
     count: int
 
 
-class VoteStartRequest(schema.Strict):
-    project_name: str
+class VoteStart(schema.Strict):
+    project: str
     version: str
     revision: str
     email_to: str
diff --git a/atr/models/results.py b/atr/models/results.py
index 735b4e6..5f6c6f6 100644
--- a/atr/models/results.py
+++ b/atr/models/results.py
@@ -21,8 +21,6 @@ from pydantic import TypeAdapter
 
 from . import schema
 
-# TODO: If we put this in atr.tasks.results, we get a circular import error
-
 
 class HashingCheck(schema.Strict):
     """Result of the task to check the hash of a file."""


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to