This is an automated email from the ASF dual-hosted git repository. sbp pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 7a8534d11bfb2f0e716013595908c591b16e9fdf Author: Sean B. Palmer <[email protected]> AuthorDate: Sun Nov 9 11:29:36 2025 +0000 Allow custom field errors after Pydantic validation --- atr/blueprints/post.py | 63 ++++++++++++++++++++++++++++------------------- atr/form.py | 17 ++++++++++++- atr/htm.py | 7 ++++-- atr/post/test.py | 5 ++++ atr/principal.py | 2 +- atr/web.py | 67 ++++++++++++++++++++++++++++++++++++++++---------- 6 files changed, 119 insertions(+), 42 deletions(-) diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py index 1e1286d..5498b3e 100644 --- a/atr/blueprints/post.py +++ b/atr/blueprints/post.py @@ -24,12 +24,10 @@ from typing import Any import asfquart.auth as auth import asfquart.base as base import asfquart.session -import markupsafe import pydantic import quart import atr.form -import atr.htm as htm import atr.log as log import atr.web as web @@ -108,21 +106,34 @@ def empty() -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable # pass def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: async def wrapper(session: web.Committer | None, *args: Any, **kwargs: Any) -> Any: + match session: + case web.Committer() as committer: + form_data = await committer.form_data() + case None: + form_data = await atr.form.quart_request() try: - form_data = await atr.form.quart_request() - context = {"session": session} - atr.form.validate(atr.form.Empty, form_data, context) + context = { + "args": args, + "kwargs": kwargs, + "session": session, + } + match session: + case web.Committer() as committer: + await committer.form_validate(atr.form.Empty, context=context) + case None: + atr.form.validate(atr.form.Empty, form_data, context=context) return await func(session, *args, **kwargs) except pydantic.ValidationError: - # This presumably should not happen, because the CSRF token checker reaches it first - msg = "Sorry, your form session expired for security reasons. Please try again." + # This could happen if the form was tampered with + # It should not happen if the CSRF token is invalid + msg = "Sorry, there was an empty form validation error. Please try again." await quart.flash(msg, "error") return quart.redirect(quart.request.path) - wrapper.__name__ = func.__name__ - wrapper.__module__ = func.__module__ - wrapper.__doc__ = func.__doc__ wrapper.__annotations__ = func.__annotations__.copy() + wrapper.__doc__ = func.__doc__ + wrapper.__module__ = func.__module__ + wrapper.__name__ = func.__name__ return wrapper return decorator @@ -133,31 +144,33 @@ def form( ) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]: def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: async def wrapper(session: web.Committer | None, *args: Any, **kwargs: Any) -> Any: - form_data = await atr.form.quart_request() + match session: + case web.Committer() as committer: + form_data = await committer.form_data() + case None: + form_data = await atr.form.quart_request() try: - context = {"session": session} - validated_form = atr.form.validate(form_cls, form_data, context) + context = { + "args": args, + "kwargs": kwargs, + "session": session, + } + match session: + case web.Committer() as committer: + validated_form = await committer.form_validate(form_cls, context) + case None: + validated_form = atr.form.validate(form_cls, form_data, context=context) return await func(session, validated_form, *args, **kwargs) except pydantic.ValidationError as e: errors = e.errors() if len(errors) == 0: raise RuntimeError("Validation failed, but no errors were reported") flash_data = atr.form.flash_error_data(form_cls, errors, form_data) - - plural = len(errors) > 1 - summary = f"Please fix the following issue{'s' if plural else ''}:" - ul = htm.Block(htm.ul, classes=".mt-2.mb-0") - for i, flash_datum in enumerate(flash_data.values()): - if i > 9: - ul.li["And more, not shown here..."] - break - if "msg" in flash_datum: - ul.li[htm.strong[flash_datum["label"]], ": ", flash_datum["msg"]] - summary = f"{summary}\n{ul.collect()}" + summary = atr.form.flash_error_summary(errors, flash_data) # TODO: Centralise all uses of markupsafe.Markup # log.info(f"Flash data: {flash_data}") - await quart.flash(markupsafe.Markup(summary), category="error") + await quart.flash(summary, category="error") await quart.flash(json.dumps(flash_data), category="form-error-data") return quart.redirect(quart.request.path) diff --git a/atr/form.py b/atr/form.py index 06744e4..81ad0d7 100644 --- a/atr/form.py +++ b/atr/form.py @@ -23,6 +23,7 @@ import types from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, get_args, get_origin import htpy +import markupsafe import pydantic import pydantic.functional_validators as functional_validators import quart @@ -35,7 +36,6 @@ import atr.models.schema as schema if TYPE_CHECKING: from collections.abc import Iterator - import markupsafe import pydantic_core import atr.web as web @@ -96,6 +96,21 @@ def flash_error_data( return flash_data +def flash_error_summary(errors: list[pydantic_core.ErrorDetails], flash_data: dict[str, Any]) -> markupsafe.Markup: + div = htm.Block() + plural = len(errors) > 1 + div.text(f"Please fix the following issue{'s' if plural else ''}:") + with div.block(htm.ul, classes=".mt-2.mb-0") as ul: + for i, flash_datum in enumerate(flash_data.values()): + if i > 9: + ul.li["And more, not shown here..."] + break + if "msg" in flash_datum: + ul.li[htm.strong[flash_datum["label"]], ": ", flash_datum["msg"]] + summary = div.collect() + return markupsafe.Markup(summary) + + def json_suitable(field_value: Any) -> Any: if isinstance(field_value, datastructures.FileStorage): return field_value.filename diff --git a/atr/htm.py b/atr/htm.py index 720febe..530800d 100644 --- a/atr/htm.py +++ b/atr/htm.py @@ -129,9 +129,12 @@ class Block: @contextlib.contextmanager def block( - self, element: Element | None = None, separator: Element | VoidElement | str | None = None + self, + element: Element | None = None, + classes: str | None = None, + separator: Element | VoidElement | str | None = None, ) -> Generator[Block, Any, Any]: - block = Block(element) + block = Block(element, classes=classes) yield block self.append(block.collect(separator=separator)) diff --git a/atr/post/test.py b/atr/post/test.py index 3e3ef1e..d5a4a01 100644 --- a/atr/post/test.py +++ b/atr/post/test.py @@ -55,6 +55,11 @@ async def test_multiple(session: web.Committer | None, form: shared.test.Multipl async def test_single(session: web.Committer | None, form: shared.test.SingleForm) -> web.WerkzeugResponse: file_names = [f.filename for f in form.files] if form.files else [] compatibility_names = [f.value for f in form.compatibility] if form.compatibility else [] + if (form.message == "Forbidden message!") and (session is not None): + return await session.form_error( + "message", + "You are not permitted to submit the forbidden message", + ) msg = ( f"Single form received:" f" name={form.name}," diff --git a/atr/principal.py b/atr/principal.py index 66c7d99..911dce1 100644 --- a/atr/principal.py +++ b/atr/principal.py @@ -356,7 +356,7 @@ class Authorisation(AsyncObject): case ArgumentNoneType() | web.Committer(): match asf_uid: case web.Committer(): - asfquart_session = asf_uid._session + asfquart_session = asf_uid.session case _: asfquart_session = await asfquart.session.read() # asfquart_session = await session.read() diff --git a/atr/web.py b/atr/web.py index a843fde..da2c0b5 100644 --- a/atr/web.py +++ b/atr/web.py @@ -17,16 +17,19 @@ from __future__ import annotations +import json import urllib.parse from typing import TYPE_CHECKING, Any, Protocol, TypeVar import asfquart.base as base import asfquart.session as session +import pydantic_core import quart import werkzeug.datastructures.headers import atr.config as config import atr.db as db +import atr.form as form import atr.htm as htm import atr.models.sql as sql import atr.user as user @@ -35,9 +38,9 @@ import atr.util as util if TYPE_CHECKING: from collections.abc import Awaitable, Sequence + import pydantic import werkzeug.wrappers.response as response - R = TypeVar("R", covariant=True) type WerkzeugResponse = response.Response @@ -58,19 +61,25 @@ class Committer: """Session with extra information about committers.""" def __init__(self, web_session: session.ClientSession) -> None: - self._projects: list[sql.Project] | None = None - self._session = web_session + self.__form_cls: type[form.Form] | None = None + self.__form_data: dict[str, Any] | None = None + self.__projects: list[sql.Project] | None = None + self.session = web_session @property def asf_uid(self) -> str: - if self._session.uid is None: + if self.session.uid is None: raise base.ASFQuartException("Not authenticated", errorcode=401) - return self._session.uid + return self.session.uid def __getattr__(self, name: str) -> Any: # TODO: Not type safe, should subclass properly if possible # For example, we can access session.no_such_attr and the type checkers won't notice - return getattr(self._session, name) + return getattr(self.session, name) + + @property + def app_host(self) -> str: + return config.get().APP_HOST async def check_access(self, project_name: str) -> None: if not any((p.name == project_name) for p in (await self.user_projects)): @@ -93,9 +102,37 @@ class Committer: return raise base.ASFQuartException("You do not have access to this committee", errorcode=403) - @property - def app_host(self) -> str: - return config.get().APP_HOST + async def form_data(self) -> dict[str, Any]: + if self.__form_data is None: + self.__form_data = await form.quart_request() + # Avoid mutations from writing back to our copy + return self.__form_data.copy() + + async def form_error(self, field_name: str, error_msg: str) -> WerkzeugResponse: + if self.__form_cls is None: + raise ValueError("Form class not set") + if self.__form_data is None: + raise ValueError("Form data not set") + errors = [ + pydantic_core.ErrorDetails( + loc=(field_name,), + msg=error_msg, + input=self.__form_data[field_name], + type="atr_error", + ) + ] + flash_data = form.flash_error_data(self.__form_cls, errors, self.__form_data) + summary = form.flash_error_summary(errors, flash_data) + + await quart.flash(summary, category="error") + await quart.flash(json.dumps(flash_data), category="form-error-data") + return quart.redirect(quart.request.path) + + async def form_validate(self, form_cls: type[form.Form], context: dict[str, Any]) -> pydantic.BaseModel: + self.__form_cls = form_cls + if self.__form_data is None: + self.__form_data = await form.quart_request() + return form.validate(form_cls, self.__form_data.copy(), context=context) @property def host(self) -> str: @@ -165,7 +202,7 @@ class Committer: @property async def user_candidate_drafts(self) -> list[sql.Release]: - return await user.candidate_drafts(self.uid, user_projects=self._projects) + return await user.candidate_drafts(self.uid, user_projects=self.__projects) # @property # async def user_committees(self) -> list[models.Committee]: @@ -173,9 +210,9 @@ class Committer: @property async def user_projects(self) -> list[sql.Project]: - if self._projects is None: - self._projects = await user.projects(self.uid) - return self._projects[:] + if self.__projects is None: + self.__projects = await user.projects(self.uid) + return self.__projects[:] class ElementResponse(quart.Response): @@ -241,6 +278,10 @@ class ZipResponse(quart.Response): super().__init__(response, status=status, headers=raw_headers, mimetype="application/zip") +async def form_error(error: str) -> None: + pass + + async def redirect[R]( route: RouteFunction[R], success: str | None = None, error: str | None = None, **kwargs: Any ) -> WerkzeugResponse: --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
