This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f7def2  Use stricter types and automatic validation for task results
6f7def2 is described below

commit 6f7def2f7c859593fc0e5cd460f411ea816041d5
Author: Sean B. Palmer <[email protected]>
AuthorDate: Tue Jul 1 14:42:38 2025 +0100

    Use stricter types and automatic validation for task results
---
 atr/db/models.py              | 25 ++++++++++++++-
 atr/results.py                | 75 +++++++++++++++++++++++++++++++++++++++++++
 atr/routes/compose.py         | 30 +++--------------
 atr/routes/release.py         |  4 ---
 atr/routes/resolve.py         | 24 +++-----------
 atr/routes/vote.py            | 12 +++++--
 atr/tasks/__init__.py         |  3 +-
 atr/tasks/checks/hashing.py   |  3 +-
 atr/tasks/checks/license.py   |  5 +--
 atr/tasks/checks/paths.py     |  3 +-
 atr/tasks/checks/rat.py       |  3 +-
 atr/tasks/checks/signature.py |  3 +-
 atr/tasks/checks/targz.py     |  5 +--
 atr/tasks/checks/zipformat.py |  5 +--
 atr/tasks/keys.py             |  3 +-
 atr/tasks/message.py          | 13 ++++----
 atr/tasks/sbom.py             |  8 +++--
 atr/tasks/svn.py              |  8 +++--
 atr/tasks/task.py             | 14 +++-----
 atr/tasks/vote.py             | 31 +++++++++---------
 atr/worker.py                 | 40 ++++-------------------
 playwright/test.py            |  4 +--
 22 files changed, 186 insertions(+), 135 deletions(-)

diff --git a/atr/db/models.py b/atr/db/models.py
index b026a30..ca7013a 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -31,6 +31,7 @@ import sqlalchemy.orm as orm
 import sqlalchemy.sql.expression as expression
 import sqlmodel
 
+import atr.results as results
 import atr.schema as schema
 
 sqlmodel.SQLModel.metadata = sqlalchemy.MetaData(
@@ -77,6 +78,28 @@ class UTCDateTime(sqlalchemy.types.TypeDecorator):
             return value
 
 
+class ResultsJSON(sqlalchemy.types.TypeDecorator):
+    impl = sqlalchemy.JSON
+    cache_ok = True
+
+    def process_bind_param(self, value, dialect):
+        if value is None:
+            return None
+        if hasattr(value, "model_dump"):
+            return value.model_dump()
+        if isinstance(value, dict):
+            return value
+        raise ValueError("Unsupported value for Results column")
+
+    def process_result_value(self, value, dialect):
+        if value is None:
+            return None
+        try:
+            return results.ResultsAdapter.validate_python(value)
+        except pydantic.ValidationError:
+            return None
+
+
 class UserRole(str, enum.Enum):
     COMMITTEE_MEMBER = "committee_member"
     RELEASE_MANAGER = "release_manager"
@@ -557,7 +580,7 @@ class Task(sqlmodel.SQLModel, table=True):
         default=None,
         sa_column=sqlalchemy.Column(UTCDateTime),
     )
-    result: Any | None = sqlmodel.Field(default=None, 
sa_column=sqlalchemy.Column(sqlalchemy.JSON))
+    result: results.Results | None = sqlmodel.Field(default=None, 
sa_column=sqlalchemy.Column(ResultsJSON))
     error: str | None = None
 
     # Used for check tasks
diff --git a/atr/results.py b/atr/results.py
new file mode 100644
index 0000000..cf87346
--- /dev/null
+++ b/atr/results.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Annotated, Literal
+
+from pydantic import TypeAdapter
+
+import atr.schema as schema
+
+# TODO: If we put this in atr.tasks.results, we get a circular import error
+
+
+class HashingCheck(schema.Strict):
+    """Result of the task to check the hash of a file."""
+
+    kind: Literal["hashing_check"] = schema.Field(alias="kind")
+    hash_algorithm: str = schema.description("The hash algorithm used")
+    hash_value: str = schema.description("The hash value of the file")
+    hash_file_path: str = schema.description("The path to the hash file")
+
+
+class MessageSend(schema.Strict):
+    """Result of the task to send an email."""
+
+    kind: Literal["message_send"] = schema.Field(alias="kind")
+    mid: str = schema.description("The message ID of the email")
+    mail_send_warnings: list[str] = schema.description("Warnings from the mail 
server")
+
+
+class SBOMGenerateCycloneDX(schema.Strict):
+    """Result of the task to generate a CycloneDX SBOM."""
+
+    kind: Literal["sbom_generate_cyclonedx"] = schema.Field(alias="kind")
+    msg: str = schema.description("The message from the SBOM generation")
+
+
+class SvnImportFiles(schema.Strict):
+    """Result of the task to import files from SVN."""
+
+    kind: Literal["svn_import"] = schema.Field(alias="kind")
+    msg: str = schema.description("The message from the SVN import")
+
+
+class VoteInitiate(schema.Strict):
+    """Result of the task to initiate a vote."""
+
+    kind: Literal["vote_initiate"] = schema.Field(alias="kind")
+    message: str = schema.description("The message from the vote initiation")
+    email_to: str = schema.description("The email address the vote was sent 
to")
+    vote_end: str = schema.description("The date and time the vote ends")
+    subject: str = schema.description("The subject of the vote email")
+    mid: str | None = schema.description("The message ID of the vote email")
+    mail_send_warnings: list[str] = schema.description("Warnings from the mail 
server")
+
+
+Results = Annotated[
+    HashingCheck | MessageSend | SBOMGenerateCycloneDX | SvnImportFiles | 
VoteInitiate,
+    schema.Field(discriminator="kind"),
+]
+
+ResultsAdapter = TypeAdapter(Results)
diff --git a/atr/routes/compose.py b/atr/routes/compose.py
index 154807a..0d41509 100644
--- a/atr/routes/compose.py
+++ b/atr/routes/compose.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import json
 from typing import TYPE_CHECKING
 
 import werkzeug.wrappers.response as response
@@ -24,6 +23,7 @@ import wtforms
 import atr.db as db
 import atr.db.interaction as interaction
 import atr.db.models as models
+import atr.results as results
 import atr.revision as revision
 import atr.routes as routes
 import atr.routes.draft as draft
@@ -124,28 +124,8 @@ def _warnings_from_vote_result(vote_task: models.Task | 
None) -> list[str]:
     if not vote_task or (not vote_task.result):
         return ["No vote task result found."]
 
-    if not isinstance(vote_task.result, list):
-        return ["Vote task result is not a list."]
+    vote_task_result = vote_task.result
+    if not isinstance(vote_task_result, results.VoteInitiate):
+        return ["Vote task result is not a results.VoteInitiate instance."]
 
-    if len(vote_task.result) != 1:
-        return ["Vote task result list length invalid."]
-
-    if not (first_task_result := vote_task.result[0]):
-        return ["Vote task result item is empty."]
-
-    if not isinstance(first_task_result, str):
-        return ["Vote task result item is not a string."]
-
-    try:
-        data_after_json_parse = json.loads(first_task_result)
-    except json.JSONDecodeError:
-        return ["Vote task result content not valid JSON."]
-
-    if not isinstance(data_after_json_parse, dict):
-        return ["Vote task result JSON content not a dictionary."]
-
-    existing_warnings_list = data_after_json_parse.get("mail_send_warnings", 
[])
-    if not isinstance(existing_warnings_list, list):
-        return ["Vote task result mail_send_warnings is not a list."]
-
-    return existing_warnings_list
+    return vote_task_result.mail_send_warnings
diff --git a/atr/routes/release.py b/atr/routes/release.py
index a68a28c..bd5b03b 100644
--- a/atr/routes/release.py
+++ b/atr/routes/release.py
@@ -49,10 +49,6 @@ async def bulk_status(session: routes.CommitterSession, 
task_id: int) -> str | r
         if task.task_type != "package_bulk_download":
             return await session.redirect(root.index, error=f"Task with ID 
{task_id} is not a bulk download task.")
 
-        # If result is a list or tuple with a single item, extract it
-        if isinstance(task.result, list | tuple) and (len(task.result) == 1):
-            task.result = task.result[0]
-
         # Get the release associated with this task if available
         release = None
         # Debug print the task.task_args using the logger
diff --git a/atr/routes/resolve.py b/atr/routes/resolve.py
index a0fd266..0e3b846 100644
--- a/atr/routes/resolve.py
+++ b/atr/routes/resolve.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import json
 
 import quart
 import sqlmodel
@@ -24,6 +23,7 @@ import werkzeug.wrappers.response as response
 import atr.construct as construct
 import atr.db as db
 import atr.db.models as models
+import atr.results as results
 import atr.revision as revision
 import atr.routes as routes
 import atr.routes.compose as compose
@@ -111,25 +111,11 @@ def task_mid_get(latest_vote_task: models.Task) -> str | 
None:
     if util.is_dev_environment():
         return vote.TEST_MID
     # TODO: Improve this
-    task_mid = None
 
-    try:
-        for result in latest_vote_task.result or []:
-            if isinstance(result, str):
-                parsed_result = json.loads(result)
-            else:
-                # Shouldn't happen
-                parsed_result = result
-            if isinstance(parsed_result, dict):
-                task_mid = parsed_result.get("mid", "(mid not found in 
result)")
-                break
-            else:
-                task_mid = "(malformed result)"
-
-    except (json.JSONDecodeError, TypeError):
-        task_mid = "(malformed result)"
-
-    return task_mid
+    result = latest_vote_task.result
+    if not isinstance(result, results.VoteInitiate):
+        return None
+    return result.mid
 
 
 async def _resolve_vote(
diff --git a/atr/routes/vote.py b/atr/routes/vote.py
index 9d4fbc4..a79470e 100644
--- a/atr/routes/vote.py
+++ b/atr/routes/vote.py
@@ -16,7 +16,6 @@
 # under the License.
 
 import enum
-import json
 import logging
 import time
 from collections.abc import Generator
@@ -29,6 +28,7 @@ import wtforms
 
 import atr.db as db
 import atr.db.models as models
+import atr.results as results
 import atr.routes as routes
 import atr.routes.compose as compose
 import atr.routes.resolve as resolve
@@ -116,7 +116,15 @@ async def selected(session: routes.CommitterSession, 
project_name: str, version_
         if util.is_dev_environment():
             logging.warning("Setting vote task to completed in dev 
environment")
             latest_vote_task.status = models.TaskStatus.COMPLETED
-            latest_vote_task.result = [json.dumps({"mid": TEST_MID})]
+            latest_vote_task.result = results.VoteInitiate(
+                kind="vote_initiate",
+                message="Vote announcement email sent successfully",
+                email_to="[email protected]",
+                vote_end="2025-07-01 12:00:00",
+                subject="Test vote",
+                mid=TEST_MID,
+                mail_send_warnings=[],
+            )
 
         # Move task_mid_get here?
         task_mid = resolve.task_mid_get(latest_vote_task)
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 90c5b37..8971a8e 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -21,6 +21,7 @@ from typing import Any, Final
 
 import atr.db as db
 import atr.db.models as models
+import atr.results as results
 import atr.tasks.checks.hashing as hashing
 import atr.tasks.checks.license as license
 import atr.tasks.checks.paths as paths
@@ -139,7 +140,7 @@ def queued(
     )
 
 
-def resolve(task_type: models.TaskType) -> Callable[..., Awaitable[str | 
None]]:  # noqa: C901
+def resolve(task_type: models.TaskType) -> Callable[..., 
Awaitable[results.Results | None]]:  # noqa: C901
     match task_type:
         case models.TaskType.HASHING_CHECK:
             return hashing.check
diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py
index 299521c..a4f043c 100644
--- a/atr/tasks/checks/hashing.py
+++ b/atr/tasks/checks/hashing.py
@@ -22,12 +22,13 @@ from typing import Final
 
 import aiofiles
 
+import atr.results as results
 import atr.tasks.checks as checks
 
 _LOGGER: Final = logging.getLogger(__name__)
 
 
-async def check(args: checks.FunctionArguments) -> str | None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
     """Check the hash of a file."""
     recorder = await args.recorder()
     if not (hash_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py
index 834c2ae..fd12c26 100644
--- a/atr/tasks/checks/license.py
+++ b/atr/tasks/checks/license.py
@@ -25,6 +25,7 @@ from collections.abc import Iterator
 from typing import Any, Final
 
 import atr.db.models as models
+import atr.results as results
 import atr.schema as schema
 import atr.static as static
 import atr.tarzip as tarzip
@@ -121,7 +122,7 @@ Result = ArtifactResult | MemberResult | MemberSkippedResult
 # Tasks
 
 
-async def files(args: checks.FunctionArguments) -> str | None:
+async def files(args: checks.FunctionArguments) -> results.Results | None:
     """Check that the LICENSE and NOTICE files exist and are valid."""
     recorder = await args.recorder()
     if not (artifact_abs_path := await recorder.abs_path()):
@@ -148,7 +149,7 @@ async def files(args: checks.FunctionArguments) -> str | 
None:
     return None
 
 
-async def headers(args: checks.FunctionArguments) -> str | None:
+async def headers(args: checks.FunctionArguments) -> results.Results | None:
     """Check that all source files have valid license headers."""
     recorder = await args.recorder()
     if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py
index 5294604..fb5721f 100644
--- a/atr/tasks/checks/paths.py
+++ b/atr/tasks/checks/paths.py
@@ -23,6 +23,7 @@ from typing import Final
 import aiofiles.os
 
 import atr.analysis as analysis
+import atr.results as results
 import atr.tasks.checks as checks
 import atr.util as util
 
@@ -30,7 +31,7 @@ _ALLOWED_TOP_LEVEL = {"CHANGES", "LICENSE", "NOTICE", 
"README"}
 _LOGGER: Final = logging.getLogger(__name__)
 
 
-async def check(args: checks.FunctionArguments) -> None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
     """Check file path structure and naming conventions against ASF release 
policy for all files in a release."""
     # We refer to the following authoritative policies:
     # - Release Creation Process (RCP)
diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py
index 5ffe721..589f2f5 100644
--- a/atr/tasks/checks/rat.py
+++ b/atr/tasks/checks/rat.py
@@ -26,6 +26,7 @@ from typing import Any, Final
 
 import atr.archives as archives
 import atr.config as config
+import atr.results as results
 import atr.tasks.checks as checks
 import atr.util as util
 
@@ -44,7 +45,7 @@ _LOGGER: Final = logging.getLogger(__name__)
 _RAT_EXCLUDES_FILENAMES: Final[set[str]] = {".rat-excludes", 
"rat-excludes.txt"}
 
 
-async def check(args: checks.FunctionArguments) -> str | None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
     """Use Apache RAT to check the licenses of the files in the artifact."""
     recorder = await args.recorder()
     if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py
index ed972e5..d78480f 100644
--- a/atr/tasks/checks/signature.py
+++ b/atr/tasks/checks/signature.py
@@ -26,13 +26,14 @@ import sqlmodel
 
 import atr.db as db
 import atr.db.models as models
+import atr.results as results
 import atr.tasks.checks as checks
 import atr.util as util
 
 _LOGGER: Final = logging.getLogger(__name__)
 
 
-async def check(args: checks.FunctionArguments) -> str | None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
     """Check a signature file."""
     recorder = await args.recorder()
     if not (primary_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py
index b0454a8..1a78959 100644
--- a/atr/tasks/checks/targz.py
+++ b/atr/tasks/checks/targz.py
@@ -21,6 +21,7 @@ import tarfile
 from typing import Final
 
 import atr.archives as archives
+import atr.results as results
 import atr.tasks.checks as checks
 
 _LOGGER: Final = logging.getLogger(__name__)
@@ -32,7 +33,7 @@ class RootDirectoryError(Exception):
     ...
 
 
-async def integrity(args: checks.FunctionArguments) -> str | None:
+async def integrity(args: checks.FunctionArguments) -> results.Results | None:
     """Check the integrity of a .tar.gz file."""
     recorder = await args.recorder()
     if not (artifact_abs_path := await recorder.abs_path()):
@@ -72,7 +73,7 @@ def root_directory(tgz_path: str) -> str:
     return root
 
 
-async def structure(args: checks.FunctionArguments) -> str | None:
+async def structure(args: checks.FunctionArguments) -> results.Results | None:
     """Check the structure of a .tar.gz file."""
     recorder = await args.recorder()
     if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py
index a0167c1..0b31e5b 100644
--- a/atr/tasks/checks/zipformat.py
+++ b/atr/tasks/checks/zipformat.py
@@ -21,12 +21,13 @@ import os
 import zipfile
 from typing import Any, Final
 
+import atr.results as results
 import atr.tasks.checks as checks
 
 _LOGGER: Final = logging.getLogger(__name__)
 
 
-async def integrity(args: checks.FunctionArguments) -> str | None:
+async def integrity(args: checks.FunctionArguments) -> results.Results | None:
     """Check that the zip archive is not corrupted and can be opened."""
     recorder = await args.recorder()
     if not (artifact_abs_path := await recorder.abs_path()):
@@ -46,7 +47,7 @@ async def integrity(args: checks.FunctionArguments) -> str | 
None:
     return None
 
 
-async def structure(args: checks.FunctionArguments) -> str | None:
+async def structure(args: checks.FunctionArguments) -> results.Results | None:
     """Check that the zip archive has a single root directory matching the 
artifact name."""
     recorder = await args.recorder()
     if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/keys.py b/atr/tasks/keys.py
index 7616467..b19a970 100644
--- a/atr/tasks/keys.py
+++ b/atr/tasks/keys.py
@@ -23,6 +23,7 @@ import aiofiles
 import atr.db as db
 import atr.db.interaction as interaction
 import atr.db.models as models
+import atr.results as results
 import atr.schema as schema
 import atr.tasks.checks as checks
 import atr.util as util
@@ -38,7 +39,7 @@ class ImportFile(schema.Strict):
 
 
 @checks.with_model(ImportFile)
-async def import_file(args: ImportFile) -> str | None:
+async def import_file(args: ImportFile) -> results.Results | None:
     """Import a KEYS file from a draft release candidate revision."""
     async with db.session() as data:
         release = await data.release(name=args.release_name).demand(
diff --git a/atr/tasks/message.py b/atr/tasks/message.py
index 57d4790..4a0bb1e 100644
--- a/atr/tasks/message.py
+++ b/atr/tasks/message.py
@@ -15,11 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import json
 import logging
 from typing import Final
 
 import atr.mail as mail
+import atr.results as results
 import atr.schema as schema
 import atr.tasks.checks as checks
 import atr.util as util
@@ -43,7 +43,7 @@ class SendError(Exception): ...
 
 
 @checks.with_model(Send)
-async def send(args: Send) -> str | None:
+async def send(args: Send) -> results.Results | None:
     if args.email_recipient not in 
util.permitted_recipients(args.email_sender):
         raise SendError(f"You are not permitted to send announcements to 
{args.email_recipient}")
 
@@ -59,14 +59,15 @@ async def send(args: Send) -> str | None:
     # TODO: Move this call into send itself?
     await mail.set_secret_key_default()
     mid, mail_errors = await mail.send(message)
-
-    result_data: dict[str, str | list[str]] = {"mid": mid}
     if mail_errors:
         _LOGGER.warning(f"Mail sending to {args.email_recipient} for subject 
'{args.subject}' encountered errors:")
         for error in mail_errors:
             _LOGGER.warning(f"- {error}")
-        result_data["mail_send_warnings"] = mail_errors
 
     # TODO: Record the vote in the database?
     # We'd need to sync with manual votes too
-    return json.dumps(result_data)
+    return results.MessageSend(
+        kind="message_send",
+        mid=mid,
+        mail_send_warnings=mail_errors,
+    )
diff --git a/atr/tasks/sbom.py b/atr/tasks/sbom.py
index 4f86c10..1425e4f 100644
--- a/atr/tasks/sbom.py
+++ b/atr/tasks/sbom.py
@@ -25,6 +25,7 @@ import aiofiles
 
 import atr.archives as archives
 import atr.config as config
+import atr.results as results
 import atr.schema as schema
 import atr.tasks.checks as checks
 import atr.tasks.checks.targz as targz
@@ -50,7 +51,7 @@ class SBOMGenerationError(Exception):
 
 
 @checks.with_model(GenerateCycloneDX)
-async def generate_cyclonedx(args: GenerateCycloneDX) -> str | None:
+async def generate_cyclonedx(args: GenerateCycloneDX) -> results.Results | 
None:
     """Generate a CycloneDX SBOM for the given artifact and write it to the 
output path."""
     try:
         result_data = await _generate_cyclonedx_core(args.artifact_path, 
args.output_path)
@@ -58,7 +59,10 @@ async def generate_cyclonedx(args: GenerateCycloneDX) -> str 
| None:
         msg = result_data["message"]
         if not isinstance(msg, str):
             raise SBOMGenerationError(f"Invalid message type: {type(msg)}")
-        return msg
+        return results.SBOMGenerateCycloneDX(
+            kind="sbom_generate_cyclonedx",
+            msg=msg,
+        )
     except (archives.ExtractionError, SBOMGenerationError) as e:
         _LOGGER.error(f"SBOM generation failed for {args.artifact_path}: {e}")
         raise
diff --git a/atr/tasks/svn.py b/atr/tasks/svn.py
index c09f003..ab32287 100644
--- a/atr/tasks/svn.py
+++ b/atr/tasks/svn.py
@@ -23,6 +23,7 @@ from typing import Any, Final
 import aiofiles.os
 import aioshutil
 
+import atr.results as results
 import atr.revision as revision
 import atr.schema as schema
 import atr.tasks.checks as checks
@@ -50,11 +51,14 @@ class SvnImportError(Exception):
 
 
 @checks.with_model(SvnImport)
-async def import_files(args: SvnImport) -> str | None:
+async def import_files(args: SvnImport) -> results.Results | None:
     """Import files from SVN into a draft release candidate revision."""
     try:
         result_message = await _import_files_core(args)
-        return result_message
+        return results.SvnImportFiles(
+            kind="svn_import",
+            msg=result_message,
+        )
     except SvnImportError as e:
         _LOGGER.error(f"SVN import failed: {e.details}")
         raise
diff --git a/atr/tasks/task.py b/atr/tasks/task.py
index f73add5..fd62f1d 100644
--- a/atr/tasks/task.py
+++ b/atr/tasks/task.py
@@ -17,9 +17,10 @@
 
 from __future__ import annotations
 
-from typing import Any, Final
+from typing import Final
 
 import atr.db.models as models
+import atr.results as results
 
 QUEUED: Final = models.TaskStatus.QUEUED
 ACTIVE: Final = models.TaskStatus.ACTIVE
@@ -30,13 +31,6 @@ FAILED: Final = models.TaskStatus.FAILED
 class Error(Exception):
     """Error during task execution."""
 
-    def __init__(self, message: str, *result: Any) -> None:
+    def __init__(self, message: str, *result: results.Results | None) -> None:
         self.message = message
-        self.result = tuple(result)
-
-
-def results_as_tuple(item: Any) -> tuple[Any, ...]:
-    """Ensure that returned results are structured as a tuple."""
-    if not isinstance(item, tuple):
-        return (item,)
-    return item
+        self.result = result
diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py
index 68742f5..968af7f 100644
--- a/atr/tasks/vote.py
+++ b/atr/tasks/vote.py
@@ -16,14 +16,14 @@
 # under the License.
 
 import datetime
-import json
 import logging
-from typing import Any, Final
+from typing import Final
 
 import atr.construct as construct
 import atr.db as db
 import atr.db.interaction as interaction
 import atr.mail as mail
+import atr.results as results
 import atr.schema as schema
 import atr.tasks.checks as checks
 import atr.util as util
@@ -49,11 +49,10 @@ class VoteInitiationError(Exception): ...
 
 
 @checks.with_model(Initiate)
-async def initiate(args: Initiate) -> str | None:
+async def initiate(args: Initiate) -> results.Results | None:
     """Initiate a vote for a release."""
     try:
-        result_data = await _initiate_core_logic(args)
-        return json.dumps(result_data)
+        return await _initiate_core_logic(args)
 
     except VoteInitiationError as e:
         _LOGGER.error(f"Vote initiation failed: {e}")
@@ -63,7 +62,7 @@ async def initiate(args: Initiate) -> str | None:
         raise
 
 
-async def _initiate_core_logic(args: Initiate) -> dict[str, Any]:
+async def _initiate_core_logic(args: Initiate) -> results.Results | None:
     """Get arguments, create an email, and then send it to the recipient."""
     _LOGGER.info("Starting initiate_core")
 
@@ -137,18 +136,18 @@ async def _initiate_core_logic(args: Initiate) -> 
dict[str, Any]:
     mid, mail_errors = await mail.send(message)
 
     # Original success message structure
-    result_data: dict[str, str | list[str]] = {
-        "message": "Vote announcement email sent successfully",
-        "email_to": args.email_to,
-        "vote_end": vote_end_str,
-        "subject": subject,
-        "mid": mid,
-    }
+    result = results.VoteInitiate(
+        kind="vote_initiate",
+        message="Vote announcement email sent successfully",
+        email_to=args.email_to,
+        vote_end=vote_end_str,
+        subject=subject,
+        mid=mid,
+        mail_send_warnings=mail_errors,
+    )
 
     if mail_errors:
         _LOGGER.warning(f"Start vote for {args.release_name}: sending to 
{args.email_to}  gave errors: {mail_errors}")
-        result_data["mail_send_warnings"] = mail_errors
     else:
         _LOGGER.info(f"Vote email sent successfully to {args.email_to}")
-
-    return result_data
+    return result
diff --git a/atr/worker.py b/atr/worker.py
index 7333ba9..daa1448 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -25,7 +25,6 @@
 import asyncio
 import datetime
 import inspect
-import json
 import logging
 import os
 import resource
@@ -37,6 +36,7 @@ import sqlmodel
 
 import atr.db as db
 import atr.db.models as models
+import atr.results as results
 import atr.tasks as tasks
 import atr.tasks.checks as checks
 import atr.tasks.task as task
@@ -105,34 +105,6 @@ def _setup_logging() -> None:
 # Task functions
 
 
-async def _task_error_handle(task_id: int, e: Exception) -> None:
-    """Handle task error by updating the database with error information."""
-    if isinstance(e, task.Error):
-        _LOGGER.error(f"Task {task_id} failed: {e.message}")
-        _LOGGER.error("".join(traceback.format_exception(e)))
-        result = json.dumps(e.result)
-
-        async with db.session() as data:
-            async with data.begin():
-                task_obj = await data.task(id=task_id).get()
-                if task_obj:
-                    task_obj.status = task.FAILED
-                    task_obj.completed = datetime.datetime.now(datetime.UTC)
-                    task_obj.error = e.message
-                    task_obj.result = result
-    else:
-        _LOGGER.error(f"Task {task_id} failed: {e}")
-        _LOGGER.error("".join(traceback.format_exception(e)))
-
-        async with db.session() as data:
-            async with data.begin():
-                task_obj = await data.task(id=task_id).get()
-                if task_obj:
-                    task_obj.status = task.FAILED
-                    task_obj.completed = datetime.datetime.now(datetime.UTC)
-                    task_obj.error = str(e)
-
-
 async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | 
None:
     """
     Attempt to claim the oldest unclaimed task.
@@ -181,10 +153,10 @@ async def _task_process(task_id: int, task_type: str, 
task_args: list[str] | dic
         task_type_member = models.TaskType(task_type)
     except ValueError as e:
         _LOGGER.error(f"Invalid task type: {task_type}")
-        await _task_result_process(task_id, tuple(), task.FAILED, str(e))
+        await _task_result_process(task_id, None, task.FAILED, str(e))
         return
 
-    task_results: tuple[Any, ...]
+    task_results: results.Results | None
     try:
         handler = tasks.resolve(task_type_member)
         sig = inspect.signature(handler)
@@ -235,11 +207,11 @@ async def _task_process(task_id: int, task_type: str, 
task_args: list[str] | dic
             # Otherwise, it's not a check handler
             handler_result = await handler(task_args)
 
-        task_results = (handler_result,)
+        task_results = handler_result
         status = task.COMPLETED
         error = None
     except Exception as e:
-        task_results = tuple()
+        task_results = None
         status = task.FAILED
         error_details = traceback.format_exc()
         _LOGGER.error(f"Task {task_id} failed processing: {error_details}")
@@ -248,7 +220,7 @@ async def _task_process(task_id: int, task_type: str, 
task_args: list[str] | dic
 
 
 async def _task_result_process(
-    task_id: int, task_results: tuple[Any, ...], status: models.TaskStatus, 
error: str | None = None
+    task_id: int, task_results: results.Results | None, status: 
models.TaskStatus, error: str | None = None
 ) -> None:
     """Process and store task results in the database."""
     async with db.session() as data:
diff --git a/playwright/test.py b/playwright/test.py
index 1e365b1..da5de65 100644
--- a/playwright/test.py
+++ b/playwright/test.py
@@ -207,7 +207,7 @@ def lifecycle_05_resolve_vote(page: sync_api.Page, 
credentials: Credentials, ver
         logging.warning("Vote initiation banner not detected after 15s, 
proceeding anyway")
 
     logging.info("Locating the 'Resolve vote' button")
-    tabulate_form_locator = 
page.locator(f'form[action="/vote/tooling-test-example/{version_name}/tabulate"]')
+    tabulate_form_locator = 
page.locator(f'form[action="/vote/tooling-test-example/{version_name}/resolve"]')
     sync_api.expect(tabulate_form_locator).to_be_visible()
 
     tabulate_button_locator = 
tabulate_form_locator.locator('button[type="submit"]:has-text("Resolve vote")')
@@ -216,7 +216,7 @@ def lifecycle_05_resolve_vote(page: sync_api.Page, 
credentials: Credentials, ver
     tabulate_button_locator.click()
 
     logging.info("Waiting for navigation to tabulated votes page")
-    wait_for_path(page, f"/vote/tooling-test-example/{version_name}/tabulate")
+    wait_for_path(page, f"/vote/tooling-test-example/{version_name}/resolve")
 
     logging.info("Locating the resolve vote form on the tabulated votes page")
     resolve_form_locator = 
page.locator(f'form[action="/resolve/tooling-test-example/{version_name}"]')


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to