This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git


The following commit(s) were added to refs/heads/main by this push:
     new 709c232  Resolve discriminated unions when extracting fields
709c232 is described below

commit 709c23259f88521254f1c47278c24fb2a1fef977
Author: Sean B. Palmer <[email protected]>
AuthorDate: Mon Nov 10 16:26:57 2025 +0000

    Resolve discriminated unions when extracting fields
---
 atr/form.py            | 54 +++++++++++++++++++++++++++++++++++++++++++++-----
 atr/static/css/atr.css |  4 ++++
 2 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/atr/form.py b/atr/form.py
index 2652858..7ca9d13 100644
--- a/atr/form.py
+++ b/atr/form.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import enum
 import json
 import types
-from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, get_args, 
get_origin
+from typing import TYPE_CHECKING, Annotated, Any, Final, Literal, 
TypeAliasType, get_args, get_origin
 
 import htpy
 import markupsafe
@@ -68,11 +68,18 @@ class Widget(enum.Enum):
 
 
 def flash_error_data(
-    form_cls: type[Form], errors: list[pydantic_core.ErrorDetails], form_data: 
dict[str, Any]
+    form_cls: type[Form] | TypeAliasType, errors: 
list[pydantic_core.ErrorDetails], form_data: dict[str, Any]
 ) -> dict[str, Any]:
     flash_data = {}
     error_field_names = set()
 
+    # It is not valid Python syntax to use type[Form]() in a match branch
+    if isinstance(form_cls, TypeAliasType):
+        discriminator_value = _discriminator_from_errors(errors)
+        concrete_cls = _get_concrete_cls(form_cls, discriminator_value)
+    else:
+        concrete_cls = form_cls
+
     for i, error in enumerate(errors):
         loc = error["loc"]
         kind = error["type"]
@@ -80,7 +87,7 @@ def flash_error_data(
         msg = msg.replace(": An email address", " because an email address")
         msg = msg.replace("Value error, ", "")
         original = error["input"]
-        field_name, field_label = name_and_label(form_cls, i, loc)
+        field_name, field_label = name_and_label(concrete_cls, i, loc)
         flash_data[field_name] = {
             "label": field_label,
             "original": json_suitable(original),
@@ -98,7 +105,7 @@ def flash_error_data(
 
 
 def flash_error_summary(errors: list[pydantic_core.ErrorDetails], flash_data: 
dict[str, Any]) -> markupsafe.Markup:
-    div = htm.Block()
+    div = htm.Block(htm.div, classes=".atr-initial")
     plural = len(errors) > 1
     div.text(f"Please fix the following issue{'s' if plural else ''}:")
     with div.block(htm.ul, classes=".mt-2.mb-0") as ul:
@@ -144,8 +151,9 @@ def name_and_label(form_cls: type[Form], i: int, loc: 
tuple[str | int, ...]) ->
             else:
                 field_label = field_name.replace("_", " ").title()
             return field_name, field_label
+    # Might be a model validation error
     field_name = f".{i}"
-    field_label = "?"
+    field_label = "*"
     return field_name, field_label
 
 
@@ -180,6 +188,32 @@ async def quart_request() -> dict[str, Any]:
     return combined_data
 
 
+def _discriminator_from_errors(errors: list[pydantic_core.ErrorDetails]) -> 
str:
+    discriminator_value = None
+    for error in errors:
+        loc = error["loc"]
+        if loc and isinstance(loc[0], str):
+            discriminator_value = loc[0]
+            error["loc"] = loc[1:]
+    if discriminator_value is None:
+        raise ValueError("Discriminator not found")
+    return discriminator_value
+
+
+def _get_concrete_cls(form_cls: TypeAliasType, discriminator_value: str) -> 
type[Form]:
+    alias_value = form_cls.__value__
+    while get_origin(alias_value) is Annotated:
+        alias_value = get_args(alias_value)[0]
+    members = get_args(alias_value)
+    if not members:
+        raise ValueError(f"No members found for union type: {alias_value}")
+    for member in members:
+        field = member.model_fields.get(DISCRIMINATOR_NAME)
+        if field and (field.default == discriminator_value):
+            return member
+    raise ValueError(f"Discriminator value {discriminator_value} not found in 
union type: {alias_value}")
+
+
 def _get_flash_error_data() -> dict[str, Any]:
     flashed_error_messages = 
quart.get_flashed_messages(category_filter=["form-error-data"])
     if flashed_error_messages:
@@ -260,6 +294,11 @@ async def render(
     return htm.form(form_classes, action=action, method="post", 
enctype="multipart/form-data")[form_children]
 
 
+async def render_block(block: htm.Block, *args, **kwargs) -> None:
+    rendered = await render(*args, **kwargs)
+    block.append(rendered)
+
+
 def session(info: pydantic.ValidationInfo) -> web.Committer | None:
     ctx: dict[str, Any] = info.context or {}
     return ctx.get("session")
@@ -325,6 +364,8 @@ Bool = Annotated[
 
 Email = pydantic.EmailStr
 
+URL = pydantic.HttpUrl
+
 File = Annotated[
     datastructures.FileStorage,
     functional_validators.BeforeValidator(to_filestorage),
@@ -599,6 +640,9 @@ def _get_widget_type(field_info: pydantic.fields.FieldInfo) 
-> Widget:  # noqa:
     if annotation is pydantic.EmailStr:
         return Widget.EMAIL
 
+    if annotation is pydantic.HttpUrl:
+        return Widget.URL
+
     if annotation in (int, float):
         return Widget.NUMBER
 
diff --git a/atr/static/css/atr.css b/atr/static/css/atr.css
index 4561a0a..e962a4f 100644
--- a/atr/static/css/atr.css
+++ b/atr/static/css/atr.css
@@ -515,6 +515,10 @@ aside.sidebar nav a:hover {
     display: none;
 }
 
+.atr-initial {
+    display: initial;
+}
+
 .atr-pre-wrap {
     white-space: pre-wrap;
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to