This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
The following commit(s) were added to refs/heads/main by this push:
new 709c232 Resolve discriminated unions when extracting fields
709c232 is described below
commit 709c23259f88521254f1c47278c24fb2a1fef977
Author: Sean B. Palmer <[email protected]>
AuthorDate: Mon Nov 10 16:26:57 2025 +0000
Resolve discriminated unions when extracting fields
---
atr/form.py | 54 +++++++++++++++++++++++++++++++++++++++++++++-----
atr/static/css/atr.css | 4 ++++
2 files changed, 53 insertions(+), 5 deletions(-)
diff --git a/atr/form.py b/atr/form.py
index 2652858..7ca9d13 100644
--- a/atr/form.py
+++ b/atr/form.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import enum
import json
import types
-from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, get_args,
get_origin
+from typing import TYPE_CHECKING, Annotated, Any, Final, Literal,
TypeAliasType, get_args, get_origin
import htpy
import markupsafe
@@ -68,11 +68,18 @@ class Widget(enum.Enum):
def flash_error_data(
- form_cls: type[Form], errors: list[pydantic_core.ErrorDetails], form_data:
dict[str, Any]
+ form_cls: type[Form] | TypeAliasType, errors:
list[pydantic_core.ErrorDetails], form_data: dict[str, Any]
) -> dict[str, Any]:
flash_data = {}
error_field_names = set()
+ # It is not valid Python syntax to use type[Form]() in a match branch
+ if isinstance(form_cls, TypeAliasType):
+ discriminator_value = _discriminator_from_errors(errors)
+ concrete_cls = _get_concrete_cls(form_cls, discriminator_value)
+ else:
+ concrete_cls = form_cls
+
for i, error in enumerate(errors):
loc = error["loc"]
kind = error["type"]
@@ -80,7 +87,7 @@ def flash_error_data(
msg = msg.replace(": An email address", " because an email address")
msg = msg.replace("Value error, ", "")
original = error["input"]
- field_name, field_label = name_and_label(form_cls, i, loc)
+ field_name, field_label = name_and_label(concrete_cls, i, loc)
flash_data[field_name] = {
"label": field_label,
"original": json_suitable(original),
@@ -98,7 +105,7 @@ def flash_error_data(
def flash_error_summary(errors: list[pydantic_core.ErrorDetails], flash_data:
dict[str, Any]) -> markupsafe.Markup:
- div = htm.Block()
+ div = htm.Block(htm.div, classes=".atr-initial")
plural = len(errors) > 1
div.text(f"Please fix the following issue{'s' if plural else ''}:")
with div.block(htm.ul, classes=".mt-2.mb-0") as ul:
@@ -144,8 +151,9 @@ def name_and_label(form_cls: type[Form], i: int, loc:
tuple[str | int, ...]) ->
else:
field_label = field_name.replace("_", " ").title()
return field_name, field_label
+ # Might be a model validation error
field_name = f".{i}"
- field_label = "?"
+ field_label = "*"
return field_name, field_label
@@ -180,6 +188,32 @@ async def quart_request() -> dict[str, Any]:
return combined_data
+def _discriminator_from_errors(errors: list[pydantic_core.ErrorDetails]) ->
str:
+ discriminator_value = None
+ for error in errors:
+ loc = error["loc"]
+ if loc and isinstance(loc[0], str):
+ discriminator_value = loc[0]
+ error["loc"] = loc[1:]
+ if discriminator_value is None:
+ raise ValueError("Discriminator not found")
+ return discriminator_value
+
+
+def _get_concrete_cls(form_cls: TypeAliasType, discriminator_value: str) ->
type[Form]:
+ alias_value = form_cls.__value__
+ while get_origin(alias_value) is Annotated:
+ alias_value = get_args(alias_value)[0]
+ members = get_args(alias_value)
+ if not members:
+ raise ValueError(f"No members found for union type: {alias_value}")
+ for member in members:
+ field = member.model_fields.get(DISCRIMINATOR_NAME)
+ if field and (field.default == discriminator_value):
+ return member
+ raise ValueError(f"Discriminator value {discriminator_value} not found in
union type: {alias_value}")
+
+
def _get_flash_error_data() -> dict[str, Any]:
flashed_error_messages =
quart.get_flashed_messages(category_filter=["form-error-data"])
if flashed_error_messages:
@@ -260,6 +294,11 @@ async def render(
return htm.form(form_classes, action=action, method="post",
enctype="multipart/form-data")[form_children]
+async def render_block(block: htm.Block, *args, **kwargs) -> None:
+ rendered = await render(*args, **kwargs)
+ block.append(rendered)
+
+
def session(info: pydantic.ValidationInfo) -> web.Committer | None:
ctx: dict[str, Any] = info.context or {}
return ctx.get("session")
@@ -325,6 +364,8 @@ Bool = Annotated[
Email = pydantic.EmailStr
+URL = pydantic.HttpUrl
+
File = Annotated[
datastructures.FileStorage,
functional_validators.BeforeValidator(to_filestorage),
@@ -599,6 +640,9 @@ def _get_widget_type(field_info: pydantic.fields.FieldInfo)
-> Widget: # noqa:
if annotation is pydantic.EmailStr:
return Widget.EMAIL
+ if annotation is pydantic.HttpUrl:
+ return Widget.URL
+
if annotation in (int, float):
return Widget.NUMBER
diff --git a/atr/static/css/atr.css b/atr/static/css/atr.css
index 4561a0a..e962a4f 100644
--- a/atr/static/css/atr.css
+++ b/atr/static/css/atr.css
@@ -515,6 +515,10 @@ aside.sidebar nav a:hover {
display: none;
}
+.atr-initial {
+ display: initial;
+}
+
.atr-pre-wrap {
white-space: pre-wrap;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]