This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push:
new 9881905 Use Quart-Schema for API POST input validation
9881905 is described below
commit 9881905bcf5899bc7945ad87098397cc07ce5e14
Author: Sean B. Palmer <[email protected]>
AuthorDate: Mon Jul 14 19:14:22 2025 +0100
Use Quart-Schema for API POST input validation
---
atr/blueprints/api/__init__.py | 17 ++++++++-
atr/blueprints/api/api.py | 85 ++++++++++++++++++++----------------------
atr/models/api.py | 27 ++++++--------
atr/models/results.py | 2 -
4 files changed, 67 insertions(+), 64 deletions(-)
diff --git a/atr/blueprints/api/__init__.py b/atr/blueprints/api/__init__.py
index c08aa9d..fefb170 100644
--- a/atr/blueprints/api/__init__.py
+++ b/atr/blueprints/api/__init__.py
@@ -16,10 +16,13 @@
# under the License.
import sys
+from typing import Any
import asfquart.base as base
+import pydantic
import quart
import quart.blueprints as blueprints
+import quart_schema
import werkzeug.exceptions as exceptions
BLUEPRINT = quart.Blueprint("api_blueprint", __name__, url_prefix="/api")
@@ -52,7 +55,17 @@ async def _handle_not_found(err: exceptions.NotFound) ->
tuple[quart.Response, i
return _json_error(err.description or err.name, 404)
-def _json_error(message: str, status_code: int | None) ->
tuple[quart.Response, int]:
[email protected](quart_schema.RequestSchemaValidationError)
+async def _handle_request_validation(err:
quart_schema.RequestSchemaValidationError) -> tuple[quart.Response, int]:
+ if not isinstance(err.validation_error, pydantic.ValidationError):
+ raise err.validation_error
+ verr: pydantic.ValidationError = err.validation_error
+ return _json_error("Input validation failed", 400, {"validation_details":
verr.errors()})
+
+
+def _json_error(
+ message: str, status_code: int | None, extra: dict[str, Any] | None = None
+) -> tuple[quart.Response, int]:
payload = {"error": message}
show_traceback = False
if show_traceback:
@@ -60,6 +73,8 @@ def _json_error(message: str, status_code: int | None) ->
tuple[quart.Response,
traceback_str = "".join(traceback.format_exception(*sys.exc_info()))
payload["traceback"] = traceback_str
+ if extra is not None:
+ payload.update(extra)
return quart.jsonify(payload), status_code or 500
diff --git a/atr/blueprints/api/api.py b/atr/blueprints/api/api.py
index 80c59ca..f50661f 100644
--- a/atr/blueprints/api/api.py
+++ b/atr/blueprints/api/api.py
@@ -99,7 +99,7 @@ async def checks_ongoing_project_version(
) -> tuple[Mapping[str, Any], int]:
"""Return a count of all unfinished check results for a given release."""
_simple_check(project, version, revision)
- ongoing_tasks_count, latest_revision = await
interaction.tasks_ongoing_revision(project, version, revision)
+ ongoing_tasks_count, _latest_revision = await
interaction.tasks_ongoing_revision(project, version, revision)
# TODO: Is there a way to return just an int?
# The ResponseReturnValue type in quart does not allow int
# And if we use quart.jsonify, we must return quart.Response which
quart_schema tries to validate
@@ -158,15 +158,14 @@ async def committees_name_projects(name: str) ->
tuple[list[Mapping], int]:
@api.BLUEPRINT.route("/draft/delete", methods=["POST"])
@jwtoken.require
@quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.ProjectVersion)
@quart_schema.validate_response(dict[str, str], 200)
-async def draft_delete_project_version() -> tuple[dict[str, str], int]:
- payload = await _payload_get()
- req = models.api.DraftDeleteRequest.model_validate(payload)
+async def draft_delete_project_version(data: models.api.ProjectVersion) ->
tuple[dict[str, str], int]:
asf_uid = _jwt_asf_uid()
- async with db.session() as data:
- release_name = sql.release_name(req.project_name, req.version)
- release = await data.release(
+ async with db.session() as db_data:
+ release_name = sql.release_name(data.project, data.version)
+ release = await db_data.release(
name=release_name, phase=sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
_committee=True
).demand(exceptions.NotFound())
if not (user.is_committee_member(release.project.committee, asf_uid)
or user.is_admin(asf_uid)):
@@ -180,7 +179,7 @@ async def draft_delete_project_version() -> tuple[dict[str,
str], int]:
await interaction.release_delete(
release_name, phase=sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
include_downloads=False
)
- await data.commit()
+ await db_data.commit()
return {"deleted": release_name}, 200
@@ -207,22 +206,20 @@ async def list_project_version(
@api.BLUEPRINT.route("/jwt", methods=["POST"])
-async def pat_jwt_post() -> quart.Response:
+@quart_schema.validate_request(models.api.AsfuidPat)
+async def pat_jwt_post(data: models.api.AsfuidPat) -> quart.Response:
"""Generate a JWT from a valid PAT."""
# Expects {"asfuid": "uid", "pat": "pat-token"}
# Returns {"asfuid": "uid", "jwt": "jwt-token"}
-
- payload = await _payload_get()
- pat_request = models.api.PATJWTRequest.model_validate(payload)
- token_hash = hashlib.sha3_256(pat_request.pat.encode()).hexdigest()
- pat_rec = await _get_pat(pat_request.asfuid, token_hash)
+ token_hash = hashlib.sha3_256(data.pat.encode()).hexdigest()
+ pat_rec = await _get_pat(data.asfuid, token_hash)
now = datetime.datetime.now(datetime.UTC)
if (pat_rec is None) or (pat_rec.expires < now):
return quart.Response("Invalid PAT", status=401)
- jwt_token = jwtoken.issue(pat_request.asfuid)
- return quart.jsonify({"asfuid": pat_request.asfuid, "jwt": jwt_token})
+ jwt_token = jwtoken.issue(data.asfuid)
+ return quart.jsonify({"asfuid": data.asfuid, "jwt": jwt_token})
@api.BLUEPRINT.route("/keys")
@@ -321,18 +318,16 @@ async def releases(query_args: models.api.Releases) ->
quart.Response:
@api.BLUEPRINT.route("/releases/create", methods=["POST"])
@jwtoken.require
@quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.ProjectVersion)
@quart_schema.validate_response(sql.Release, 201)
-async def releases_create() -> tuple[Mapping, int]:
+async def releases_create(data: models.api.ProjectVersion) -> tuple[Mapping,
int]:
"""Create a new release draft for a project via POSTed JSON."""
-
- payload = await _payload_get()
- request_data = models.api.ReleaseCreateRequest.model_validate(payload)
asf_uid = _jwt_asf_uid()
try:
release, _project = await start.create_release_draft(
- project_name=request_data.project_name,
- version=request_data.version,
+ project_name=data.project,
+ version=data.version,
asf_uid=asf_uid,
)
except routes.FlashError as exc:
@@ -462,13 +457,21 @@ async def tasks(query_args: models.api.Task) ->
quart.Response:
return quart.jsonify(result)
[email protected]("/vote/resolve", methods=["POST"])
[email protected]
+@quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.VoteStart)
+@quart_schema.validate_response(sql.Task, 201)
+async def vote_resolve(req: models.api.VoteStart) -> tuple[Mapping, int]:
+ return {}, 200
+
+
@api.BLUEPRINT.route("/vote/start", methods=["POST"])
@jwtoken.require
@quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.VoteStart)
@quart_schema.validate_response(sql.Task, 201)
-async def vote_start() -> tuple[Mapping, int]:
- payload = await _payload_get()
- req = models.api.VoteStartRequest.model_validate(payload)
+async def vote_start(req: models.api.VoteStart) -> tuple[Mapping, int]:
asf_uid = _jwt_asf_uid()
permitted_recipients = util.permitted_recipients(asf_uid)
@@ -476,7 +479,7 @@ async def vote_start() -> tuple[Mapping, int]:
raise exceptions.Forbidden("Invalid mailing list choice")
async with db.session() as data:
- release_name = sql.release_name(req.project_name, req.version)
+ release_name = sql.release_name(req.project, req.version)
release = await data.release(name=release_name, _project=True,
_committee=True).demand(exceptions.NotFound())
if not (user.is_committee_member(release.committee, asf_uid) or
user.is_admin(asf_uid)):
@@ -503,7 +506,7 @@ async def vote_start() -> tuple[Mapping, int]:
subject=req.subject,
body=req.body,
).model_dump(),
- project_name=req.project_name,
+ project_name=req.project,
version_name=req.version,
)
data.add(task)
@@ -514,19 +517,18 @@ async def vote_start() -> tuple[Mapping, int]:
@api.BLUEPRINT.route("/upload", methods=["POST"])
@jwtoken.require
@quart_schema.security_scheme([{"BearerAuth": []}])
+@quart_schema.validate_request(models.api.ProjectVersionRelpathContent)
@quart_schema.validate_response(sql.Revision, 201)
-async def upload() -> tuple[Mapping, int]:
- payload = await _payload_get()
- req = models.api.FileUploadRequest.model_validate(payload)
+async def upload(data: models.api.ProjectVersionRelpathContent) ->
tuple[Mapping, int]:
asf_uid = _jwt_asf_uid()
- async with db.session() as data:
- project = await data.project(name=req.project_name,
_committee=True).demand(exceptions.NotFound())
+ async with db.session() as db_data:
+ project = await db_data.project(name=data.project,
_committee=True).demand(exceptions.NotFound())
# TODO: user.is_participant(project, asf_uid)
if not (user.is_committee_member(project.committee, asf_uid) or
user.is_admin(asf_uid)):
raise exceptions.Forbidden("You do not have permission to upload
to this project")
- revision = await _upload_process_file(req, asf_uid)
+ revision = await _upload_process_file(data, asf_uid)
return revision.model_dump(), 201
@@ -556,24 +558,17 @@ def _pagination_args_validate(query_args:
models.api.Pagination) -> None:
raise exceptions.BadRequest("Maximum limit of 1000 exceeded")
-async def _payload_get() -> dict:
- payload = await quart.request.get_json(force=True, silent=False)
- if not isinstance(payload, dict):
- raise exceptions.BadRequest("Invalid JSON")
- return payload
-
-
def _simple_check(*args: str | None) -> None:
for arg in args:
if arg == "None":
raise exceptions.BadRequest("Argument cannot be the string 'None'")
-async def _upload_process_file(req: models.api.FileUploadRequest, asf_uid:
str) -> sql.Revision:
- file_bytes = base64.b64decode(req.content, validate=True)
- file_path = req.rel_path.lstrip("/")
+async def _upload_process_file(args: models.api.ProjectVersionRelpathContent,
asf_uid: str) -> sql.Revision:
+ file_bytes = base64.b64decode(args.content, validate=True)
+ file_path = args.relpath.lstrip("/")
description = f"Upload via API: {file_path}"
- async with revision.create_and_manage(req.project_name, req.version,
asf_uid, description=description) as creating:
+ async with revision.create_and_manage(args.project, args.version, asf_uid,
description=description) as creating:
target_path = pathlib.Path(creating.interim_path) / file_path
await aiofiles.os.makedirs(target_path.parent, exist_ok=True)
async with aiofiles.open(target_path, "wb") as f:
@@ -581,5 +576,5 @@ async def _upload_process_file(req:
models.api.FileUploadRequest, asf_uid: str)
if creating.new is None:
raise exceptions.InternalServerError("Failed to create revision")
async with db.session() as data:
- release_name = sql.release_name(req.project_name, req.version)
+ release_name = sql.release_name(args.project, args.version)
return await data.revision(release_name=release_name,
number=creating.new.number).demand(exceptions.NotFound())
diff --git a/atr/models/api.py b/atr/models/api.py
index f88f552..bb668ac 100644
--- a/atr/models/api.py
+++ b/atr/models/api.py
@@ -38,34 +38,29 @@ class Task(Pagination):
status: str | None = None
-class DraftDeleteRequest(schema.Strict):
- project_name: str
- version: str
+class AsfuidPat(schema.Strict):
+ asfuid: str
+ pat: str
-class FileUploadRequest(schema.Strict):
- project_name: str
+class ProjectVersion(schema.Strict):
+ project: str
version: str
- rel_path: str
- content: str
-class PATJWTRequest(schema.Strict):
- asfuid: str
- pat: str
-
-
-class ReleaseCreateRequest(schema.Strict):
- project_name: str
+class ProjectVersionRelpathContent(schema.Strict):
+ project: str
version: str
+ relpath: str
+ content: str
class ResultCount(schema.Strict):
count: int
-class VoteStartRequest(schema.Strict):
- project_name: str
+class VoteStart(schema.Strict):
+ project: str
version: str
revision: str
email_to: str
diff --git a/atr/models/results.py b/atr/models/results.py
index 735b4e6..5f6c6f6 100644
--- a/atr/models/results.py
+++ b/atr/models/results.py
@@ -21,8 +21,6 @@ from pydantic import TypeAdapter
from . import schema
-# TODO: If we put this in atr.tasks.results, we get a circular import error
-
class HashingCheck(schema.Strict):
"""Result of the task to check the hash of a file."""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]