This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
The following commit(s) were added to refs/heads/main by this push:
new 5886533 Preserve submitted values in forms with errors
5886533 is described below
commit 58865339f08e7ddc44072a5f9cdee470cc86c8ff
Author: Sean B. Palmer <[email protected]>
AuthorDate: Fri Nov 7 16:20:07 2025 +0000
Preserve submitted values in forms with errors
---
atr/blueprints/post.py | 23 ++-------
atr/form.py | 138 +++++++++++++++++++++++++++++++++++++------------
2 files changed, 109 insertions(+), 52 deletions(-)
diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py
index acbce1b..bee19da 100644
--- a/atr/blueprints/post.py
+++ b/atr/blueprints/post.py
@@ -132,8 +132,8 @@ def form(
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[...,
Awaitable[Any]]:
async def wrapper(session: web.Committer | None, *args: Any, **kwargs:
Any) -> Any:
+ form_data = await atr.form.quart_request()
try:
- form_data = await atr.form.quart_request()
context = {"session": session}
validated_form = atr.form.validate(form_cls, form_data,
context)
return await func(session, validated_form, *args, **kwargs)
@@ -141,22 +141,7 @@ def form(
errors = e.errors()
if len(errors) == 0:
raise RuntimeError("Validation failed, but no errors were
reported")
-
- flash_data = {}
- for i, error in enumerate(errors):
- loc = error["loc"]
- kind = error["type"]
- msg = error["msg"]
- msg = msg.replace(": An email address", " because an email
address")
- msg = msg.replace("Value error, ", "")
- original = error["input"]
- field_name, field_label =
atr.form.name_and_label(form_cls, i, loc)
- flash_data[field_name] = {
- "label": field_label,
- "original": original,
- "kind": kind,
- "msg": msg,
- }
+ flash_data = atr.form.flash_error_data(form_cls, errors,
form_data)
plural = len(errors) > 1
summary = f"Please fix the following issue{'s' if plural else
''}:"
@@ -165,10 +150,12 @@ def form(
if i > 9:
ul.li["And more, not shown here..."]
break
- ul.li[htm.strong[flash_datum["label"]], ": ",
flash_datum["msg"]]
+ if "msg" in flash_datum:
+ ul.li[htm.strong[flash_datum["label"]], ": ",
flash_datum["msg"]]
summary = f"{summary}\n{ul.collect()}"
# TODO: Centralise all uses of markupsafe.Markup
+ # log.info(f"Flash data: {flash_data}")
await quart.flash(markupsafe.Markup(summary), category="error")
await quart.flash(json.dumps(flash_data),
category="form-error-data")
return quart.redirect(quart.request.path)
diff --git a/atr/form.py b/atr/form.py
index 6c5c375..ab13cd6 100644
--- a/atr/form.py
+++ b/atr/form.py
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
from collections.abc import Iterator
import markupsafe
+ import pydantic_core
import atr.web as web
@@ -65,6 +66,47 @@ class Widget(enum.Enum):
URL = "url"
+def flash_error_data(
+ form_cls: type[Form], errors: list[pydantic_core.ErrorDetails], form_data:
dict[str, Any]
+) -> dict[str, Any]:
+ flash_data = {}
+ error_field_names = set()
+
+ for i, error in enumerate(errors):
+ loc = error["loc"]
+ kind = error["type"]
+ msg = error["msg"]
+ msg = msg.replace(": An email address", " because an email address")
+ msg = msg.replace("Value error, ", "")
+ original = error["input"]
+ field_name, field_label = name_and_label(form_cls, i, loc)
+ flash_data[field_name] = {
+ "label": field_label,
+ "original": json_suitable(original),
+ "kind": kind,
+ "msg": msg,
+ }
+ error_field_names.add(field_name)
+
+ for field_name, field_value in form_data.items():
+ if (field_name not in error_field_names) and (field_name !=
"csrf_token"):
+ flash_data[f"!{field_name}"] = {
+ "original": json_suitable(field_value),
+ }
+ return flash_data
+
+
+def json_suitable(field_value: Any) -> Any:
+ if isinstance(field_value, datastructures.FileStorage):
+ return field_value.filename
+ elif isinstance(field_value, list):
+ if all(isinstance(f, datastructures.FileStorage) for f in field_value):
+ return [f.filename for f in field_value]
+ else:
+ return field_value
+ return field_value
+
+
def label(description: str, *, default: Any = ..., widget: Widget | None =
None) -> Any:
extra: dict[str, Any] = {"widget": widget.value} if widget else {}
return pydantic.Field(default, description=description,
json_schema_extra=extra)
@@ -138,8 +180,6 @@ async def render_columns(
except (json.JSONDecodeError, IndexError):
pass
- label_classes = "col-sm-3 col-form-label text-sm-end"
-
field_rows: list[htm.Element] = []
hidden_fields: list[htm.Element | htm.VoidElement | markupsafe.Markup] = []
@@ -150,38 +190,11 @@ async def render_columns(
if field_name == "csrf_token":
continue
- if defaults:
- field_value = defaults.get(field_name)
- elif not field_info.is_required():
- # Use the Pydantic default if no user default are provided
- field_value = field_info.get_default(call_default_factory=True)
- else:
- field_value = None
- field_errors = errors.get(field_name) if errors else None
-
- if (field_name == DISCRIMINATOR_NAME) and (field_info.default is not
None):
- default_value = field_info.default
- hidden_fields.append(htpy.input(type="hidden",
name=DISCRIMINATOR_NAME, value=default_value))
- continue
-
- label_text = field_info.description or field_name.replace("_", "
").title()
- is_required = field_info.is_required()
-
- has_flash_error = field_name in flash_error_data
- label_classes_with_error = f"{label_classes} text-danger" if
has_flash_error else label_classes
- label_elem = htpy.label(for_=field_name,
class_=label_classes_with_error)[label_text]
-
- widget_elem = _render_widget(
- field_name=field_name,
- field_info=field_info,
- field_value=field_value,
- field_errors=field_errors,
- is_required=is_required,
- )
-
- row_div = htm.div(".mb-3.row")
- widget_div = htm.div(".col-sm-8")
- field_rows.append(row_div[label_elem, widget_div[widget_elem]])
+ hidden_field, row = _render_row(field_info, field_name,
flash_error_data, defaults, errors)
+ if hidden_field:
+ hidden_fields.append(hidden_field)
+ if row:
+ field_rows.append(row)
form_children: list[htm.Element | htm.VoidElement | markupsafe.Markup] =
hidden_fields + field_rows
@@ -539,3 +552,60 @@ def _get_widget_type(field_info:
pydantic.fields.FieldInfo) -> Widget: # noqa:
return Widget.FILES
return Widget.TEXT
+
+
+def _render_row(
+ field_info: pydantic.fields.FieldInfo,
+ field_name: str,
+ flash_error_data: dict[str, Any],
+ defaults: dict[str, Any] | None,
+ errors: dict[str, list[str]] | None,
+) -> tuple[htm.VoidElement | None, htm.Element | None]:
+ widget_type = _get_widget_type(field_info)
+ has_flash_error = field_name in flash_error_data
+ has_flash_data = f"!{field_name}" in flash_error_data
+
+ if widget_type: # not in (Widget.FILE, Widget.FILES):
+ if has_flash_error:
+ field_value = flash_error_data[field_name]["original"]
+ elif has_flash_data:
+ field_value = flash_error_data[f"!{field_name}"]["original"]
+ elif defaults:
+ field_value = defaults.get(field_name)
+ elif not field_info.is_required():
+ field_value = field_info.get_default(call_default_factory=True)
+ else:
+ field_value = None
+ elif defaults:
+ field_value = defaults.get(field_name)
+ elif not field_info.is_required():
+ field_value = field_info.get_default(call_default_factory=True)
+ else:
+ field_value = None
+
+ if (widget_type in (Widget.CHECKBOXES, Widget.FILES)) and (not
isinstance(field_value, list)):
+ field_value = [field_value]
+ field_errors = errors.get(field_name) if errors else None
+
+ if (field_name == DISCRIMINATOR_NAME) and (field_info.default is not None):
+ default_value = field_info.default
+ return htpy.input(type="hidden", name=DISCRIMINATOR_NAME,
value=default_value), None
+
+ label_text = field_info.description or field_name.replace("_", " ").title()
+ is_required = field_info.is_required()
+
+ label_classes = "col-sm-3 col-form-label text-sm-end"
+ label_classes_with_error = f"{label_classes} text-danger" if
has_flash_error else label_classes
+ label_elem = htpy.label(for_=field_name,
class_=label_classes_with_error)[label_text]
+
+ widget_elem = _render_widget(
+ field_name=field_name,
+ field_info=field_info,
+ field_value=field_value,
+ field_errors=field_errors,
+ is_required=is_required,
+ )
+
+ row_div = htm.div(".mb-3.row")
+ widget_div = htm.div(".col-sm-8")
+ return None, row_div[label_elem, widget_div[widget_elem]]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]