This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push:
new 6f7def2 Use stricter types and automatic validation for task results
6f7def2 is described below
commit 6f7def2f7c859593fc0e5cd460f411ea816041d5
Author: Sean B. Palmer <[email protected]>
AuthorDate: Tue Jul 1 14:42:38 2025 +0100
Use stricter types and automatic validation for task results
---
atr/db/models.py | 25 ++++++++++++++-
atr/results.py | 75 +++++++++++++++++++++++++++++++++++++++++++
atr/routes/compose.py | 30 +++--------------
atr/routes/release.py | 4 ---
atr/routes/resolve.py | 24 +++-----------
atr/routes/vote.py | 12 +++++--
atr/tasks/__init__.py | 3 +-
atr/tasks/checks/hashing.py | 3 +-
atr/tasks/checks/license.py | 5 +--
atr/tasks/checks/paths.py | 3 +-
atr/tasks/checks/rat.py | 3 +-
atr/tasks/checks/signature.py | 3 +-
atr/tasks/checks/targz.py | 5 +--
atr/tasks/checks/zipformat.py | 5 +--
atr/tasks/keys.py | 3 +-
atr/tasks/message.py | 13 ++++----
atr/tasks/sbom.py | 8 +++--
atr/tasks/svn.py | 8 +++--
atr/tasks/task.py | 14 +++-----
atr/tasks/vote.py | 31 +++++++++---------
atr/worker.py | 40 ++++-------------------
playwright/test.py | 4 +--
22 files changed, 186 insertions(+), 135 deletions(-)
diff --git a/atr/db/models.py b/atr/db/models.py
index b026a30..ca7013a 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -31,6 +31,7 @@ import sqlalchemy.orm as orm
import sqlalchemy.sql.expression as expression
import sqlmodel
+import atr.results as results
import atr.schema as schema
sqlmodel.SQLModel.metadata = sqlalchemy.MetaData(
@@ -77,6 +78,28 @@ class UTCDateTime(sqlalchemy.types.TypeDecorator):
return value
+class ResultsJSON(sqlalchemy.types.TypeDecorator):
+ impl = sqlalchemy.JSON
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect):
+ if value is None:
+ return None
+ if hasattr(value, "model_dump"):
+ return value.model_dump()
+ if isinstance(value, dict):
+ return value
+ raise ValueError("Unsupported value for Results column")
+
+ def process_result_value(self, value, dialect):
+ if value is None:
+ return None
+ try:
+ return results.ResultsAdapter.validate_python(value)
+ except pydantic.ValidationError:
+ return None
+
+
class UserRole(str, enum.Enum):
COMMITTEE_MEMBER = "committee_member"
RELEASE_MANAGER = "release_manager"
@@ -557,7 +580,7 @@ class Task(sqlmodel.SQLModel, table=True):
default=None,
sa_column=sqlalchemy.Column(UTCDateTime),
)
- result: Any | None = sqlmodel.Field(default=None,
sa_column=sqlalchemy.Column(sqlalchemy.JSON))
+ result: results.Results | None = sqlmodel.Field(default=None,
sa_column=sqlalchemy.Column(ResultsJSON))
error: str | None = None
# Used for check tasks
diff --git a/atr/results.py b/atr/results.py
new file mode 100644
index 0000000..cf87346
--- /dev/null
+++ b/atr/results.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Annotated, Literal
+
+from pydantic import TypeAdapter
+
+import atr.schema as schema
+
+# TODO: If we put this in atr.tasks.results, we get a circular import error
+
+
+class HashingCheck(schema.Strict):
+ """Result of the task to check the hash of a file."""
+
+ kind: Literal["hashing_check"] = schema.Field(alias="kind")
+ hash_algorithm: str = schema.description("The hash algorithm used")
+ hash_value: str = schema.description("The hash value of the file")
+ hash_file_path: str = schema.description("The path to the hash file")
+
+
+class MessageSend(schema.Strict):
+ """Result of the task to send an email."""
+
+ kind: Literal["message_send"] = schema.Field(alias="kind")
+ mid: str = schema.description("The message ID of the email")
+ mail_send_warnings: list[str] = schema.description("Warnings from the mail
server")
+
+
+class SBOMGenerateCycloneDX(schema.Strict):
+ """Result of the task to generate a CycloneDX SBOM."""
+
+ kind: Literal["sbom_generate_cyclonedx"] = schema.Field(alias="kind")
+ msg: str = schema.description("The message from the SBOM generation")
+
+
+class SvnImportFiles(schema.Strict):
+ """Result of the task to import files from SVN."""
+
+ kind: Literal["svn_import"] = schema.Field(alias="kind")
+ msg: str = schema.description("The message from the SVN import")
+
+
+class VoteInitiate(schema.Strict):
+ """Result of the task to initiate a vote."""
+
+ kind: Literal["vote_initiate"] = schema.Field(alias="kind")
+ message: str = schema.description("The message from the vote initiation")
+ email_to: str = schema.description("The email address the vote was sent
to")
+ vote_end: str = schema.description("The date and time the vote ends")
+ subject: str = schema.description("The subject of the vote email")
+ mid: str | None = schema.description("The message ID of the vote email")
+ mail_send_warnings: list[str] = schema.description("Warnings from the mail
server")
+
+
+Results = Annotated[
+ HashingCheck | MessageSend | SBOMGenerateCycloneDX | SvnImportFiles |
VoteInitiate,
+ schema.Field(discriminator="kind"),
+]
+
+ResultsAdapter = TypeAdapter(Results)
diff --git a/atr/routes/compose.py b/atr/routes/compose.py
index 154807a..0d41509 100644
--- a/atr/routes/compose.py
+++ b/atr/routes/compose.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import json
from typing import TYPE_CHECKING
import werkzeug.wrappers.response as response
@@ -24,6 +23,7 @@ import wtforms
import atr.db as db
import atr.db.interaction as interaction
import atr.db.models as models
+import atr.results as results
import atr.revision as revision
import atr.routes as routes
import atr.routes.draft as draft
@@ -124,28 +124,8 @@ def _warnings_from_vote_result(vote_task: models.Task |
None) -> list[str]:
if not vote_task or (not vote_task.result):
return ["No vote task result found."]
- if not isinstance(vote_task.result, list):
- return ["Vote task result is not a list."]
+ vote_task_result = vote_task.result
+ if not isinstance(vote_task_result, results.VoteInitiate):
+ return ["Vote task result is not a results.VoteInitiate instance."]
- if len(vote_task.result) != 1:
- return ["Vote task result list length invalid."]
-
- if not (first_task_result := vote_task.result[0]):
- return ["Vote task result item is empty."]
-
- if not isinstance(first_task_result, str):
- return ["Vote task result item is not a string."]
-
- try:
- data_after_json_parse = json.loads(first_task_result)
- except json.JSONDecodeError:
- return ["Vote task result content not valid JSON."]
-
- if not isinstance(data_after_json_parse, dict):
- return ["Vote task result JSON content not a dictionary."]
-
- existing_warnings_list = data_after_json_parse.get("mail_send_warnings",
[])
- if not isinstance(existing_warnings_list, list):
- return ["Vote task result mail_send_warnings is not a list."]
-
- return existing_warnings_list
+ return vote_task_result.mail_send_warnings
diff --git a/atr/routes/release.py b/atr/routes/release.py
index a68a28c..bd5b03b 100644
--- a/atr/routes/release.py
+++ b/atr/routes/release.py
@@ -49,10 +49,6 @@ async def bulk_status(session: routes.CommitterSession,
task_id: int) -> str | r
if task.task_type != "package_bulk_download":
return await session.redirect(root.index, error=f"Task with ID
{task_id} is not a bulk download task.")
- # If result is a list or tuple with a single item, extract it
- if isinstance(task.result, list | tuple) and (len(task.result) == 1):
- task.result = task.result[0]
-
# Get the release associated with this task if available
release = None
# Debug print the task.task_args using the logger
diff --git a/atr/routes/resolve.py b/atr/routes/resolve.py
index a0fd266..0e3b846 100644
--- a/atr/routes/resolve.py
+++ b/atr/routes/resolve.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import json
import quart
import sqlmodel
@@ -24,6 +23,7 @@ import werkzeug.wrappers.response as response
import atr.construct as construct
import atr.db as db
import atr.db.models as models
+import atr.results as results
import atr.revision as revision
import atr.routes as routes
import atr.routes.compose as compose
@@ -111,25 +111,11 @@ def task_mid_get(latest_vote_task: models.Task) -> str |
None:
if util.is_dev_environment():
return vote.TEST_MID
# TODO: Improve this
- task_mid = None
- try:
- for result in latest_vote_task.result or []:
- if isinstance(result, str):
- parsed_result = json.loads(result)
- else:
- # Shouldn't happen
- parsed_result = result
- if isinstance(parsed_result, dict):
- task_mid = parsed_result.get("mid", "(mid not found in
result)")
- break
- else:
- task_mid = "(malformed result)"
-
- except (json.JSONDecodeError, TypeError):
- task_mid = "(malformed result)"
-
- return task_mid
+ result = latest_vote_task.result
+ if not isinstance(result, results.VoteInitiate):
+ return None
+ return result.mid
async def _resolve_vote(
diff --git a/atr/routes/vote.py b/atr/routes/vote.py
index 9d4fbc4..a79470e 100644
--- a/atr/routes/vote.py
+++ b/atr/routes/vote.py
@@ -16,7 +16,6 @@
# under the License.
import enum
-import json
import logging
import time
from collections.abc import Generator
@@ -29,6 +28,7 @@ import wtforms
import atr.db as db
import atr.db.models as models
+import atr.results as results
import atr.routes as routes
import atr.routes.compose as compose
import atr.routes.resolve as resolve
@@ -116,7 +116,15 @@ async def selected(session: routes.CommitterSession,
project_name: str, version_
if util.is_dev_environment():
logging.warning("Setting vote task to completed in dev
environment")
latest_vote_task.status = models.TaskStatus.COMPLETED
- latest_vote_task.result = [json.dumps({"mid": TEST_MID})]
+ latest_vote_task.result = results.VoteInitiate(
+ kind="vote_initiate",
+ message="Vote announcement email sent successfully",
+ email_to="[email protected]",
+ vote_end="2025-07-01 12:00:00",
+ subject="Test vote",
+ mid=TEST_MID,
+ mail_send_warnings=[],
+ )
# Move task_mid_get here?
task_mid = resolve.task_mid_get(latest_vote_task)
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 90c5b37..8971a8e 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -21,6 +21,7 @@ from typing import Any, Final
import atr.db as db
import atr.db.models as models
+import atr.results as results
import atr.tasks.checks.hashing as hashing
import atr.tasks.checks.license as license
import atr.tasks.checks.paths as paths
@@ -139,7 +140,7 @@ def queued(
)
-def resolve(task_type: models.TaskType) -> Callable[..., Awaitable[str |
None]]: # noqa: C901
+def resolve(task_type: models.TaskType) -> Callable[...,
Awaitable[results.Results | None]]: # noqa: C901
match task_type:
case models.TaskType.HASHING_CHECK:
return hashing.check
diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py
index 299521c..a4f043c 100644
--- a/atr/tasks/checks/hashing.py
+++ b/atr/tasks/checks/hashing.py
@@ -22,12 +22,13 @@ from typing import Final
import aiofiles
+import atr.results as results
import atr.tasks.checks as checks
_LOGGER: Final = logging.getLogger(__name__)
-async def check(args: checks.FunctionArguments) -> str | None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
"""Check the hash of a file."""
recorder = await args.recorder()
if not (hash_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py
index 834c2ae..fd12c26 100644
--- a/atr/tasks/checks/license.py
+++ b/atr/tasks/checks/license.py
@@ -25,6 +25,7 @@ from collections.abc import Iterator
from typing import Any, Final
import atr.db.models as models
+import atr.results as results
import atr.schema as schema
import atr.static as static
import atr.tarzip as tarzip
@@ -121,7 +122,7 @@ Result = ArtifactResult | MemberResult | MemberSkippedResult
# Tasks
-async def files(args: checks.FunctionArguments) -> str | None:
+async def files(args: checks.FunctionArguments) -> results.Results | None:
"""Check that the LICENSE and NOTICE files exist and are valid."""
recorder = await args.recorder()
if not (artifact_abs_path := await recorder.abs_path()):
@@ -148,7 +149,7 @@ async def files(args: checks.FunctionArguments) -> str |
None:
return None
-async def headers(args: checks.FunctionArguments) -> str | None:
+async def headers(args: checks.FunctionArguments) -> results.Results | None:
"""Check that all source files have valid license headers."""
recorder = await args.recorder()
if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py
index 5294604..fb5721f 100644
--- a/atr/tasks/checks/paths.py
+++ b/atr/tasks/checks/paths.py
@@ -23,6 +23,7 @@ from typing import Final
import aiofiles.os
import atr.analysis as analysis
+import atr.results as results
import atr.tasks.checks as checks
import atr.util as util
@@ -30,7 +31,7 @@ _ALLOWED_TOP_LEVEL = {"CHANGES", "LICENSE", "NOTICE",
"README"}
_LOGGER: Final = logging.getLogger(__name__)
-async def check(args: checks.FunctionArguments) -> None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
"""Check file path structure and naming conventions against ASF release
policy for all files in a release."""
# We refer to the following authoritative policies:
# - Release Creation Process (RCP)
diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py
index 5ffe721..589f2f5 100644
--- a/atr/tasks/checks/rat.py
+++ b/atr/tasks/checks/rat.py
@@ -26,6 +26,7 @@ from typing import Any, Final
import atr.archives as archives
import atr.config as config
+import atr.results as results
import atr.tasks.checks as checks
import atr.util as util
@@ -44,7 +45,7 @@ _LOGGER: Final = logging.getLogger(__name__)
_RAT_EXCLUDES_FILENAMES: Final[set[str]] = {".rat-excludes",
"rat-excludes.txt"}
-async def check(args: checks.FunctionArguments) -> str | None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
"""Use Apache RAT to check the licenses of the files in the artifact."""
recorder = await args.recorder()
if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py
index ed972e5..d78480f 100644
--- a/atr/tasks/checks/signature.py
+++ b/atr/tasks/checks/signature.py
@@ -26,13 +26,14 @@ import sqlmodel
import atr.db as db
import atr.db.models as models
+import atr.results as results
import atr.tasks.checks as checks
import atr.util as util
_LOGGER: Final = logging.getLogger(__name__)
-async def check(args: checks.FunctionArguments) -> str | None:
+async def check(args: checks.FunctionArguments) -> results.Results | None:
"""Check a signature file."""
recorder = await args.recorder()
if not (primary_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py
index b0454a8..1a78959 100644
--- a/atr/tasks/checks/targz.py
+++ b/atr/tasks/checks/targz.py
@@ -21,6 +21,7 @@ import tarfile
from typing import Final
import atr.archives as archives
+import atr.results as results
import atr.tasks.checks as checks
_LOGGER: Final = logging.getLogger(__name__)
@@ -32,7 +33,7 @@ class RootDirectoryError(Exception):
...
-async def integrity(args: checks.FunctionArguments) -> str | None:
+async def integrity(args: checks.FunctionArguments) -> results.Results | None:
"""Check the integrity of a .tar.gz file."""
recorder = await args.recorder()
if not (artifact_abs_path := await recorder.abs_path()):
@@ -72,7 +73,7 @@ def root_directory(tgz_path: str) -> str:
return root
-async def structure(args: checks.FunctionArguments) -> str | None:
+async def structure(args: checks.FunctionArguments) -> results.Results | None:
"""Check the structure of a .tar.gz file."""
recorder = await args.recorder()
if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py
index a0167c1..0b31e5b 100644
--- a/atr/tasks/checks/zipformat.py
+++ b/atr/tasks/checks/zipformat.py
@@ -21,12 +21,13 @@ import os
import zipfile
from typing import Any, Final
+import atr.results as results
import atr.tasks.checks as checks
_LOGGER: Final = logging.getLogger(__name__)
-async def integrity(args: checks.FunctionArguments) -> str | None:
+async def integrity(args: checks.FunctionArguments) -> results.Results | None:
"""Check that the zip archive is not corrupted and can be opened."""
recorder = await args.recorder()
if not (artifact_abs_path := await recorder.abs_path()):
@@ -46,7 +47,7 @@ async def integrity(args: checks.FunctionArguments) -> str |
None:
return None
-async def structure(args: checks.FunctionArguments) -> str | None:
+async def structure(args: checks.FunctionArguments) -> results.Results | None:
"""Check that the zip archive has a single root directory matching the
artifact name."""
recorder = await args.recorder()
if not (artifact_abs_path := await recorder.abs_path()):
diff --git a/atr/tasks/keys.py b/atr/tasks/keys.py
index 7616467..b19a970 100644
--- a/atr/tasks/keys.py
+++ b/atr/tasks/keys.py
@@ -23,6 +23,7 @@ import aiofiles
import atr.db as db
import atr.db.interaction as interaction
import atr.db.models as models
+import atr.results as results
import atr.schema as schema
import atr.tasks.checks as checks
import atr.util as util
@@ -38,7 +39,7 @@ class ImportFile(schema.Strict):
@checks.with_model(ImportFile)
-async def import_file(args: ImportFile) -> str | None:
+async def import_file(args: ImportFile) -> results.Results | None:
"""Import a KEYS file from a draft release candidate revision."""
async with db.session() as data:
release = await data.release(name=args.release_name).demand(
diff --git a/atr/tasks/message.py b/atr/tasks/message.py
index 57d4790..4a0bb1e 100644
--- a/atr/tasks/message.py
+++ b/atr/tasks/message.py
@@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
-import json
import logging
from typing import Final
import atr.mail as mail
+import atr.results as results
import atr.schema as schema
import atr.tasks.checks as checks
import atr.util as util
@@ -43,7 +43,7 @@ class SendError(Exception): ...
@checks.with_model(Send)
-async def send(args: Send) -> str | None:
+async def send(args: Send) -> results.Results | None:
if args.email_recipient not in
util.permitted_recipients(args.email_sender):
raise SendError(f"You are not permitted to send announcements to
{args.email_recipient}")
@@ -59,14 +59,15 @@ async def send(args: Send) -> str | None:
# TODO: Move this call into send itself?
await mail.set_secret_key_default()
mid, mail_errors = await mail.send(message)
-
- result_data: dict[str, str | list[str]] = {"mid": mid}
if mail_errors:
_LOGGER.warning(f"Mail sending to {args.email_recipient} for subject
'{args.subject}' encountered errors:")
for error in mail_errors:
_LOGGER.warning(f"- {error}")
- result_data["mail_send_warnings"] = mail_errors
# TODO: Record the vote in the database?
# We'd need to sync with manual votes too
- return json.dumps(result_data)
+ return results.MessageSend(
+ kind="message_send",
+ mid=mid,
+ mail_send_warnings=mail_errors,
+ )
diff --git a/atr/tasks/sbom.py b/atr/tasks/sbom.py
index 4f86c10..1425e4f 100644
--- a/atr/tasks/sbom.py
+++ b/atr/tasks/sbom.py
@@ -25,6 +25,7 @@ import aiofiles
import atr.archives as archives
import atr.config as config
+import atr.results as results
import atr.schema as schema
import atr.tasks.checks as checks
import atr.tasks.checks.targz as targz
@@ -50,7 +51,7 @@ class SBOMGenerationError(Exception):
@checks.with_model(GenerateCycloneDX)
-async def generate_cyclonedx(args: GenerateCycloneDX) -> str | None:
+async def generate_cyclonedx(args: GenerateCycloneDX) -> results.Results |
None:
"""Generate a CycloneDX SBOM for the given artifact and write it to the
output path."""
try:
result_data = await _generate_cyclonedx_core(args.artifact_path,
args.output_path)
@@ -58,7 +59,10 @@ async def generate_cyclonedx(args: GenerateCycloneDX) -> str
| None:
msg = result_data["message"]
if not isinstance(msg, str):
raise SBOMGenerationError(f"Invalid message type: {type(msg)}")
- return msg
+ return results.SBOMGenerateCycloneDX(
+ kind="sbom_generate_cyclonedx",
+ msg=msg,
+ )
except (archives.ExtractionError, SBOMGenerationError) as e:
_LOGGER.error(f"SBOM generation failed for {args.artifact_path}: {e}")
raise
diff --git a/atr/tasks/svn.py b/atr/tasks/svn.py
index c09f003..ab32287 100644
--- a/atr/tasks/svn.py
+++ b/atr/tasks/svn.py
@@ -23,6 +23,7 @@ from typing import Any, Final
import aiofiles.os
import aioshutil
+import atr.results as results
import atr.revision as revision
import atr.schema as schema
import atr.tasks.checks as checks
@@ -50,11 +51,14 @@ class SvnImportError(Exception):
@checks.with_model(SvnImport)
-async def import_files(args: SvnImport) -> str | None:
+async def import_files(args: SvnImport) -> results.Results | None:
"""Import files from SVN into a draft release candidate revision."""
try:
result_message = await _import_files_core(args)
- return result_message
+ return results.SvnImportFiles(
+ kind="svn_import",
+ msg=result_message,
+ )
except SvnImportError as e:
_LOGGER.error(f"SVN import failed: {e.details}")
raise
diff --git a/atr/tasks/task.py b/atr/tasks/task.py
index f73add5..fd62f1d 100644
--- a/atr/tasks/task.py
+++ b/atr/tasks/task.py
@@ -17,9 +17,10 @@
from __future__ import annotations
-from typing import Any, Final
+from typing import Final
import atr.db.models as models
+import atr.results as results
QUEUED: Final = models.TaskStatus.QUEUED
ACTIVE: Final = models.TaskStatus.ACTIVE
@@ -30,13 +31,6 @@ FAILED: Final = models.TaskStatus.FAILED
class Error(Exception):
"""Error during task execution."""
- def __init__(self, message: str, *result: Any) -> None:
+ def __init__(self, message: str, *result: results.Results | None) -> None:
self.message = message
- self.result = tuple(result)
-
-
-def results_as_tuple(item: Any) -> tuple[Any, ...]:
- """Ensure that returned results are structured as a tuple."""
- if not isinstance(item, tuple):
- return (item,)
- return item
+ self.result = result
diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py
index 68742f5..968af7f 100644
--- a/atr/tasks/vote.py
+++ b/atr/tasks/vote.py
@@ -16,14 +16,14 @@
# under the License.
import datetime
-import json
import logging
-from typing import Any, Final
+from typing import Final
import atr.construct as construct
import atr.db as db
import atr.db.interaction as interaction
import atr.mail as mail
+import atr.results as results
import atr.schema as schema
import atr.tasks.checks as checks
import atr.util as util
@@ -49,11 +49,10 @@ class VoteInitiationError(Exception): ...
@checks.with_model(Initiate)
-async def initiate(args: Initiate) -> str | None:
+async def initiate(args: Initiate) -> results.Results | None:
"""Initiate a vote for a release."""
try:
- result_data = await _initiate_core_logic(args)
- return json.dumps(result_data)
+ return await _initiate_core_logic(args)
except VoteInitiationError as e:
_LOGGER.error(f"Vote initiation failed: {e}")
@@ -63,7 +62,7 @@ async def initiate(args: Initiate) -> str | None:
raise
-async def _initiate_core_logic(args: Initiate) -> dict[str, Any]:
+async def _initiate_core_logic(args: Initiate) -> results.Results | None:
"""Get arguments, create an email, and then send it to the recipient."""
_LOGGER.info("Starting initiate_core")
@@ -137,18 +136,18 @@ async def _initiate_core_logic(args: Initiate) ->
dict[str, Any]:
mid, mail_errors = await mail.send(message)
# Original success message structure
- result_data: dict[str, str | list[str]] = {
- "message": "Vote announcement email sent successfully",
- "email_to": args.email_to,
- "vote_end": vote_end_str,
- "subject": subject,
- "mid": mid,
- }
+ result = results.VoteInitiate(
+ kind="vote_initiate",
+ message="Vote announcement email sent successfully",
+ email_to=args.email_to,
+ vote_end=vote_end_str,
+ subject=subject,
+ mid=mid,
+ mail_send_warnings=mail_errors,
+ )
if mail_errors:
_LOGGER.warning(f"Start vote for {args.release_name}: sending to
{args.email_to} gave errors: {mail_errors}")
- result_data["mail_send_warnings"] = mail_errors
else:
_LOGGER.info(f"Vote email sent successfully to {args.email_to}")
-
- return result_data
+ return result
diff --git a/atr/worker.py b/atr/worker.py
index 7333ba9..daa1448 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -25,7 +25,6 @@
import asyncio
import datetime
import inspect
-import json
import logging
import os
import resource
@@ -37,6 +36,7 @@ import sqlmodel
import atr.db as db
import atr.db.models as models
+import atr.results as results
import atr.tasks as tasks
import atr.tasks.checks as checks
import atr.tasks.task as task
@@ -105,34 +105,6 @@ def _setup_logging() -> None:
# Task functions
-async def _task_error_handle(task_id: int, e: Exception) -> None:
- """Handle task error by updating the database with error information."""
- if isinstance(e, task.Error):
- _LOGGER.error(f"Task {task_id} failed: {e.message}")
- _LOGGER.error("".join(traceback.format_exception(e)))
- result = json.dumps(e.result)
-
- async with db.session() as data:
- async with data.begin():
- task_obj = await data.task(id=task_id).get()
- if task_obj:
- task_obj.status = task.FAILED
- task_obj.completed = datetime.datetime.now(datetime.UTC)
- task_obj.error = e.message
- task_obj.result = result
- else:
- _LOGGER.error(f"Task {task_id} failed: {e}")
- _LOGGER.error("".join(traceback.format_exception(e)))
-
- async with db.session() as data:
- async with data.begin():
- task_obj = await data.task(id=task_id).get()
- if task_obj:
- task_obj.status = task.FAILED
- task_obj.completed = datetime.datetime.now(datetime.UTC)
- task_obj.error = str(e)
-
-
async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] |
None:
"""
Attempt to claim the oldest unclaimed task.
@@ -181,10 +153,10 @@ async def _task_process(task_id: int, task_type: str,
task_args: list[str] | dic
task_type_member = models.TaskType(task_type)
except ValueError as e:
_LOGGER.error(f"Invalid task type: {task_type}")
- await _task_result_process(task_id, tuple(), task.FAILED, str(e))
+ await _task_result_process(task_id, None, task.FAILED, str(e))
return
- task_results: tuple[Any, ...]
+ task_results: results.Results | None
try:
handler = tasks.resolve(task_type_member)
sig = inspect.signature(handler)
@@ -235,11 +207,11 @@ async def _task_process(task_id: int, task_type: str,
task_args: list[str] | dic
# Otherwise, it's not a check handler
handler_result = await handler(task_args)
- task_results = (handler_result,)
+ task_results = handler_result
status = task.COMPLETED
error = None
except Exception as e:
- task_results = tuple()
+ task_results = None
status = task.FAILED
error_details = traceback.format_exc()
_LOGGER.error(f"Task {task_id} failed processing: {error_details}")
@@ -248,7 +220,7 @@ async def _task_process(task_id: int, task_type: str,
task_args: list[str] | dic
async def _task_result_process(
- task_id: int, task_results: tuple[Any, ...], status: models.TaskStatus,
error: str | None = None
+ task_id: int, task_results: results.Results | None, status:
models.TaskStatus, error: str | None = None
) -> None:
"""Process and store task results in the database."""
async with db.session() as data:
diff --git a/playwright/test.py b/playwright/test.py
index 1e365b1..da5de65 100644
--- a/playwright/test.py
+++ b/playwright/test.py
@@ -207,7 +207,7 @@ def lifecycle_05_resolve_vote(page: sync_api.Page,
credentials: Credentials, ver
logging.warning("Vote initiation banner not detected after 15s,
proceeding anyway")
logging.info("Locating the 'Resolve vote' button")
- tabulate_form_locator =
page.locator(f'form[action="/vote/tooling-test-example/{version_name}/tabulate"]')
+ tabulate_form_locator =
page.locator(f'form[action="/vote/tooling-test-example/{version_name}/resolve"]')
sync_api.expect(tabulate_form_locator).to_be_visible()
tabulate_button_locator =
tabulate_form_locator.locator('button[type="submit"]:has-text("Resolve vote")')
@@ -216,7 +216,7 @@ def lifecycle_05_resolve_vote(page: sync_api.Page,
credentials: Credentials, ver
tabulate_button_locator.click()
logging.info("Waiting for navigation to tabulated votes page")
- wait_for_path(page, f"/vote/tooling-test-example/{version_name}/tabulate")
+ wait_for_path(page, f"/vote/tooling-test-example/{version_name}/resolve")
logging.info("Locating the resolve vote form on the tabulated votes page")
resolve_form_locator =
page.locator(f'form[action="/resolve/tooling-test-example/{version_name}"]')
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]