This is an automated email from the ASF dual-hosted git repository.
arm pushed a commit to branch arm
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
The following commit(s) were added to refs/heads/arm by this push:
new 424bddf8 #1136 - change import paths to stop importing all tasks via
SSH (-> writers -> announce -> message) module.
424bddf8 is described below
commit 424bddf8c68a23df0e7596eacb716b75ae20e1d9
Author: Alastair McFarlane <[email protected]>
AuthorDate: Tue Apr 7 13:59:26 2026 +0100
#1136 - change import paths to stop importing all tasks via SSH (-> writers
-> announce -> message) module.
---
atr/docs/tasks.md | 5 ++
atr/get/distribution.py | 8 +-
atr/get/finish.py | 8 +-
atr/mail.py | 8 +-
atr/models/__init__.py | 19 +++-
atr/models/args.py | 170 +++++++++++++++++++++++++++++++++++
atr/models/{__init__.py => mail.py} | 22 ++---
atr/sbom/cyclonedx.py | 11 +--
atr/ssh.py | 5 --
atr/storage/writers/announce.py | 4 +-
atr/storage/writers/distributions.py | 4 +-
atr/storage/writers/sbom.py | 10 +--
atr/storage/writers/vote.py | 11 ++-
atr/tasks/__init__.py | 27 +++---
atr/tasks/distribution.py | 19 ++--
atr/tasks/gha.py | 72 ++++++---------
atr/tasks/keys.py | 21 ++---
atr/tasks/maintenance.py | 19 ++--
atr/tasks/message.py | 70 +++++----------
atr/tasks/metadata.py | 18 ++--
atr/tasks/quarantine.py | 30 +++----
atr/tasks/sbom.py | 53 +++--------
atr/tasks/svn.py | 21 ++---
atr/tasks/vote.py | 66 ++++++--------
tests/unit/test_quarantine_task.py | 13 +--
25 files changed, 372 insertions(+), 342 deletions(-)
diff --git a/atr/docs/tasks.md b/atr/docs/tasks.md
index 8761a25c..a948e4ef 100644
--- a/atr/docs/tasks.md
+++ b/atr/docs/tasks.md
@@ -71,3 +71,8 @@ class TaskType(str, enum.Enum):
LICENSE_FILES = "license_files"
LICENSE_HEADERS = "license_headers"
```
+
+### Task arguments
+
+The arguments to a task should be defined in a class, which should reside in
`atr/models/args.py`. This allows them to be imported
+from elsewhere in the system without importing the whole task module.
diff --git a/atr/get/distribution.py b/atr/get/distribution.py
index 68135180..0e45faa0 100644
--- a/atr/get/distribution.py
+++ b/atr/get/distribution.py
@@ -24,12 +24,12 @@ import atr.blueprints.get as get
import atr.db as db
import atr.form as form
import atr.htm as htm
+import atr.models.args as args
import atr.models.safe as safe
import atr.models.sql as sql
import atr.post as post
import atr.render as render
import atr.shared as shared
-import atr.tasks.gha as gha
import atr.template as template
import atr.util as util
import atr.web as web
@@ -344,7 +344,7 @@ def _render_distribution_tasks(
def _render_task(task: sql.Task) -> htm.Element:
"""Render a distribution task's details."""
- args: gha.DistributionWorkflow =
gha.DistributionWorkflow.model_validate(task.task_args)
+ workflow_args: args.DistributionWorkflow =
args.DistributionWorkflow.model_validate(task.task_args)
task_date = task.added.strftime("%Y-%m-%d %H:%M:%S")
task_status = task.status.value
workflow_status = task.workflow.status if task.workflow else ""
@@ -353,11 +353,11 @@ def _render_task(task: sql.Task) -> htm.Element:
)
if task_status != sql.TaskStatus.COMPLETED:
return htm.details(".ms-4")[
- htm.summary[f"{task_date} {args.platform} ({args.package}
{args.version})"],
+ htm.summary[f"{task_date} {workflow_args.platform}
({workflow_args.package} {workflow_args.version})"],
htm.p(".ms-4")[task.error if task.error else
task_status.capitalize()],
]
else:
return htm.details(".ms-4")[
- htm.summary[f"{task_date} {args.platform} ({args.package}
{args.version})"],
+ htm.summary[f"{task_date} {workflow_args.platform}
({workflow_args.package} {workflow_args.version})"],
*[htm.p(".ms-4")[w] for w in workflow_message.split("\n")],
]
diff --git a/atr/get/finish.py b/atr/get/finish.py
index 1fe1a6c1..e07abd4b 100644
--- a/atr/get/finish.py
+++ b/atr/get/finish.py
@@ -39,12 +39,12 @@ import atr.get.revisions as revisions
import atr.get.root as root
import atr.htm as htm
import atr.mapping as mapping
+import atr.models.args as args
import atr.models.safe as safe
import atr.models.sql as sql
import atr.paths as paths
import atr.render as render
import atr.shared as shared
-import atr.tasks.gha as gha
import atr.template as template
import atr.util as util
import atr.web as web
@@ -479,7 +479,7 @@ def _render_release_card(release: sql.Release,
announce_disable_message: str) ->
def _render_task(task: sql.Task) -> htm.Element:
"""Render a distribution task's details."""
- args: gha.DistributionWorkflow =
gha.DistributionWorkflow.model_validate(task.task_args)
+ workflow_args: args.DistributionWorkflow =
args.DistributionWorkflow.model_validate(task.task_args)
task_date = task.added.strftime("%Y-%m-%d %H:%M:%S")
task_status = task.status.value
workflow_status = task.workflow.status if task.workflow else ""
@@ -488,12 +488,12 @@ def _render_task(task: sql.Task) -> htm.Element:
)
if task_status != sql.TaskStatus.COMPLETED:
return htm.details(".ms-4")[
- htm.summary[f"{task_date} {args.platform} ({args.package}
{args.version})"],
+ htm.summary[f"{task_date} {workflow_args.platform}
({workflow_args.package} {workflow_args.version})"],
htm.p(".ms-4")[task.error if task.error else
task_status.capitalize()],
]
else:
return htm.details(".ms-4")[
- htm.summary[f"{task_date} {args.platform} ({args.package}
{args.version})"],
+ htm.summary[f"{task_date} {workflow_args.platform}
({workflow_args.package} {workflow_args.version})"],
*[htm.p(".ms-4")[w] for w in workflow_message.split("\n")],
]
diff --git a/atr/mail.py b/atr/mail.py
index 4cd9e00e..4c8d1134 100644
--- a/atr/mail.py
+++ b/atr/mail.py
@@ -20,7 +20,6 @@ import email.headerregistry as headerregistry
import email.message as message
import email.policy as policy
import email.utils as utils
-import enum
import ssl
import time
import uuid
@@ -30,6 +29,7 @@ import aiosmtplib
# import dkim
import atr.log as log
+import atr.models.mail as models_mail
import atr.util as util
# TODO: We should choose a pattern for globals
@@ -41,11 +41,7 @@ _MAIL_RELAY: Final[str] = "mail-relay.apache.org"
_SMTP_PORT: Final[int] = 587
_SMTP_TIMEOUT: Final[int] = 30
-
-class MailFooterCategory(enum.StrEnum):
- NONE = "none"
- USER = "user"
- AUTO = "auto"
+MailFooterCategory = models_mail.MailFooterCategory
@dataclasses.dataclass
diff --git a/atr/models/__init__.py b/atr/models/__init__.py
index 20d8698e..ce0d46cb 100644
--- a/atr/models/__init__.py
+++ b/atr/models/__init__.py
@@ -15,15 +15,32 @@
# specific language governing permissions and limitations
# under the License.
-from . import api, basic, distribution, github, helpers, results, safe,
schema, session, sql, tabulate, validation
+from . import (
+ api,
+ args,
+ basic,
+ distribution,
+ github,
+ helpers,
+ mail,
+ results,
+ safe,
+ schema,
+ session,
+ sql,
+ tabulate,
+ validation,
+)
# If we use .__name__, pyright gives a warning
__all__ = [
"api",
+ "args",
"basic",
"distribution",
"github",
"helpers",
+ "mail",
"results",
"safe",
"schema",
diff --git a/atr/models/args.py b/atr/models/args.py
new file mode 100644
index 00000000..b2cec6ff
--- /dev/null
+++ b/atr/models/args.py
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Annotated, Any
+
+import pydantic
+
+from . import mail, safe, schema
+
+
+class ConvertCycloneDX(schema.Strict):
+ """Arguments for the task to convert an artifact to a CycloneDX SBOM."""
+
+ artifact_path: safe.StatePath = schema.description("Absolute path to the
artifact")
+ output_path: safe.StatePath = schema.description("Absolute path where the
generated SBOM JSON should be written")
+ revision: safe.RevisionNumber = schema.description("Revision number")
+
+
+class DistributionWorkflow(schema.Strict):
+ """Arguments for the task to start a GitHub Actions distribution
workflow."""
+
+ namespace: str = schema.description("Namespace to distribute to")
+ package: str = schema.description("Package to distribute")
+ version: str = schema.description("Version to distribute")
+ staging: bool = schema.description("Whether this is a staging
distribution")
+ project_key: str = schema.description("Project name in ATR")
+ version_key: str = schema.description("Version name in ATR")
+ phase: str = schema.description("Release phase in ATR")
+ asf_uid: str = schema.description("ASF UID of the user triggering the
workflow")
+ committee_key: str = schema.description("Committee name in ATR")
+ platform: str = schema.description("Distribution platform")
+ arguments: dict[str, str] = schema.description("Workflow arguments")
+ name: str = schema.description("Name of the run")
+
+
+class DistributionStatusCheckArgs(schema.Strict):
+ """Arguments for the task to re-check distribution statuses."""
+
+ next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
+ asf_uid: str = schema.description("ASF UID of the user triggering the
workflow")
+
+
+class FileArgs(schema.Strict):
+ """Arguments for SBOM file processing tasks."""
+
+ project_key: safe.ProjectKey = schema.description("Project name")
+ version_key: safe.VersionKey = schema.description("Version name")
+ revision_number: safe.RevisionNumber = schema.description("Revision
number")
+ file_path: safe.RelPath = schema.description("Relative path to the SBOM
file")
+ asf_uid: str | None = None
+
+
+class ScoreArgs(FileArgs):
+ """Arguments for SBOM file scoring tasks."""
+
+ previous_release_version: safe.VersionKey | None =
schema.description("Previous release version")
+
+
+class GenerateCycloneDX(schema.Strict):
+ """Arguments for the task to generate a CycloneDX SBOM from an artifact."""
+
+ artifact_path: safe.StatePath = schema.description("Absolute path to the
artifact")
+ output_path: safe.StatePath = schema.description("Absolute path where the
generated SBOM JSON should be written")
+
+
+class ImportFile(schema.Strict):
+ """Import a KEYS file from a draft release candidate revision."""
+
+ asf_uid: str
+ project_key: safe.ProjectKey
+ version_key: safe.VersionKey
+
+
+class Initiate(schema.Strict):
+ """Arguments for the task to start a vote."""
+
+ release_key: str = schema.description("The name of the release to vote on")
+ email_to: pydantic.EmailStr = schema.description("The mailing list To
address")
+ vote_duration: int = schema.description("Duration of the vote in hours")
+ initiator_id: str = schema.description("ASF ID of the vote initiator")
+ initiator_fullname: str = schema.description("Full name of the vote
initiator")
+ subject: str = schema.description("Subject line for the vote email")
+ body: str = schema.description("Body content for the vote email")
+ email_cc: list[pydantic.EmailStr] = schema.factory(list)
+ email_bcc: list[pydantic.EmailStr] = schema.factory(list)
+
+
+class MaintenanceArgs(schema.Strict):
+ """Arguments for the task to perform scheduled maintenance."""
+
+ asf_uid: str = schema.description("The ASF UID of the user triggering the
maintenance")
+ next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
+
+
+class QuarantineArchiveEntry(schema.Strict):
+ """An archive entry in a quarantine validation task."""
+
+ rel_path: str
+ content_hash: str
+
+
+class QuarantineValidate(schema.Strict):
+ """Arguments for the task to validate a quarantined upload."""
+
+ quarantined_id: int
+ archives: list[QuarantineArchiveEntry]
+
+
+def _ensure_footer_enum(value: Any) -> mail.MailFooterCategory | None:
+ if isinstance(value, mail.MailFooterCategory):
+ return value
+ if isinstance(value, str):
+ return mail.MailFooterCategory(value)
+ else:
+ return None
+
+
+class Send(schema.Strict):
+ """Arguments for the task to send an email."""
+
+ email_sender: pydantic.EmailStr = schema.description("The email address of
the sender")
+ email_to: pydantic.EmailStr = schema.description("The email To address")
+ subject: str = schema.description("The subject of the email")
+ body: str = schema.description("The body of the email")
+ in_reply_to: str | None = schema.description("The message ID of the email
to reply to")
+ email_cc: list[pydantic.EmailStr] = schema.factory(list)
+ email_bcc: list[pydantic.EmailStr] = schema.factory(list)
+ footer_category: Annotated[mail.MailFooterCategory,
pydantic.BeforeValidator(_ensure_footer_enum)] = (
+ schema.description("The category of email footer to include")
+ )
+
+
+class SvnImport(schema.Strict):
+ """Arguments for the task to import files from SVN."""
+
+ svn_url: safe.RelPath
+ revision: str
+ target_subdirectory: str | None
+ project_key: safe.ProjectKey
+ version_key: safe.VersionKey
+ asf_uid: str
+
+
+class Update(schema.Strict):
+ """Arguments for the task to update metadata from remote data sources."""
+
+ asf_uid: str = schema.description("The ASF UID of the user triggering the
update")
+ next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
+
+
+class WorkflowStatusCheck(schema.Strict):
+ """Arguments for the task to check the status of a GitHub Actions
workflow."""
+
+ run_id: int | None = schema.description("Run ID")
+ next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
+ asf_uid: str = schema.description("ASF UID of the user triggering the
workflow")
diff --git a/atr/models/__init__.py b/atr/models/mail.py
similarity index 69%
copy from atr/models/__init__.py
copy to atr/models/mail.py
index 20d8698e..9d67ec92 100644
--- a/atr/models/__init__.py
+++ b/atr/models/mail.py
@@ -15,20 +15,10 @@
# specific language governing permissions and limitations
# under the License.
-from . import api, basic, distribution, github, helpers, results, safe,
schema, session, sql, tabulate, validation
+import enum
-# If we use .__name__, pyright gives a warning
-__all__ = [
- "api",
- "basic",
- "distribution",
- "github",
- "helpers",
- "results",
- "safe",
- "schema",
- "session",
- "sql",
- "tabulate",
- "validation",
-]
+
+class MailFooterCategory(enum.StrEnum):
+ NONE = "none"
+ USER = "user"
+ AUTO = "auto"
diff --git a/atr/sbom/cyclonedx.py b/atr/sbom/cyclonedx.py
index 8bc8d19b..961cf4e4 100644
--- a/atr/sbom/cyclonedx.py
+++ b/atr/sbom/cyclonedx.py
@@ -21,16 +21,13 @@ import os
import subprocess
from typing import TYPE_CHECKING
-import cyclonedx.exception
-import cyclonedx.schema
-import cyclonedx.validation
-import cyclonedx.validation.json
-
from .utilities import get_pointer
if TYPE_CHECKING:
from collections.abc import Iterable
+ import cyclonedx.validation.json
+
from . import models
@@ -64,6 +61,10 @@ def validate_cli(bundle_value: models.bundle.Bundle) ->
list[str] | None:
def validate_py(
bundle_value: models.bundle.Bundle,
) -> Iterable[cyclonedx.validation.json.JsonValidationError] | None:
+ import cyclonedx.exception
+ import cyclonedx.schema
+ import cyclonedx.validation.json
+
json_sv = get_pointer(bundle_value.doc, "/specVersion")
schema_version = cyclonedx.schema.SchemaVersion.V1_6
if isinstance(json_sv, str):
diff --git a/atr/ssh.py b/atr/ssh.py
index 228f6507..c44f5d15 100644
--- a/atr/ssh.py
+++ b/atr/ssh.py
@@ -50,19 +50,14 @@ import atr.user as user
import atr.util as util
_CONFIG: Final = config.get()
-
_SSH_AUDIT_POLICY: Final = builtin_policies.BUILTIN_POLICIES["Hardened OpenSSH
Server v9.9 (version 1)"]
-
_ASYNCSSH_SUPPORTED_ENC: Final = {bytes(a) for a in
encryption.get_encryption_algs()}
_ASYNCSSH_SUPPORTED_KEX: Final = {bytes(a) for a in kex.get_kex_algs()}
_ASYNCSSH_SUPPORTED_MAC: Final = {bytes(a) for a in mac.get_mac_algs()}
-
-
_APPROVED_CIPHERS: Final = util.intersect_algs(_SSH_AUDIT_POLICY, "ciphers",
_ASYNCSSH_SUPPORTED_ENC)
_APPROVED_KEX: Final = util.intersect_algs(_SSH_AUDIT_POLICY, "kex",
_ASYNCSSH_SUPPORTED_KEX)
_APPROVED_MACS: Final = util.intersect_algs(_SSH_AUDIT_POLICY, "macs",
_ASYNCSSH_SUPPORTED_MAC)
-
_PATH_ALPHANUM: Final = frozenset(string.ascii_letters + string.digits + "-")
# From a survey of version numbers we find that only . and - are used
# We also allow + which is in common use
diff --git a/atr/storage/writers/announce.py b/atr/storage/writers/announce.py
index ebd786f4..76967a46 100644
--- a/atr/storage/writers/announce.py
+++ b/atr/storage/writers/announce.py
@@ -29,12 +29,12 @@ import sqlmodel
import atr.construct as construct
import atr.db as db
import atr.mail as mail
+import atr.models.args as args
import atr.models.basic as basic
import atr.models.safe as safe
import atr.models.sql as sql
import atr.paths as paths
import atr.storage as storage
-import atr.tasks.message as message
import atr.util as util
@@ -233,7 +233,7 @@ class CommitteeMember(CommitteeParticipant):
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.MESSAGE_SEND,
- task_args=message.Send(
+ task_args=args.Send(
email_sender=f"{asf_uid}@apache.org",
email_to=email_to,
subject=subject,
diff --git a/atr/storage/writers/distributions.py
b/atr/storage/writers/distributions.py
index 047b7017..c586a499 100644
--- a/atr/storage/writers/distributions.py
+++ b/atr/storage/writers/distributions.py
@@ -23,10 +23,10 @@ import datetime
import atr.db as db
import atr.log as log
import atr.models as models
+import atr.models.args as args
import atr.shared.distribution as distribution
import atr.storage as storage
import atr.storage.outcome as outcome
-import atr.tasks.gha as gha
import atr.util as util
@@ -108,7 +108,7 @@ class CommitteeMember(CommitteeParticipant):
) -> models.sql.Task:
dist_task = models.sql.Task(
task_type=models.sql.TaskType.DISTRIBUTION_WORKFLOW,
- task_args=gha.DistributionWorkflow(
+ task_args=args.DistributionWorkflow(
name=str(release_key),
namespace=str(owner_namespace) if owner_namespace else "",
package=str(package),
diff --git a/atr/storage/writers/sbom.py b/atr/storage/writers/sbom.py
index f2605887..0e2e16ed 100644
--- a/atr/storage/writers/sbom.py
+++ b/atr/storage/writers/sbom.py
@@ -22,10 +22,10 @@ import asyncio
import datetime
import atr.db as db
+import atr.models.args as args
import atr.models.safe as safe
import atr.models.sql as sql
import atr.storage as storage
-import atr.tasks.sbom as sbom
import atr.util as util
@@ -81,7 +81,7 @@ class CommitteeParticipant(FoundationCommitter):
) -> sql.Task:
sbom_task = sql.Task(
task_type=sql.TaskType.SBOM_AUGMENT,
- task_args=sbom.FileArgs(
+ task_args=args.FileArgs(
project_key=project_key,
version_key=version_key,
revision_number=revision_number,
@@ -111,7 +111,7 @@ class CommitteeParticipant(FoundationCommitter):
) -> sql.Task:
sbom_task = sql.Task(
task_type=sql.TaskType.SBOM_CONVERT,
- task_args=sbom.ConvertCycloneDX(
+ task_args=args.ConvertCycloneDX(
artifact_path=file_path,
revision=revision_number,
output_path=sbom_path,
@@ -142,7 +142,7 @@ class CommitteeParticipant(FoundationCommitter):
output_path = await asyncio.to_thread(_resolved_path_str,
sbom_path_in_new_revision)
sbom_task = sql.Task(
task_type=sql.TaskType.SBOM_GENERATE_CYCLONEDX,
- task_args=sbom.GenerateCycloneDX(
+ task_args=args.GenerateCycloneDX(
artifact_path=artifact_path,
output_path=output_path,
).model_dump(),
@@ -167,7 +167,7 @@ class CommitteeParticipant(FoundationCommitter):
) -> sql.Task:
sbom_task = sql.Task(
task_type=sql.TaskType.SBOM_OSV_SCAN,
- task_args=sbom.FileArgs(
+ task_args=args.FileArgs(
project_key=project_key,
version_key=version_key,
revision_number=revision_number,
diff --git a/atr/storage/writers/vote.py b/atr/storage/writers/vote.py
index c2562d0b..dd81c350 100644
--- a/atr/storage/writers/vote.py
+++ b/atr/storage/writers/vote.py
@@ -28,12 +28,11 @@ import atr.db as db
import atr.db.interaction as interaction
import atr.log as log
import atr.mail as mail
+import atr.models.args as args
import atr.models.results as results
import atr.models.safe as safe
import atr.models.sql as sql
import atr.storage as storage
-import atr.tasks.message as message
-import atr.tasks.vote as tasks_vote
import atr.util as util
@@ -122,7 +121,7 @@ class CommitteeParticipant(FoundationCommitter):
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.MESSAGE_SEND,
- task_args=message.Send(
+ task_args=args.Send(
email_sender=email_sender,
email_to=email_to,
subject=subject,
@@ -199,7 +198,7 @@ class CommitteeParticipant(FoundationCommitter):
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.VOTE_INITIATE,
- task_args=tasks_vote.Initiate(
+ task_args=args.Initiate(
release_key=release.key,
email_to=email_to,
vote_duration=vote_duration_choice,
@@ -527,7 +526,7 @@ class CommitteeMember(CommitteeParticipant):
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.MESSAGE_SEND,
- task_args=message.Send(
+ task_args=args.Send(
email_sender=email_sender,
email_to=email_to,
subject=subject,
@@ -546,7 +545,7 @@ class CommitteeMember(CommitteeParticipant):
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.MESSAGE_SEND,
- task_args=message.Send(
+ task_args=args.Send(
email_sender=email_sender,
email_to=extra_destination[0],
subject=subject,
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 71d16967..ca62d8ce 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -28,6 +28,7 @@ import atr.attestable as attestable
import atr.db as db
import atr.hashes as hashes
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
import atr.models.safe as safe
import atr.models.sql as sql
@@ -115,14 +116,14 @@ async def distribution_status_check(
schedule_next: bool = False,
) -> sql.Task:
"""Queue a workflow status update task."""
- args = distribution.DistributionStatusCheckArgs(next_schedule_seconds=0,
asf_uid=asf_uid)
+ task_args = args.DistributionStatusCheckArgs(next_schedule_seconds=0,
asf_uid=asf_uid)
if schedule_next:
- args.next_schedule_seconds = _EVERY_2_MINUTES
+ task_args.next_schedule_seconds = _EVERY_2_MINUTES
async with db.ensure_session(caller_data) as data:
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.DISTRIBUTION_STATUS,
- task_args=args.model_dump(),
+ task_args=task_args.model_dump(),
asf_uid=asf_uid,
revision_number=None,
primary_rel_path=None,
@@ -221,7 +222,7 @@ async def keys_import_file(
sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.KEYS_IMPORT_FILE,
- task_args=keys.ImportFile(
+ task_args=args.ImportFile(
asf_uid=asf_uid,
project_key=project_key,
version_key=version_key,
@@ -243,14 +244,14 @@ async def run_maintenance(
schedule_next: bool = False,
) -> sql.Task:
"""Queue a maintenance task."""
- args = maintenance.MaintenanceArgs(asf_uid=asf_uid,
next_schedule_seconds=0)
+ task_args = args.MaintenanceArgs(asf_uid=asf_uid, next_schedule_seconds=0)
if schedule_next:
- args.next_schedule_seconds = _DAILY
+ task_args.next_schedule_seconds = _DAILY
async with db.ensure_session(caller_data) as data:
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.MAINTENANCE,
- task_args=args.model_dump(),
+ task_args=task_args.model_dump(),
asf_uid=asf_uid,
revision_number=None,
primary_rel_path=None,
@@ -272,14 +273,14 @@ async def metadata_update(
schedule_next: bool = False,
) -> sql.Task:
"""Queue a metadata update task."""
- args = metadata.Update(asf_uid=asf_uid, next_schedule_seconds=0)
+ task_args = args.Update(asf_uid=asf_uid, next_schedule_seconds=0)
if schedule_next:
- args.next_schedule_seconds = _DAILY
+ task_args.next_schedule_seconds = _DAILY
async with db.ensure_session(caller_data) as data:
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.METADATA_UPDATE,
- task_args=args.model_dump(),
+ task_args=task_args.model_dump(),
asf_uid=asf_uid,
revision_number=None,
primary_rel_path=None,
@@ -501,14 +502,14 @@ async def workflow_update(
schedule_next: bool = False,
) -> sql.Task:
"""Queue a workflow status update task."""
- args = gha.WorkflowStatusCheck(next_schedule_seconds=0, run_id=0,
asf_uid=asf_uid)
+ task_args = args.WorkflowStatusCheck(next_schedule_seconds=0, run_id=0,
asf_uid=asf_uid)
if schedule_next:
- args.next_schedule_seconds = _EVERY_2_MINUTES
+ task_args.next_schedule_seconds = _EVERY_2_MINUTES
async with db.ensure_session(caller_data) as data:
task = sql.Task(
status=sql.TaskStatus.QUEUED,
task_type=sql.TaskType.WORKFLOW_STATUS,
- task_args=args.model_dump(),
+ task_args=task_args.model_dump(),
asf_uid=asf_uid,
revision_number=None,
primary_rel_path=None,
diff --git a/atr/tasks/distribution.py b/atr/tasks/distribution.py
index d3742690..33ab78c5 100644
--- a/atr/tasks/distribution.py
+++ b/atr/tasks/distribution.py
@@ -14,12 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import pydantic
-
import atr.db as db
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
-import atr.models.schema as schema
import atr.shared.distribution as distribution
import atr.storage as storage
import atr.tasks as tasks
@@ -28,15 +26,10 @@ import atr.tasks.checks as checks
_RETRY_LIMIT = 5
-class DistributionStatusCheckArgs(schema.Strict):
- """Arguments for the task to re-check distribution statuses."""
-
- next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
- asf_uid: str = schema.description("ASF UID of the user triggering the
workflow")
-
-
[email protected]_model(DistributionStatusCheckArgs)
-async def status_check(args: DistributionStatusCheckArgs, *, task_id: int |
None = None) -> results.Results | None:
[email protected]_model(args.DistributionStatusCheckArgs)
+async def status_check(
+ task_args: args.DistributionStatusCheckArgs, *, task_id: int | None = None
+) -> results.Results | None:
log.info("Checking pending recorded distributions")
dists = []
async with db.session() as data:
@@ -68,7 +61,7 @@ async def status_check(args: DistributionStatusCheckArgs, *,
task_id: int | None
except (distribution.DistributionError, storage.AccessError) as e:
msg = f"Failed to record distribution: {e}"
log.error(msg)
- await tasks.schedule_next(args.asf_uid, args.next_schedule_seconds,
tasks.distribution_status_check)
+ await tasks.schedule_next(task_args.asf_uid,
task_args.next_schedule_seconds, tasks.distribution_status_check)
return results.DistributionStatusCheck(
kind="distribution_status",
)
diff --git a/atr/tasks/gha.py b/atr/tasks/gha.py
index 660040b4..68179574 100644
--- a/atr/tasks/gha.py
+++ b/atr/tasks/gha.py
@@ -21,14 +21,13 @@ from collections.abc import Callable
from typing import Any, Final, NoReturn
import aiohttp
-import pydantic
import atr.config as config
import atr.db as db
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
import atr.models.safe as safe
-import atr.models.schema as schema
import atr.models.sql as sql
import atr.storage as storage
import atr.tasks as tasks
@@ -42,31 +41,8 @@ _FAILED_STATUSES: Final[list[str]] = ["failure",
"startup_failure"]
_TIMEOUT_S = 60
-class DistributionWorkflow(schema.Strict):
- """Arguments for the task to start a Github Actions workflow."""
-
- namespace: str = schema.description("Namespace to distribute to")
- package: str = schema.description("Package to distribute")
- version: str = schema.description("Version to distribute")
- staging: bool = schema.description("Whether this is a staging
distribution")
- project_key: str = schema.description("Project name in ATR")
- version_key: str = schema.description("Version name in ATR")
- phase: str = schema.description("Release phase in ATR")
- asf_uid: str = schema.description("ASF UID of the user triggering the
workflow")
- committee_key: str = schema.description("Committee name in ATR")
- platform: str = schema.description("Distribution platform")
- arguments: dict[str, str] = schema.description("Workflow arguments")
- name: str = schema.description("Name of the run")
-
-
-class WorkflowStatusCheck(schema.Strict):
- run_id: int | None = schema.description("Run ID")
- next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
- asf_uid: str = schema.description("ASF UID of the user triggering the
workflow")
-
-
[email protected]_model(WorkflowStatusCheck)
-async def status_check(args: WorkflowStatusCheck) ->
results.DistributionWorkflowStatus:
[email protected]_model(args.WorkflowStatusCheck)
+async def status_check(task_args: args.WorkflowStatusCheck) ->
results.DistributionWorkflowStatus:
"""Check remote workflow statuses."""
headers = {"Accept": "application/vnd.github+json", "Authorization":
f"Bearer {config.get().GITHUB_TOKEN}"}
@@ -124,7 +100,7 @@ async def status_check(args: WorkflowStatusCheck) ->
results.DistributionWorkflo
f"Workflow status update completed: updated {updated_count}
workflow(s)",
)
- await tasks.schedule_next(args.asf_uid, args.next_schedule_seconds,
tasks.workflow_update)
+ await tasks.schedule_next(task_args.asf_uid,
task_args.next_schedule_seconds, tasks.workflow_update)
return results.DistributionWorkflowStatus(
kind="distribution_workflow_status",
@@ -136,35 +112,37 @@ async def status_check(args: WorkflowStatusCheck) ->
results.DistributionWorkflo
_fail(f"Unexpected error during workflow status update: {e!s}")
[email protected]_model(DistributionWorkflow)
-async def trigger_workflow(args: DistributionWorkflow, *, task_id: int | None
= None) -> results.Results | None:
- unique_id = f"atr-dist-{args.name}-{uuid.uuid4()}"
- project = safe.ProjectKey(args.project_key)
- safe.VersionKey(args.version_key)
[email protected]_model(args.DistributionWorkflow)
+async def trigger_workflow(
+ task_args: args.DistributionWorkflow, *, task_id: int | None = None
+) -> results.Results | None:
+ unique_id = f"atr-dist-{task_args.name}-{uuid.uuid4()}"
+ project = safe.ProjectKey(task_args.project_key)
+ safe.VersionKey(task_args.version_key)
try:
- sql_platform = sql.DistributionPlatform[args.platform]
+ sql_platform = sql.DistributionPlatform[task_args.platform]
except KeyError:
- _fail(f"Invalid platform: {args.platform}")
- workflow = f"distribute-{sql_platform.value.gh_slug}{'-stg' if
args.staging else ''}.yml"
+ _fail(f"Invalid platform: {task_args.platform}")
+ workflow = f"distribute-{sql_platform.value.gh_slug}{'-stg' if
task_args.staging else ''}.yml"
payload = {
"ref": "main",
"inputs": {
"atr-id": unique_id,
- "asf-uid": args.asf_uid,
- "project": args.project_key,
- "phase": args.phase,
- "version": args.version_key,
- "distribution-owner-namespace": args.namespace,
- "distribution-package": args.package,
- "distribution-version": args.version,
- # **args.arguments,
+ "asf-uid": task_args.asf_uid,
+ "project": task_args.project_key,
+ "phase": task_args.phase,
+ "version": task_args.version_key,
+ "distribution-owner-namespace": task_args.namespace,
+ "distribution-package": task_args.package,
+ "distribution-version": task_args.version,
+ # **task_args.arguments,
},
}
headers = {"Accept": "application/vnd.github+json", "Authorization":
f"Bearer {config.get().GITHUB_TOKEN}"}
log.info(
f"Triggering Github workflow apache/tooling-actions/{workflow} with
args: {
- json.dumps(args.arguments, indent=2)
+ json.dumps(task_args.arguments, indent=2)
}"
)
async with util.create_secure_session() as session:
@@ -182,13 +160,13 @@ async def trigger_workflow(args: DistributionWorkflow, *,
task_id: int | None =
if run.get("status") in _FAILED_STATUSES:
_fail(f"Github workflow apache/tooling-actions/{workflow} run
{run_id} failed with error")
- async with storage.write_as_committee_member(args.committee_key,
args.asf_uid) as w:
+ async with storage.write_as_committee_member(task_args.committee_key,
task_args.asf_uid) as w:
try:
await w.workflowstatus.add_workflow_status(workflow, run_id,
project, task_id, status=run.get("status"))
except storage.AccessError as e:
_fail(f"Failed to record distribution: {e}")
return results.DistributionWorkflow(
- kind="distribution_workflow", name=args.name, run_id=run_id,
url=run.get("html_url", "")
+ kind="distribution_workflow", name=task_args.name, run_id=run_id,
url=run.get("html_url", "")
)
diff --git a/atr/tasks/keys.py b/atr/tasks/keys.py
index b3743555..c08fdc27 100644
--- a/atr/tasks/keys.py
+++ b/atr/tasks/keys.py
@@ -15,27 +15,18 @@
# specific language governing permissions and limitations
# under the License.
+import atr.models.args as args
import atr.models.results as results
-import atr.models.safe as safe
-import atr.models.schema as schema
import atr.storage as storage
import atr.tasks.checks as checks
-class ImportFile(schema.Strict):
[email protected]_model(args.ImportFile)
+async def import_file(task_args: args.ImportFile) -> results.Results | None:
"""Import a KEYS file from a draft release candidate revision."""
-
- asf_uid: str
- project_key: safe.ProjectKey
- version_key: safe.VersionKey
-
-
[email protected]_model(ImportFile)
-async def import_file(args: ImportFile) -> results.Results | None:
- """Import a KEYS file from a draft release candidate revision."""
- async with storage.write(args.asf_uid) as write:
- wacm = await write.as_project_committee_member(args.project_key)
- outcomes = await wacm.keys.import_keys_file(args.project_key,
args.version_key)
+ async with storage.write(task_args.asf_uid) as write:
+ wacm = await write.as_project_committee_member(task_args.project_key)
+ outcomes = await wacm.keys.import_keys_file(task_args.project_key,
task_args.version_key)
if outcomes.any_error:
# TODO: Log this? This code is unused anyway
pass
diff --git a/atr/tasks/maintenance.py b/atr/tasks/maintenance.py
index 171c9a76..bc102751 100644
--- a/atr/tasks/maintenance.py
+++ b/atr/tasks/maintenance.py
@@ -15,30 +15,21 @@
# specific language governing permissions and limitations
# under the License.
-import pydantic
-
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
-import atr.models.schema as schema
import atr.tasks as tasks
import atr.tasks.checks as checks
-class MaintenanceArgs(schema.Strict):
- """Arguments for the task to perform scheduled maintenance."""
-
- asf_uid: str = schema.description("The ASF UID of the user triggering the
maintenance")
- next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
-
-
class MaintenanceError(Exception):
pass
[email protected]_model(MaintenanceArgs)
-async def run(args: MaintenanceArgs) -> results.Results | None:
[email protected]_model(args.MaintenanceArgs)
+async def run(task_args: args.MaintenanceArgs) -> results.Results | None:
"""Run maintenance."""
- log.info(f"Starting maintenance (user: {args.asf_uid})")
+ log.info(f"Starting maintenance (user: {task_args.asf_uid})")
try:
await _storage_maintenance()
@@ -47,7 +38,7 @@ async def run(args: MaintenanceArgs) -> results.Results |
None:
"Storage maintenance completed successfully",
)
- await tasks.schedule_next(args.asf_uid, args.next_schedule_seconds,
tasks.run_maintenance)
+ await tasks.schedule_next(task_args.asf_uid,
task_args.next_schedule_seconds, tasks.run_maintenance)
return results.Maintenance(
kind="maintenance",
diff --git a/atr/tasks/message.py b/atr/tasks/message.py
index 7e3f59d9..75f937c3 100644
--- a/atr/tasks/message.py
+++ b/atr/tasks/message.py
@@ -14,64 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Annotated, Any
-
-import pydantic
-
import atr.ldap as ldap
import atr.log as log
import atr.mail as mail
+import atr.models.args as args
import atr.models.results as results
-import atr.models.schema as schema
import atr.storage as storage
import atr.tasks.checks as checks
-def _ensure_footer_enum(value: Any) -> mail.MailFooterCategory | None:
- if isinstance(value, mail.MailFooterCategory):
- return value
- if isinstance(value, str):
- return mail.MailFooterCategory(value)
- else:
- return None
-
-
-class Send(schema.Strict):
- """Arguments for the task to send an email."""
-
- email_sender: pydantic.EmailStr = schema.description("The email address of
the sender")
- email_to: pydantic.EmailStr = schema.description("The email To address")
- subject: str = schema.description("The subject of the email")
- body: str = schema.description("The body of the email")
- in_reply_to: str | None = schema.description("The message ID of the email
to reply to")
- email_cc: list[pydantic.EmailStr] = schema.factory(list)
- email_bcc: list[pydantic.EmailStr] = schema.factory(list)
- footer_category: Annotated[mail.MailFooterCategory,
pydantic.BeforeValidator(_ensure_footer_enum)] = (
- schema.description("The category of email footer to include")
- )
-
-
class SendError(Exception):
pass
[email protected]_model(Send)
-async def send(args: Send) -> results.Results | None:
- if "@" not in args.email_sender:
- log.warning(f"Invalid email sender: {args.email_sender}")
- sender_asf_uid = args.email_sender
- elif args.email_sender.endswith("@apache.org"):
- sender_asf_uid = args.email_sender.split("@")[0]
[email protected]_model(args.Send)
+async def send(task_args: args.Send) -> results.Results | None:
+ if "@" not in task_args.email_sender:
+ log.warning(f"Invalid email sender: {task_args.email_sender}")
+ sender_asf_uid = task_args.email_sender
+ elif task_args.email_sender.endswith("@apache.org"):
+ sender_asf_uid = task_args.email_sender.split("@")[0]
else:
- raise SendError(f"Invalid email sender: {args.email_sender}")
+ raise SendError(f"Invalid email sender: {task_args.email_sender}")
sender_account = await ldap.account_lookup(sender_asf_uid)
if sender_account is None:
- raise SendError(f"Invalid email account: {args.email_sender}")
+ raise SendError(f"Invalid email account: {task_args.email_sender}")
if ldap.is_banned(sender_account):
- raise SendError(f"Email account {args.email_sender} is banned")
+ raise SendError(f"Email account {task_args.email_sender} is banned")
- all_recipients = [args.email_to, *args.email_cc, *args.email_bcc]
+ all_recipients = [task_args.email_to, *task_args.email_cc,
*task_args.email_bcc]
for addr in all_recipients:
recipient_domain = addr.split("@")[-1]
sending_to_self = addr == f"{sender_asf_uid}@apache.org"
@@ -81,23 +53,23 @@ async def send(args: Send) -> results.Results | None:
raise SendError(f"You are not permitted to send emails to {addr}")
message = mail.Message(
- email_sender=args.email_sender,
- email_to=args.email_to,
- subject=args.subject,
- body=args.body,
- in_reply_to=args.in_reply_to,
- email_cc=args.email_cc,
- email_bcc=args.email_bcc,
+ email_sender=task_args.email_sender,
+ email_to=task_args.email_to,
+ subject=task_args.subject,
+ body=task_args.body,
+ in_reply_to=task_args.in_reply_to,
+ email_cc=task_args.email_cc,
+ email_bcc=task_args.email_bcc,
)
- footer_category = mail.MailFooterCategory(args.footer_category)
+ footer_category = mail.MailFooterCategory(task_args.footer_category)
async with storage.write(sender_asf_uid) as write:
wafc = write.as_foundation_committer()
mid, mail_errors = await wafc.mail.send(message, footer_category)
if mail_errors:
- log.warning(f"Mail sending to {args.email_to} for subject
'{args.subject}' encountered errors:")
+ log.warning(f"Mail sending to {task_args.email_to} for subject
'{task_args.subject}' encountered errors:")
for error in mail_errors:
log.warning(f"- {error}")
diff --git a/atr/tasks/metadata.py b/atr/tasks/metadata.py
index 1edf8fe1..e9042054 100644
--- a/atr/tasks/metadata.py
+++ b/atr/tasks/metadata.py
@@ -16,31 +16,23 @@
# under the License.
import aiohttp
-import pydantic
import atr.datasources.apache as apache
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
-import atr.models.schema as schema
import atr.tasks as tasks
import atr.tasks.checks as checks
-class Update(schema.Strict):
- """Arguments for the task to update metadata from remote data sources."""
-
- asf_uid: str = schema.description("The ASF UID of the user triggering the
update")
- next_schedule_seconds: int = pydantic.Field(default=0, description="The
next scheduled time")
-
-
class UpdateError(Exception):
pass
[email protected]_model(Update)
-async def update(args: Update) -> results.Results | None:
[email protected]_model(args.Update)
+async def update(task_args: args.Update) -> results.Results | None:
"""Update metadata from remote data sources."""
- log.info(f"Starting metadata update for user {args.asf_uid}")
+ log.info(f"Starting metadata update for user {task_args.asf_uid}")
try:
added_count, updated_count = await apache.update_metadata()
@@ -49,7 +41,7 @@ async def update(args: Update) -> results.Results | None:
f"Metadata update completed successfully: added {added_count},
updated {updated_count}",
)
- await tasks.schedule_next(args.asf_uid, args.next_schedule_seconds,
tasks.metadata_update)
+ await tasks.schedule_next(task_args.asf_uid,
task_args.next_schedule_seconds, tasks.metadata_update)
return results.MetadataUpdate(
kind="metadata_update",
diff --git a/atr/tasks/quarantine.py b/atr/tasks/quarantine.py
index 9137f11d..d6441297 100644
--- a/atr/tasks/quarantine.py
+++ b/atr/tasks/quarantine.py
@@ -36,9 +36,9 @@ import atr.db as db
import atr.detection as detection
import atr.hashes as hashes
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
import atr.models.safe as safe
-import atr.models.schema as schema
import atr.models.sql as sql
import atr.paths as paths
import atr.storage.writers.revision as revision
@@ -47,16 +47,6 @@ import atr.tasks.checks as checks
import atr.util as util
-class QuarantineArchiveEntry(schema.Strict):
- rel_path: str
- content_hash: str
-
-
-class QuarantineValidate(schema.Strict):
- quarantined_id: int
- archives: list[QuarantineArchiveEntry]
-
-
def backfill_archive_cache() -> list[tuple[str, safe.StatePath, float]]:
done_file = _backfill_done_file()
done_file_path = done_file.path
@@ -103,17 +93,17 @@ def backfill_archive_cache() -> list[tuple[str,
safe.StatePath, float]]:
return results_list
[email protected]_model(QuarantineValidate)
-async def validate(args: QuarantineValidate) -> results.Results | None:
[email protected]_model(args.QuarantineValidate)
+async def validate(task_args: args.QuarantineValidate) -> results.Results |
None:
async with db.session() as data:
- quarantined = await data.quarantined(id=args.quarantined_id,
_release=True).get()
+ quarantined = await data.quarantined(id=task_args.quarantined_id,
_release=True).get()
if quarantined is None:
- log.error(f"Quarantined row {args.quarantined_id} not found")
+ log.error(f"Quarantined row {task_args.quarantined_id} not found")
return None
if quarantined.status != sql.QuarantineStatus.PENDING:
- log.error(f"Quarantined row {args.quarantined_id} is not PENDING")
+ log.error(f"Quarantined row {task_args.quarantined_id} is not PENDING")
return None
release = quarantined.release
@@ -125,7 +115,7 @@ async def validate(args: QuarantineValidate) ->
results.Results | None:
await _mark_failed(quarantined, None, "Quarantine directory does not
exist")
return None
- file_entries, any_failed = await _run_safety_checks(args.archives,
quarantine_dir)
+ file_entries, any_failed = await _run_safety_checks(task_args.archives,
quarantine_dir)
if any_failed:
await _mark_failed(quarantined, file_entries)
@@ -133,7 +123,7 @@ async def validate(args: QuarantineValidate) ->
results.Results | None:
return None
try:
- await _extract_archives(args.archives, quarantine_dir, project_key,
version_key, file_entries)
+ await _extract_archives(task_args.archives, quarantine_dir,
project_key, version_key, file_entries)
except Exception as exc:
await _mark_failed(quarantined, file_entries, f"Archive extraction
failed: {exc}")
await aioshutil.rmtree(quarantine_dir)
@@ -225,7 +215,7 @@ def _extract_archive_to_dir(
async def _extract_archives(
- archives: list[QuarantineArchiveEntry],
+ archives: list[args.QuarantineArchiveEntry],
quarantine_dir: safe.StatePath,
project_key: safe.ProjectKey,
version_key: safe.VersionKey,
@@ -369,7 +359,7 @@ async def _promote(
async def _run_safety_checks(
- archives: list[QuarantineArchiveEntry], quarantine_dir: safe.StatePath
+ archives: list[args.QuarantineArchiveEntry], quarantine_dir: safe.StatePath
) -> tuple[list[sql.QuarantineFileEntryV1], bool]:
file_entries: list[sql.QuarantineFileEntryV1] = []
any_failed = False
diff --git a/atr/tasks/sbom.py b/atr/tasks/sbom.py
index 07df0793..85be0749 100644
--- a/atr/tasks/sbom.py
+++ b/atr/tasks/sbom.py
@@ -27,9 +27,9 @@ import aiofiles.os
import atr.archives as archives
import atr.config as config
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
import atr.models.safe as safe
-import atr.models.schema as schema
import atr.models.sql as sql
import atr.paths as paths
import atr.sbom as sbom
@@ -40,21 +40,6 @@ import atr.util as util
_CONFIG: Final = config.get()
-class ConvertCycloneDX(schema.Strict):
- """Arguments for the task to generate a CycloneDX SBOM."""
-
- artifact_path: safe.StatePath = schema.description("Absolute path to the
artifact")
- output_path: safe.StatePath = schema.description("Absolute path where the
generated SBOM JSON should be written")
- revision: safe.RevisionNumber = schema.description("Revision number")
-
-
-class GenerateCycloneDX(schema.Strict):
- """Arguments for the task to generate a CycloneDX SBOM."""
-
- artifact_path: safe.StatePath = schema.description("Absolute path to the
artifact")
- output_path: safe.StatePath = schema.description("Absolute path where the
generated SBOM JSON should be written")
-
-
class SBOMConversionError(Exception):
"""Custom exception for SBOM conversion failures."""
@@ -85,20 +70,8 @@ class SBOMScoringError(Exception):
self.context = context if (context is not None) else {}
-class FileArgs(schema.Strict):
- project_key: safe.ProjectKey = schema.description("Project name")
- version_key: safe.VersionKey = schema.description("Version name")
- revision_number: safe.RevisionNumber = schema.description("Revision
number")
- file_path: safe.RelPath = schema.description("Relative path to the SBOM
file")
- asf_uid: str | None = None
-
-
-class ScoreArgs(FileArgs):
- previous_release_version: safe.VersionKey | None =
schema.description("Previous release version")
-
-
[email protected]_model(FileArgs)
-async def augment(args: FileArgs) -> results.Results | None:
[email protected]_model(args.FileArgs)
+async def augment(args: args.FileArgs) -> results.Results | None:
revision_str = str(args.revision_number)
path_str = str(args.file_path)
@@ -149,8 +122,8 @@ async def augment(args: FileArgs) -> results.Results | None:
)
[email protected]_model(ConvertCycloneDX)
-async def convert_cyclonedx(args: ConvertCycloneDX) -> results.Results | None:
[email protected]_model(args.ConvertCycloneDX)
+async def convert_cyclonedx(args: args.ConvertCycloneDX) -> results.Results |
None:
"""Generate a JSON CycloneDX SBOM from a given XML SBOM."""
try:
result_data = await _convert_cyclonedx_core(args.artifact_path,
args.output_path, args.revision)
@@ -166,8 +139,8 @@ async def convert_cyclonedx(args: ConvertCycloneDX) ->
results.Results | None:
raise
[email protected]_model(GenerateCycloneDX)
-async def generate_cyclonedx(args: GenerateCycloneDX) -> results.Results |
None:
[email protected]_model(args.GenerateCycloneDX)
+async def generate_cyclonedx(args: args.GenerateCycloneDX) -> results.Results
| None:
"""Generate a CycloneDX SBOM for the given artifact and write it to the
output path."""
try:
result_data = await _generate_cyclonedx_core(args.artifact_path,
args.output_path)
@@ -184,8 +157,8 @@ async def generate_cyclonedx(args: GenerateCycloneDX) ->
results.Results | None:
raise
[email protected]_model(FileArgs)
-async def osv_scan(args: FileArgs) -> results.Results | None:
[email protected]_model(args.FileArgs)
+async def osv_scan(args: args.FileArgs) -> results.Results | None:
revision_str = str(args.revision_number)
path_str = str(args.file_path)
@@ -251,8 +224,8 @@ async def osv_scan(args: FileArgs) -> results.Results |
None:
)
[email protected]_model(FileArgs)
-async def score_qs(args: FileArgs) -> results.Results | None:
[email protected]_model(args.FileArgs)
+async def score_qs(args: args.FileArgs) -> results.Results | None:
path_str = str(args.file_path)
base_dir = paths.get_unfinished_dir_for(args.project_key,
args.version_key, args.revision_number)
@@ -290,8 +263,8 @@ async def score_qs(args: FileArgs) -> results.Results |
None:
)
[email protected]_model(ScoreArgs)
-async def score_tool(args: ScoreArgs) -> results.Results | None:
[email protected]_model(args.ScoreArgs)
+async def score_tool(args: args.ScoreArgs) -> results.Results | None:
path_str = str(args.file_path)
base_dir = paths.get_unfinished_dir_for(args.project_key,
args.version_key, args.revision_number)
diff --git a/atr/tasks/svn.py b/atr/tasks/svn.py
index 7dbad904..37f9fcf2 100644
--- a/atr/tasks/svn.py
+++ b/atr/tasks/svn.py
@@ -22,9 +22,9 @@ import aiofiles.os
import aioshutil
import atr.log as log
+import atr.models.args as args
import atr.models.results as results
import atr.models.safe as safe
-import atr.models.schema as schema
import atr.models.sql as sql
import atr.storage as storage
import atr.tasks.checks as checks
@@ -32,17 +32,6 @@ import atr.tasks.checks as checks
_SVN_BASE_URL: Final[str] = "https://dist.apache.org/repos/dist"
-class SvnImport(schema.Strict):
- """Arguments for the task to import files from SVN."""
-
- svn_url: safe.RelPath
- revision: str
- target_subdirectory: str | None
- project_key: safe.ProjectKey
- version_key: safe.VersionKey
- asf_uid: str
-
-
class SvnImportError(Exception):
"""Custom exception for SVN import failures."""
@@ -51,12 +40,12 @@ class SvnImportError(Exception):
self.details = details or {}
[email protected]_model(SvnImport)
-async def import_files(args: SvnImport) -> results.Results | None:
[email protected]_model(args.SvnImport)
+async def import_files(task_args: args.SvnImport) -> results.Results | None:
"""Import files from SVN into a draft release candidate revision."""
# audit_guidance any file uploads are from known and managed repositories
so file size is not an issue
try:
- result_message = await _import_files_core(args)
+ result_message = await _import_files_core(task_args)
return results.SvnImportFiles(
kind="svn_import",
msg=result_message,
@@ -69,7 +58,7 @@ async def import_files(args: SvnImport) -> results.Results |
None:
raise
-async def _import_files_core(args: SvnImport) -> str:
+async def _import_files_core(args: args.SvnImport) -> str:
"""Core logic to perform the SVN export."""
project_str = str(args.project_key)
diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py
index c4b93419..8d4bc6f1 100644
--- a/atr/tasks/vote.py
+++ b/atr/tasks/vote.py
@@ -17,43 +17,27 @@
import datetime
-import pydantic
-
import atr.db as db
import atr.db.interaction as interaction
import atr.log as log
import atr.mail as mail
+import atr.models.args as args
import atr.models.results as results
import atr.models.safe as safe
-import atr.models.schema as schema
import atr.storage as storage
import atr.tasks.checks as checks
import atr.util as util
-class Initiate(schema.Strict):
- """Arguments for the task to start a vote."""
-
- release_key: str = schema.description("The name of the release to vote on")
- email_to: pydantic.EmailStr = schema.description("The mailing list To
address")
- vote_duration: int = schema.description("Duration of the vote in hours")
- initiator_id: str = schema.description("ASF ID of the vote initiator")
- initiator_fullname: str = schema.description("Full name of the vote
initiator")
- subject: str = schema.description("Subject line for the vote email")
- body: str = schema.description("Body content for the vote email")
- email_cc: list[pydantic.EmailStr] = schema.factory(list)
- email_bcc: list[pydantic.EmailStr] = schema.factory(list)
-
-
class VoteInitiationError(Exception):
pass
[email protected]_model(Initiate)
-async def initiate(args: Initiate) -> results.Results | None:
[email protected]_model(args.Initiate)
+async def initiate(task_args: args.Initiate) -> results.Results | None:
"""Initiate a vote for a release."""
try:
- return await _initiate_core_logic(args)
+ return await _initiate_core_logic(task_args)
except VoteInitiationError as e:
log.error(f"Vote initiation failed: {e}")
@@ -63,34 +47,36 @@ async def initiate(args: Initiate) -> results.Results |
None:
raise
-async def _initiate_core_logic(args: Initiate) -> results.Results | None:
+async def _initiate_core_logic(task_args: args.Initiate) -> results.Results |
None:
"""Get arguments, create an email, and then send it to the recipient."""
log.info("Starting initiate_core")
- safe.ReleaseKey(args.release_key)
+ safe.ReleaseKey(task_args.release_key)
# Validate arguments
- all_addrs = [args.email_to, *args.email_cc, *args.email_bcc]
+ all_addrs = [task_args.email_to, *task_args.email_cc, *task_args.email_bcc]
for addr in all_addrs:
if not (addr.endswith("@apache.org") or addr.endswith(".apache.org")):
log.error(f"Invalid destination email address: {addr}")
raise VoteInitiationError(f"Invalid destination email address:
{addr}")
async with db.session() as data:
- release = await data.release(key=args.release_key, _project=True,
_committee=True).demand(
- VoteInitiationError(f"Release {args.release_key!s} not found")
+ release = await data.release(key=task_args.release_key, _project=True,
_committee=True).demand(
+ VoteInitiationError(f"Release {task_args.release_key!s} not found")
)
latest_revision_number = release.latest_revision_number
if latest_revision_number is None:
- raise VoteInitiationError(f"No revisions found for release
{args.release_key!s}")
+ raise VoteInitiationError(f"No revisions found for release
{task_args.release_key!s}")
ongoing_tasks = await interaction.tasks_ongoing(
release.safe_project_key, release.safe_version_key,
release.safe_latest_revision_number
)
if ongoing_tasks > 0:
- raise VoteInitiationError(f"Cannot start vote for
{args.release_key!s} as {ongoing_tasks} are not complete")
+ raise VoteInitiationError(
+ f"Cannot start vote for {task_args.release_key!s} as
{ongoing_tasks} are not complete"
+ )
# Calculate vote end date
- vote_duration_hours = args.vote_duration
+ vote_duration_hours = task_args.vote_duration
vote_start = datetime.datetime.now(datetime.UTC)
vote_end = vote_start + datetime.timedelta(hours=vote_duration_hours)
@@ -112,36 +98,36 @@ async def _initiate_core_logic(args: Initiate) ->
results.Results | None:
raise VoteInitiationError(error_msg)
# The subject and body have already been substituted by the route handler
- subject = args.subject
- body = args.body
+ subject = task_args.subject
+ body = task_args.body
- permitted_recipients = util.permitted_voting_recipients(args.initiator_id,
release.committee.key)
+ permitted_recipients =
util.permitted_voting_recipients(task_args.initiator_id, release.committee.key)
for addr in all_addrs:
if addr not in permitted_recipients:
log.error(f"Invalid mailing list choice: {addr} not in
{permitted_recipients}")
raise VoteInitiationError("Invalid mailing list choice")
# Create mail message
- log.info(f"Creating mail message for {args.email_to}")
+ log.info(f"Creating mail message for {task_args.email_to}")
message = mail.Message(
- email_sender=f"{args.initiator_id}@apache.org",
- email_to=args.email_to,
+ email_sender=f"{task_args.initiator_id}@apache.org",
+ email_to=task_args.email_to,
subject=subject,
body=body,
- email_cc=args.email_cc,
- email_bcc=args.email_bcc,
+ email_cc=task_args.email_cc,
+ email_bcc=task_args.email_bcc,
)
- async with storage.write(args.initiator_id) as write:
+ async with storage.write(task_args.initiator_id) as write:
wafc = write.as_foundation_committer()
mid, mail_errors = await wafc.mail.send(message,
mail.MailFooterCategory.USER)
# Original success message structure
- all_destinations = [args.email_to, *args.email_cc, *args.email_bcc]
+ all_destinations = [task_args.email_to, *task_args.email_cc,
*task_args.email_bcc]
result = results.VoteInitiate(
kind="vote_initiate",
message="Vote announcement email sent successfully",
- email_to=args.email_to,
+ email_to=task_args.email_to,
vote_end=vote_end_str,
subject=subject,
mid=mid,
@@ -149,7 +135,7 @@ async def _initiate_core_logic(args: Initiate) ->
results.Results | None:
)
if mail_errors:
- log.warning(f"Start vote for {args.release_key}: sending to
{all_destinations} gave errors: {mail_errors}")
+ log.warning(f"Start vote for {task_args.release_key}: sending to
{all_destinations} gave errors: {mail_errors}")
else:
log.info(f"Vote email sent successfully to {all_destinations}")
return result
diff --git a/tests/unit/test_quarantine_task.py
b/tests/unit/test_quarantine_task.py
index 2eaf4efa..095a2cac 100644
--- a/tests/unit/test_quarantine_task.py
+++ b/tests/unit/test_quarantine_task.py
@@ -24,6 +24,7 @@ import unittest.mock as mock
import pytest
+import atr.models.args as args
import atr.models.safe as safe
import atr.models.sql as sql
import atr.storage as storage
@@ -97,7 +98,7 @@ async def
test_extract_archives_discards_staging_dir_on_enotempty_collision(
entries = [sql.QuarantineFileEntryV1(rel_path=archive_rel_path,
size_bytes=7, content_hash="blake3:ghi", errors=[])]
await quarantine._extract_archives(
- [quarantine.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:ghi")],
+ [args.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:ghi")],
quarantine_dir,
safe.ProjectKey("proj"),
safe.VersionKey("1.0"),
@@ -143,7 +144,7 @@ async def
test_extract_archives_discards_staging_dir_when_other_worker_wins(
entries = [sql.QuarantineFileEntryV1(rel_path=archive_rel_path,
size_bytes=7, content_hash="blake3:def", errors=[])]
await quarantine._extract_archives(
- [quarantine.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:def")],
+ [args.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:def")],
quarantine_dir,
safe.ProjectKey("proj"),
safe.VersionKey("1.0"),
@@ -180,7 +181,7 @@ async def
test_extract_archives_propagates_exarch_error_to_file_entry(
with pytest.raises(RuntimeError, match="unsafe zip detected"):
await quarantine._extract_archives(
- [quarantine.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:bad")],
+ [args.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:bad")],
quarantine_dir,
safe.ProjectKey("proj"),
safe.VersionKey("1.0"),
@@ -217,7 +218,7 @@ async def
test_extract_archives_stages_in_temporary_then_promotes(
entries = [sql.QuarantineFileEntryV1(rel_path=archive_rel_path,
size_bytes=7, content_hash="blake3:abc", errors=[])]
await quarantine._extract_archives(
- [quarantine.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:abc")],
+ [args.QuarantineArchiveEntry(rel_path=archive_rel_path,
content_hash="blake3:abc")],
quarantine_dir,
safe.ProjectKey("proj"),
safe.VersionKey("1.0"),
@@ -337,7 +338,7 @@ async def test_run_safety_checks_safe_archive(tmp_path:
pathlib.Path):
archive_path = tmp_path / "safe.tar.gz"
_create_safe_tar_gz(archive_path)
- archives = [quarantine.QuarantineArchiveEntry(rel_path="safe.tar.gz",
content_hash="abc123")]
+ archives = [args.QuarantineArchiveEntry(rel_path="safe.tar.gz",
content_hash="abc123")]
entries, any_failed = await quarantine._run_safety_checks(archives,
tmp_path)
assert not any_failed
@@ -352,7 +353,7 @@ async def test_run_safety_checks_unsafe_archive(tmp_path:
pathlib.Path):
archive_path = tmp_path / "unsafe.tar.gz"
_create_traversal_tar_gz(archive_path)
- archives = [quarantine.QuarantineArchiveEntry(rel_path="unsafe.tar.gz",
content_hash="def456")]
+ archives = [args.QuarantineArchiveEntry(rel_path="unsafe.tar.gz",
content_hash="def456")]
entries, any_failed = await quarantine._run_safety_checks(archives,
tmp_path)
assert any_failed
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]