This is an automated email from the ASF dual-hosted git repository.
arm pushed a commit to branch arm
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
The following commit(s) were added to refs/heads/arm by this push:
new a72a347a #643 - Add safe.RevisionNumber and utilise unsafe.UnsafeStr
for remaining str types.
a72a347a is described below
commit a72a347ab5bdba54017e9d3aca7ec618bd885381
Author: Alastair McFarlane <[email protected]>
AuthorDate: Mon Mar 9 17:55:09 2026 +0000
#643 - Add safe.RevisionNumber and utilise unsafe.UnsafeStr for remaining
str types.
---
atr/admin/__init__.py | 13 +++++++---
atr/api/__init__.py | 47 +++++++++++++++++++----------------
atr/attestable.py | 41 ++++++++++++++++--------------
atr/blueprints/api.py | 2 +-
atr/blueprints/common.py | 51 ++++++++++++++++++++++++--------------
atr/blueprints/get.py | 2 +-
atr/blueprints/post.py | 2 +-
atr/construct.py | 8 +++---
atr/db/interaction.py | 33 ++++++++++++++----------
atr/get/announce.py | 2 +-
atr/get/checks.py | 2 +-
atr/get/committees.py | 11 ++++----
atr/get/keys.py | 13 +++++-----
atr/get/manual.py | 6 ++---
atr/get/projects.py | 21 ++++++++--------
atr/get/ref.py | 9 ++++---
atr/get/test.py | 7 +++---
atr/get/voting.py | 6 ++---
atr/merge.py | 8 +++---
atr/models/api.py | 8 +++---
atr/models/safe.py | 25 ++++++++++++++++---
atr/models/sql.py | 10 ++++++++
atr/models/unsafe.py | 5 +++-
atr/paths.py | 6 ++---
atr/post/announce.py | 6 ++---
atr/post/draft.py | 12 +++------
atr/post/keys.py | 9 ++++---
atr/post/manual.py | 2 +-
atr/post/projects.py | 11 ++++----
atr/post/upload.py | 9 ++++---
atr/post/voting.py | 6 ++---
atr/shared/ignores.py | 3 ++-
atr/ssh.py | 4 ++-
atr/storage/readers/releases.py | 5 ++--
atr/storage/writers/announce.py | 12 ++++-----
atr/storage/writers/checks.py | 4 +--
atr/storage/writers/release.py | 16 +++++++-----
atr/storage/writers/revision.py | 18 ++++++++------
atr/storage/writers/vote.py | 6 ++---
atr/tasks/__init__.py | 22 ++++++++--------
atr/tasks/checks/__init__.py | 18 +++++++-------
atr/tasks/checks/compare.py | 2 +-
atr/tasks/quarantine.py | 2 +-
atr/tasks/vote.py | 2 +-
atr/web.py | 7 +++---
atr/worker.py | 5 ++--
tests/unit/test_create_revision.py | 15 +++++++++--
47 files changed, 313 insertions(+), 221 deletions(-)
diff --git a/atr/admin/__init__.py b/atr/admin/__init__.py
index c6763110..3fe6e7dc 100644
--- a/atr/admin/__init__.py
+++ b/atr/admin/__init__.py
@@ -599,7 +599,8 @@ async def ongoing_tasks_get(
) -> web.QuartResponse:
project = safe.ProjectName(project_name)
version = safe.VersionName(version_name)
- return await _ongoing_tasks(session, project, version, revision)
+ revision_number = safe.RevisionNumber(revision)
+ return await _ongoing_tasks(session, project, version, revision_number)
@admin.post("/ongoing-tasks/<project_name>/<version_name>/<revision>")
@@ -608,7 +609,8 @@ async def ongoing_tasks_post(
) -> web.QuartResponse:
project = safe.ProjectName(project_name)
version = safe.VersionName(version_name)
- return await _ongoing_tasks(session, project, version, revision)
+ revision_number = safe.RevisionNumber(revision)
+ return await _ongoing_tasks(session, project, version, revision_number)
@admin.get("/performance")
@@ -1172,13 +1174,16 @@ async def
_get_filesystem_dirs_unfinished(filesystem_dirs: list[str]) -> None:
async def _ongoing_tasks(
- session: web.Committer, project_name: safe.ProjectName, version_name:
safe.VersionName, revision: str
+ session: web.Committer,
+ project_name: safe.ProjectName,
+ version_name: safe.VersionName,
+ revision: safe.RevisionNumber,
) -> web.QuartResponse:
try:
ongoing = await interaction.tasks_ongoing(project_name, version_name,
revision)
return web.TextResponse(str(ongoing))
except Exception:
- log.exception(f"Error fetching ongoing task count for {project_name!s}
{version_name!s} rev {revision}:")
+ log.exception(f"Error fetching ongoing task count for {project_name!s}
{version_name!s} rev {revision!s}:")
return web.TextResponse("")
diff --git a/atr/api/__init__.py b/atr/api/__init__.py
index 97c5b750..6ff0c505 100644
--- a/atr/api/__init__.py
+++ b/atr/api/__init__.py
@@ -39,6 +39,7 @@ import atr.log as log
import atr.models as models
import atr.models.safe as safe
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.paths as paths
import atr.principal as principal
import atr.storage as storage
@@ -100,7 +101,7 @@ async def checks_list_revision(
_checks_list: Literal["checks/list"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str,
+ revision: safe.RevisionNumber,
) -> DictResponse:
"""
URL: GET /checks/list/<project_name>/<version_name>/<revision>
@@ -121,7 +122,7 @@ async def checks_list_revision(
exceptions.NotFound(f"Release '{release_name}' does not exist")
)
- revision_result = await data.revision(release_name=release_name,
number=revision).get()
+ revision_result = await data.revision(release_name=release_name,
number=str(revision)).get()
if revision_result is None:
raise exceptions.NotFound(f"Revision '{revision}' does not exist
for release '{release_name}'")
@@ -130,7 +131,7 @@ async def checks_list_revision(
return models.api.ChecksListResults(
endpoint="/checks/list",
checks=check_results,
- checks_revision=revision,
+ checks_revision=str(revision),
current_phase=release_result.phase,
).model_dump(mode="json"), 200
@@ -141,7 +142,7 @@ async def checks_ongoing(
_checks_ongoing: Literal["checks/ongoing"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str | None = None,
+ revision: safe.RevisionNumber | None = None,
) -> DictResponse:
"""
URL: GET /checks/ongoing/<project_name>/<version_name>[/<revision>]
@@ -179,7 +180,7 @@ async def checks_ongoing(
@quart_schema.validate_response(models.api.CommitteeGetResults, 200)
async def committee_get(
_committee_get: Literal["committee/get"],
- name: str,
+ name: unsafe.UnsafeStr,
) -> DictResponse:
"""
URL: GET /committee/get/<name>
@@ -192,7 +193,9 @@ async def committee_get(
"simple-example".
"""
async with db.session() as data:
- committee = await
data.committee(name=name).demand(exceptions.NotFound(f"Committee '{name}' was
not found"))
+ committee = await data.committee(name=str(name)).demand(
+ exceptions.NotFound(f"Committee '{name!s}' was not found")
+ )
return models.api.CommitteeGetResults(
endpoint="/committee/get",
committee=committee,
@@ -203,7 +206,7 @@ async def committee_get(
@quart_schema.validate_response(models.api.CommitteeKeysResults, 200)
async def committee_keys(
_committee_keys: Literal["committee/keys"],
- name: str,
+ name: unsafe.UnsafeStr,
) -> DictResponse:
"""
URL: GET /committee/keys/<name>
@@ -216,8 +219,8 @@ async def committee_keys(
"simple-example".
"""
async with db.session() as data:
- committee = await data.committee(name=name,
_public_signing_keys=True).demand(
- exceptions.NotFound(f"Committee '{name}' was not found")
+ committee = await data.committee(name=str(name),
_public_signing_keys=True).demand(
+ exceptions.NotFound(f"Committee '{name!s}' was not found")
)
return models.api.CommitteeKeysResults(
endpoint="/committee/keys",
@@ -229,7 +232,7 @@ async def committee_keys(
@quart_schema.validate_response(models.api.CommitteeProjectsResults, 200)
async def committee_projects(
_committee_projects: Literal["committee/projects"],
- name: str,
+ name: unsafe.UnsafeStr,
) -> DictResponse:
"""
URL: GET /committee/projects/<name>
@@ -242,8 +245,8 @@ async def committee_projects(
"simple-example".
"""
async with db.session() as data:
- committee = await data.committee(name=name, _projects=True).demand(
- exceptions.NotFound(f"Committee '{name}' was not found")
+ committee = await data.committee(name=str(name),
_projects=True).demand(
+ exceptions.NotFound(f"Committee '{name!s}' was not found")
)
return models.api.CommitteeProjectsResults(
endpoint="/committee/projects",
@@ -583,7 +586,7 @@ async def key_delete(
@quart_schema.validate_response(models.api.KeyGetResults, 200)
async def key_get(
_key_get: Literal["key/get"],
- fingerprint: str,
+ fingerprint: unsafe.UnsafeStr,
) -> DictResponse:
"""
URL: GET /key/get/<fingerprint>
@@ -593,8 +596,8 @@ async def key_get(
All public OpenPGP keys stored within the database are accessible.
"""
async with db.session() as data:
- key = await
data.public_signing_key(fingerprint=fingerprint.lower()).demand(
- exceptions.NotFound(f"Key '{fingerprint}' not found")
+ key = await
data.public_signing_key(fingerprint=str(fingerprint).lower()).demand(
+ exceptions.NotFound(f"Key '{fingerprint!s}' not found")
)
return models.api.KeyGetResults(
endpoint="/key/get",
@@ -669,7 +672,7 @@ async def keys_upload(
@quart_schema.validate_response(models.api.KeysUserResults, 200)
async def keys_user(
_keys_user: Literal["keys/user"],
- asf_uid: str,
+ asf_uid: unsafe.UnsafeStr,
) -> DictResponse:
"""
URL: GET /keys/user/<asf_uid>
@@ -677,7 +680,7 @@ async def keys_user(
List public OpenPGP keys by the ASF UID of a user.
"""
async with db.session() as data:
- keys = await data.public_signing_key(apache_uid=asf_uid).all()
+ keys = await data.public_signing_key(apache_uid=str(asf_uid)).all()
return models.api.KeysUserResults(
endpoint="/keys/user",
keys=keys,
@@ -1068,7 +1071,7 @@ async def release_paths(
_release_paths: Literal["release/paths"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str | None = None,
+ revision: safe.RevisionNumber | None = None,
) -> DictResponse:
"""
URL: GET /release/paths/<project_name>/<version_name>[/<revision>]
@@ -1081,8 +1084,8 @@ async def release_paths(
if revision is None:
dir_path = paths.release_directory(release)
else:
- await data.revision(release_name=release_name,
number=revision).demand(exceptions.NotFound())
- dir_path = paths.release_directory_version(release) / revision
+ await data.revision(release_name=release_name,
number=str(revision)).demand(exceptions.NotFound())
+ dir_path = paths.release_directory_version(release) / str(revision)
if not (await aiofiles.os.path.isdir(dir_path)):
raise exceptions.NotFound("Files not found")
files: list[str] = [str(path) for path in [p async for p in
util.paths_recursive(dir_path)]]
@@ -1322,7 +1325,7 @@ async def ssh_key_delete(
@rate_limiter.rate_limit(10, datetime.timedelta(hours=1))
async def ssh_keys_list(
_ssh_keys_list: Literal["ssh-keys/list"],
- asf_uid: str,
+ asf_uid: unsafe.UnsafeStr,
query_args: models.api.SshKeysListQuery,
) -> DictResponse:
"""
@@ -1335,7 +1338,7 @@ async def ssh_keys_list(
async with db.session() as data:
statement = (
sqlmodel.select(sql.SSHKey)
- .where(sql.SSHKey.asf_uid == asf_uid)
+ .where(sql.SSHKey.asf_uid == str(asf_uid))
.limit(query_args.limit)
.offset(query_args.offset)
.order_by(via(sql.SSHKey.fingerprint).asc())
diff --git a/atr/attestable.py b/atr/attestable.py
index f5e8592e..a020776d 100644
--- a/atr/attestable.py
+++ b/atr/attestable.py
@@ -36,31 +36,34 @@ if TYPE_CHECKING:
def attestable_checks_path(
- project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str
+ project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: safe.RevisionNumber
) -> pathlib.Path:
- return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number}.checks.json"
+ return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number!s}.checks.json"
def attestable_path(
- project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str
+ project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: safe.RevisionNumber
) -> pathlib.Path:
- return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number}.json"
+ return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number!s}.json"
def attestable_paths_path(
- project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str
+ project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: safe.RevisionNumber
) -> pathlib.Path:
- return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number}.paths.json"
+ return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number!s}.paths.json"
def github_tp_payload_path(
- project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str
+ project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: safe.RevisionNumber
) -> pathlib.Path:
- return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number}.github-tp.json"
+ return paths.get_attestable_dir() / str(project_name) / str(version_name)
/ f"{revision_number!s}.github-tp.json"
async def github_tp_payload_write(
- project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str, github_payload: dict[str, Any]
+ project_name: safe.ProjectName,
+ version_name: safe.VersionName,
+ revision_number: safe.RevisionNumber,
+ github_payload: dict[str, Any],
) -> None:
payload_path = github_tp_payload_path(project_name, version_name,
revision_number)
await util.atomic_write_file(payload_path, json.dumps(github_payload,
indent=2))
@@ -69,7 +72,7 @@ async def github_tp_payload_write(
async def load(
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
) -> models.AttestableV1 | None:
file_path = attestable_path(project_name, version_name, revision_number)
if not await aiofiles.os.path.isfile(file_path):
@@ -86,7 +89,7 @@ async def load(
async def load_checks(
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
) -> dict[str, dict[str, str]]:
file_path = attestable_checks_path(project_name, version_name,
revision_number)
# TODO: Once we're sure everyone is on V2, we should be strict about
failures here,
@@ -108,7 +111,7 @@ async def load_checks(
async def load_paths(
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
) -> dict[str, str] | None:
file_path = attestable_paths_path(project_name, version_name,
revision_number)
if await aiofiles.os.path.isfile(file_path):
@@ -174,7 +177,7 @@ async def paths_to_hashes_and_sizes(directory:
pathlib.Path) -> tuple[dict[str,
async def write_checks_data(
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
rel_path: str,
checks: dict[str, str],
) -> None:
@@ -198,7 +201,7 @@ async def write_checks_data(
async def write_files_data(
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
release_policy: dict[str, Any] | None,
uploader_uid: str,
previous: models.AttestableV1 | None,
@@ -222,7 +225,7 @@ def _compute_hashes_with_attribution( # noqa: C901
path_to_size: dict[str, int],
previous: models.AttestableV1 | None,
uploader_uid: str,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
) -> dict[str, models.HashEntry]:
previous_hash_to_paths: dict[str, set[str]] = {}
if previous is not None:
@@ -243,7 +246,7 @@ def _compute_hashes_with_attribution( # noqa: C901
if hash_ref not in new_hashes:
new_hashes[hash_ref] = models.HashEntry(
size=file_size,
- uploaders=[(uploader_uid, revision_number)],
+ uploaders=[(uploader_uid, str(revision_number))],
basenames=sorted(current_basenames),
)
continue
@@ -256,8 +259,8 @@ def _compute_hashes_with_attribution( # noqa: C901
if len(current_paths) > len(previous_paths):
existing_entries = set(new_hashes[hash_ref].uploaders)
- if (uploader_uid, revision_number) not in existing_entries:
- new_hashes[hash_ref].uploaders.append((uploader_uid,
revision_number))
+ if (uploader_uid, str(revision_number)) not in existing_entries:
+ new_hashes[hash_ref].uploaders.append((uploader_uid,
str(revision_number)))
return new_hashes
@@ -265,7 +268,7 @@ def _compute_hashes_with_attribution( # noqa: C901
def _generate_files_data(
path_to_hash: dict[str, str],
path_to_size: dict[str, int],
- revision_number: str,
+ revision_number: safe.RevisionNumber,
release_policy: dict[str, Any] | None,
uploader_uid: str,
previous: models.AttestableV1 | None,
diff --git a/atr/blueprints/api.py b/atr/blueprints/api.py
index 5df652e7..f04428cb 100644
--- a/atr/blueprints/api.py
+++ b/atr/blueprints/api.py
@@ -73,7 +73,7 @@ def typed(func: Callable[..., Any]) -> Callable[..., Any]:
query_safe_params = common.safe_params_for_type(query_param[1]) if
query_param is not None else []
async def wrapper(*_args: Any, **kwargs: Any) -> Any:
- await common.run_validators(kwargs, validated_params)
+ await common.validate_params(kwargs, validated_params)
kwargs.update(literal_params)
if body_param is not None:
diff --git a/atr/blueprints/common.py b/atr/blueprints/common.py
index 670d52a2..b9c2d219 100644
--- a/atr/blueprints/common.py
+++ b/atr/blueprints/common.py
@@ -38,7 +38,7 @@ QUART_CONVERTERS: dict[Any, str] = {
unsafe.Path: "path",
}
-VALIDATED_TYPES: set[Any] = {safe.ProjectName, safe.VersionName}
+VALIDATED_TYPES: set[Any] = {safe.ProjectName, safe.RevisionNumber,
safe.VersionName, unsafe.UnsafeStr}
async def authenticate() -> web.Committer:
@@ -101,12 +101,7 @@ def build_path(
form_param = (param_name, hint)
continue
- segment = _param_to_segment(param_name, hint, func.__name__)
- segments.append(segment)
- if hint in VALIDATED_TYPES:
- validated_params.append((param_name, hint))
- elif get_origin(hint) is Literal:
- literal_params[param_name] = str(get_args(hint)[0])
+ _classify_url_param(param_name, hint, func.__name__, segments,
validated_params, literal_params)
path = "/" + "/".join(segments)
return path, validated_params, literal_params, form_param, public
@@ -165,18 +160,12 @@ def build_api_path(
inner, is_optional = _unwrap_optional(hint)
if is_optional:
- segment = _param_to_segment(param_name, inner, func.__name__)
- segments.append(segment)
+ segments.append(_param_to_segment(param_name, inner,
func.__name__))
optional_params.append(param_name)
# Note - this means that safe types which are optional will not
get validated - no current use case for this
continue
- segment = _param_to_segment(param_name, hint, func.__name__)
- segments.append(segment)
- if hint in VALIDATED_TYPES:
- validated_params.append((param_name, hint))
- elif get_origin(hint) is Literal:
- literal_params[param_name] = str(get_args(hint)[0])
+ _classify_url_param(param_name, hint, func.__name__, segments,
validated_params, literal_params)
path = "/" + "/".join(segments)
return path, validated_params, literal_params, body_param, query_param,
optional_params
@@ -196,9 +185,9 @@ def safe_params_for_type(cls: type) -> list[tuple[str,
type]]:
return [(name, hint) for name, hint in hints.items() if hint in
VALIDATED_TYPES]
-async def run_validators(kwargs: dict[str, Any], validated_params:
list[tuple[str, type]]) -> None:
+async def validate_params(kwargs: dict[str, Any], known_params:
list[tuple[str, type]]) -> None:
"""Validate URL parameters in order, using the type-specific validators."""
- for param_name, param_type in validated_params:
+ for param_name, param_type in known_params:
raw = kwargs[param_name]
if param_type is safe.ProjectName:
try:
@@ -210,6 +199,13 @@ async def run_validators(kwargs: dict[str, Any],
validated_params: list[tuple[st
kwargs[param_name] = safe.VersionName(raw)
except ValueError:
raise base.ASFQuartException(f"Version name {param_name!r} is
invalid. ")
+ elif param_type is safe.RevisionNumber:
+ try:
+ kwargs[param_name] = safe.RevisionNumber(raw)
+ except ValueError:
+ raise base.ASFQuartException(f"Revision number {param_name!r}
is invalid. ")
+ elif param_type is unsafe.UnsafeStr:
+ kwargs[param_name] = unsafe.UnsafeStr(raw)
async def validate_safe_fields(
@@ -227,12 +223,31 @@ async def validate_safe_fields(
value = getattr(instance, name, None)
if value is not None:
temp[name] = str(value)
- await run_validators(temp, [(n, t) for n, t in safe_params if n in temp])
+ await validate_params(temp, [(n, t) for n, t in safe_params if n in temp])
for name, _ in safe_params:
if name in temp:
setattr(instance, name, temp[name])
+def _classify_url_param(
+ param_name: str,
+ hint: Any,
+ func_name: str,
+ segments: list[str],
+ validated_params: list[tuple[str, type]],
+ literal_params: dict[str, str],
+) -> None:
+ """Build a URL segment for a parameter and classify it as validated or
literal."""
+ segment = _param_to_segment(param_name, hint, func_name)
+ segments.append(segment)
+ if hint in VALIDATED_TYPES:
+ validated_params.append((param_name, hint))
+ elif get_origin(hint) is Literal:
+ literal_params[param_name] = str(get_args(hint)[0])
+ elif hint is str:
+ raise TypeError(f"Parameter {param_name!r} in {func_name} is unguarded
str")
+
+
def _is_body_type(hint: Any) -> bool:
"""Check if a type hint is a pydantic BaseModel subclass (but not a
Form)."""
if not isinstance(hint, type):
diff --git a/atr/blueprints/get.py b/atr/blueprints/get.py
index ea48245a..05e597b5 100644
--- a/atr/blueprints/get.py
+++ b/atr/blueprints/get.py
@@ -65,7 +65,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]:
async def wrapper(*_args: Any, **kwargs: Any) -> Any:
enhanced_session = await common.authenticate_public() if public else
await common.authenticate()
- await common.run_validators(kwargs, validated_params)
+ await common.validate_params(kwargs, validated_params)
kwargs.update(literal_params)
start_time_ns = time.perf_counter_ns()
diff --git a/atr/blueprints/post.py b/atr/blueprints/post.py
index 77fdd44a..6c585507 100644
--- a/atr/blueprints/post.py
+++ b/atr/blueprints/post.py
@@ -70,7 +70,7 @@ def typed(func: Callable[..., Any]) -> web.RouteFunction[Any]:
async def wrapper(*_args: Any, **kwargs: Any) -> Any:
enhanced_session = await common.authenticate_public() if public else
await common.authenticate()
- await common.run_validators(kwargs, validated_params)
+ await common.validate_params(kwargs, validated_params)
kwargs.update(literal_params)
if check_access and (enhanced_session is not None) and
(project_name_var is not None):
diff --git a/atr/construct.py b/atr/construct.py
index f4ca7326..32a18770 100644
--- a/atr/construct.py
+++ b/atr/construct.py
@@ -56,7 +56,7 @@ class AnnounceReleaseOptions:
fullname: str
project_name: safe.ProjectName
version_name: safe.VersionName
- revision_number: str
+ revision_number: safe.RevisionNumber
@dataclasses.dataclass
@@ -65,7 +65,7 @@ class StartVoteOptions:
fullname: str
project_name: safe.ProjectName
version_name: safe.VersionName
- revision_number: str
+ revision_number: safe.RevisionNumber
vote_duration: int
@@ -102,7 +102,7 @@ async def announce_release_subject_and_body(
raise RuntimeError(f"Release {options.project_name}
{options.version_name} has no committee")
committee = release.committee
- revision = await data.revision(release_name=release.name,
number=options.revision_number).get()
+ revision = await data.revision(release_name=release.name,
number=str(options.revision_number)).get()
revision_number = revision.number if revision else ""
revision_tag = revision.tag if (revision and revision.tag) else ""
@@ -206,7 +206,7 @@ async def start_vote_subject_and_body(subject: str, body:
str, options: StartVot
raise RuntimeError(f"Release {options.project_name}
{options.version_name} has no committee")
committee = release.committee
- revision = await data.revision(release_name=release.name,
number=options.revision_number).get()
+ revision = await data.revision(release_name=release.name,
number=str(options.revision_number)).get()
revision_number = revision.number if revision else ""
revision_tag = revision.tag if (revision and revision.tag) else ""
diff --git a/atr/db/interaction.py b/atr/db/interaction.py
index e921f208..35f86a8b 100644
--- a/atr/db/interaction.py
+++ b/atr/db/interaction.py
@@ -164,13 +164,13 @@ async def candidates(project: sql.Project) ->
list[sql.Release]:
async def checks_for(
release: sql.Release,
- revision: str | None = None,
+ revision: safe.RevisionNumber | None = None,
rel_path: str | None = None,
caller_data: db.Session | None = None,
) -> list[sql.CheckResult]:
"""Get the check results for a release, optionally for a specific revision
and/or file path."""
if revision is None:
- revision = release.unwrap_revision_number
+ revision = release.safe_latest_revision_number
file_path_checks = await attestable.load_checks(release.safe_project_name,
release.safe_version_name, revision)
if file_path_checks:
if rel_path is not None:
@@ -194,7 +194,10 @@ async def checks_for(
async def count_checks_for_revision_by_status(
- status: sql.CheckResultStatus, release: sql.Release, revision_number: str,
caller_data: db.Session | None = None
+ status: sql.CheckResultStatus,
+ release: sql.Release,
+ revision_number: safe.RevisionNumber,
+ caller_data: db.Session | None = None,
):
file_path_checks = await attestable.load_checks(
release.safe_project_name, release.safe_version_name, revision_number
@@ -228,14 +231,18 @@ async def full_releases(project: sql.Project) ->
list[sql.Release]:
return await releases_by_phase(project, sql.ReleasePhase.RELEASE)
-async def has_blocker_checks(release: sql.Release, revision_number: str,
caller_data: db.Session | None = None) -> bool:
+async def has_blocker_checks(
+ release: sql.Release, revision_number: safe.RevisionNumber, caller_data:
db.Session | None = None
+) -> bool:
count = await count_checks_for_revision_by_status(
sql.CheckResultStatus.BLOCKER, release, revision_number, caller_data
)
return count > 0
-async def has_failing_checks(release: sql.Release, revision_number: str,
caller_data: db.Session | None = None) -> bool:
+async def has_failing_checks(
+ release: sql.Release, revision_number: safe.RevisionNumber, caller_data:
db.Session | None = None
+) -> bool:
count = await count_checks_for_revision_by_status(
sql.CheckResultStatus.FAILURE, release, revision_number, caller_data
)
@@ -244,7 +251,7 @@ async def has_failing_checks(release: sql.Release,
revision_number: str, caller_
async def latest_info(
project_name: safe.ProjectName, version_name: safe.VersionName
-) -> tuple[str, str, datetime.datetime] | None:
+) -> tuple[safe.RevisionNumber, str, datetime.datetime] | None:
"""Get the name, editor, and timestamp of the latest revision."""
release_name = sql.release_name(project_name, version_name)
async with db.session() as data:
@@ -258,7 +265,7 @@ async def latest_info(
revision = await data.revision(release_name=str(release_name),
number=release.latest_revision_number).get()
if not revision:
return None
- return revision.number, revision.asfuid, revision.created
+ return revision.safe_number, revision.asfuid, revision.created
async def latest_revision(release: sql.Release, caller_data: db.Session | None
= None) -> sql.Revision | None:
@@ -298,7 +305,7 @@ async def release_ready_for_vote( # noqa: C901
session: web.Committer,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str,
+ revision: safe.RevisionNumber,
data: db.Session,
manual_vote: bool = False,
) -> tuple[sql.Release, sql.Committee] | str:
@@ -315,7 +322,7 @@ async def release_ready_for_vote( # noqa: C901
if selected_revision_number is None:
return "No revision found for this release"
- if selected_revision_number != revision:
+ if release.safe_latest_revision_number != revision:
return "This revision does not match the revision you are voting on"
committee = release.committee
@@ -397,7 +404,7 @@ def task_recipient_get(latest_vote_task: sql.Task) -> str |
None:
async def tasks_ongoing(
- project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str | None = None
+ project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: safe.RevisionNumber | None = None
) -> int:
tasks = sqlmodel.select(sqlalchemy.func.count()).select_from(sql.Task)
async with db.session() as data:
@@ -405,7 +412,7 @@ async def tasks_ongoing(
sql.Task.project_name == str(project_name),
sql.Task.version_name == str(version_name),
sql.Task.revision_number
- == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is
None) else revision_number),
+ == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is
None) else str(revision_number)),
sql.validate_instrumented_attribute(sql.Task.status).in_([sql.TaskStatus.QUEUED,
sql.TaskStatus.ACTIVE]),
)
result = await data.execute(query)
@@ -415,7 +422,7 @@ async def tasks_ongoing(
async def tasks_ongoing_revision(
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str | None = None,
+ revision_number: safe.RevisionNumber | None = None,
) -> tuple[int, str | None]:
via = sql.validate_instrumented_attribute
subquery = (
@@ -438,7 +445,7 @@ async def tasks_ongoing_revision(
.where(
sql.Task.project_name == str(project_name),
sql.Task.version_name == str(version_name),
- sql.Task.revision_number == (subquery if (revision_number is None)
else revision_number),
+ sql.Task.revision_number == (subquery if (revision_number is None)
else str(revision_number)),
sql.validate_instrumented_attribute(sql.Task.status).in_(
[sql.TaskStatus.QUEUED, sql.TaskStatus.ACTIVE],
),
diff --git a/atr/get/announce.py b/atr/get/announce.py
index b33d4069..5054f68d 100644
--- a/atr/get/announce.py
+++ b/atr/get/announce.py
@@ -69,7 +69,7 @@ async def selected(
fullname=session.fullname,
project_name=project_name,
version_name=version_name,
- revision_number=latest_revision_number,
+ revision_number=release.safe_latest_revision_number,
)
default_subject, default_body = await
construct.announce_release_subject_and_body(
default_subject_template, default_body_template, options
diff --git a/atr/get/checks.py b/atr/get/checks.py
index 03e26b91..52ef3a6d 100644
--- a/atr/get/checks.py
+++ b/atr/get/checks.py
@@ -146,7 +146,7 @@ async def selected_revision(
_checks: Literal["checks"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
) -> web.QuartResponse:
"""
URL: /checks/<project_name>/<version_name>/<revision_number>
diff --git a/atr/get/committees.py b/atr/get/committees.py
index f565d8b9..1ab70af9 100644
--- a/atr/get/committees.py
+++ b/atr/get/committees.py
@@ -24,6 +24,7 @@ import atr.blueprints.get as get
import atr.db as db
import atr.form as form
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.post as post
import atr.shared as shared
import atr.template as template
@@ -48,17 +49,17 @@ async def directory(_session: web.Public, _committees:
Literal["committees"]) ->
@get.typed
-async def view(_session: web.Public, _committees: Literal["committees"], name:
str) -> str:
+async def view(_session: web.Public, _committees: Literal["committees"], name:
unsafe.UnsafeStr) -> str:
"""
URL: /committees/<name>
"""
# TODO: Could also import this from keys.py
async with db.session() as data:
committee = await data.committee(
- name=name,
+ name=str(name),
_projects=True,
_public_signing_keys=True,
- ).demand(base.ASFQuartException(f"Committee {name} not found",
errorcode=404))
+ ).demand(base.ASFQuartException(f"Committee {name!s} not found",
errorcode=404))
project_list = list(committee.projects)
for project in project_list:
# Workaround for the usual loading problem
@@ -74,8 +75,8 @@ async def view(_session: web.Public, _committees:
Literal["committees"], name: s
model_cls=shared.keys.UpdateCommitteeKeysForm,
action=util.as_url(post.keys.keys),
submit_label="Regenerate KEYS file",
- defaults={"committee_name": name},
+ defaults={"committee_name": str(name)},
empty=True,
),
- is_standing=util.committee_is_standing(name),
+ is_standing=util.committee_is_standing(str(name)),
)
diff --git a/atr/get/keys.py b/atr/get/keys.py
index 18223ff8..d0fba17f 100644
--- a/atr/get/keys.py
+++ b/atr/get/keys.py
@@ -27,6 +27,7 @@ import atr.db as db
import atr.form as form
import atr.htm as htm
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.post as post
import atr.shared as shared
import atr.storage as storage
@@ -72,14 +73,14 @@ async def add(_session: web.Committer, _keys_add:
Literal["keys/add"]) -> str:
@get.typed
-async def details(session: web.Committer, _keys_details:
Literal["keys/details"], fingerprint: str) -> str:
+async def details(session: web.Committer, _keys_details:
Literal["keys/details"], fingerprint: unsafe.UnsafeStr) -> str:
"""
URL: /keys/details/<fingerprint>
Display details for a specific OpenPGP key.
"""
- fingerprint = fingerprint.lower()
+ key_fingerprint = str(fingerprint).lower()
async with db.session() as data:
- key, is_owner = await _key_and_is_owner(data, session, fingerprint)
+ key, is_owner = await _key_and_is_owner(data, session, key_fingerprint)
user_committees = []
if is_owner:
project_list = session.committees + session.projects
@@ -154,7 +155,7 @@ async def details(session: web.Committer, _keys_details:
Literal["keys/details"]
checkboxes = _render_committee_checkboxes(committee_choices,
current_committee_names)
pmc_div.form(
method="post",
- action=util.as_url(post.keys.details, fingerprint=fingerprint),
+ action=util.as_url(post.keys.details, fingerprint=key_fingerprint),
)[
form.csrf_input(),
checkboxes,
@@ -182,7 +183,7 @@ async def details(session: web.Committer, _keys_details:
Literal["keys/details"]
@get.typed
async def export(
- _session: web.Committer, _keys_export: Literal["keys/export"],
committee_name: str
+ _session: web.Committer, _keys_export: Literal["keys/export"],
committee_name: unsafe.UnsafeStr
) -> web.TextResponse:
"""
URL: /keys/export/<committee_name>
@@ -190,7 +191,7 @@ async def export(
"""
async with storage.write() as write:
wafc = write.as_foundation_committer()
- keys_file_text = await wafc.keys.keys_file_text(committee_name)
+ keys_file_text = await wafc.keys.keys_file_text(str(committee_name))
return web.TextResponse(keys_file_text)
diff --git a/atr/get/manual.py b/atr/get/manual.py
index 7d90f0e9..e9ba1162 100644
--- a/atr/get/manual.py
+++ b/atr/get/manual.py
@@ -71,7 +71,7 @@ async def start_selected_revision(
_manual_start: Literal["manual/start"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str,
+ revision: safe.RevisionNumber,
) -> web.WerkzeugResponse | str:
"""
URL: /manual/start/<project_name>/<version_name>/<revision>
@@ -87,12 +87,12 @@ async def start_selected_revision(
error=error,
project_name=str(project_name),
version_name=str(version_name),
- revision=revision,
+ revision=str(revision),
)
case (release, _committee):
pass
- content = await _render_page(release=release, revision=revision)
+ content = await _render_page(release=release, revision=str(revision))
return await template.blank(
title=f"Start manual vote on {release.project.short_display_name}
{release.version}", content=content
diff --git a/atr/get/projects.py b/atr/get/projects.py
index ec992b13..d57622cf 100644
--- a/atr/get/projects.py
+++ b/atr/get/projects.py
@@ -35,6 +35,7 @@ import atr.get.start as start
import atr.htm as htm
import atr.models.safe as safe
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.post as post
import atr.registry as registry
import atr.shared as shared
@@ -46,33 +47,33 @@ import atr.web as web
@get.typed
async def add_project(
- session: web.Committer, _project_add: Literal["project/add"],
committee_name: str
+ session: web.Committer, _project_add: Literal["project/add"],
committee_name: unsafe.UnsafeStr
) -> web.WerkzeugResponse | str:
"""
URL: /project/add/<committee_name>
"""
- await session.check_access_committee(committee_name)
+ await session.check_access_committee(str(committee_name))
async with db.session() as data:
- committee = await data.committee(name=committee_name).demand(
- base.ASFQuartException(f"Committee {committee_name} not found",
errorcode=404)
+ committee = await data.committee(name=str(committee_name)).demand(
+ base.ASFQuartException(f"Committee {committee_name!s} not found",
errorcode=404)
)
page = htm.Block()
- page.p[htm.a(".atr-back-link", href=util.as_url(committees.view,
name=committee_name))["← Back to committee"]]
+ page.p[htm.a(".atr-back-link", href=util.as_url(committees.view,
name=str(committee_name)))["← Back to committee"]]
page.h1["Add project"]
page.p[f"Add a new project to the {committee.display_name} committee."]
- committee_display_name = committee.full_name or committee_name.title()
+ committee_display_name = committee.full_name or str(committee_name).title()
form.render_block(
page,
model_cls=shared.projects.AddProjectForm,
- action=util.as_url(post.projects.add_project,
committee_name=committee_name),
+ action=util.as_url(post.projects.add_project,
committee_name=str(committee_name)),
submit_label="Add project",
- cancel_url=util.as_url(committees.view, name=committee_name),
+ cancel_url=util.as_url(committees.view, name=str(committee_name)),
defaults={
- "committee_name": committee_name,
+ "committee_name": str(committee_name),
},
)
@@ -80,7 +81,7 @@ async def add_project(
page.append(
htpy.div(
"#projects-add-config.d-none",
- data_committee_name=committee_name,
+ data_committee_name=str(committee_name),
data_committee_display_name=committee_display_name,
)
)
diff --git a/atr/get/ref.py b/atr/get/ref.py
index 9a665ac5..84b18a76 100644
--- a/atr/get/ref.py
+++ b/atr/get/ref.py
@@ -41,9 +41,10 @@ async def resolve(_session: web.Public, _ref:
Literal["ref"], ref_path: unsafe.P
Resolve a code reference to a GitHub permalink.
"""
project_root = pathlib.Path(config.get().PROJECT_ROOT)
+ path = str(ref_path)
- if ":" in ref_path:
- file_path_str, symbol = ref_path.rsplit(":", 1)
+ if ":" in path:
+ file_path_str, symbol = path.rsplit(":", 1)
resolved_file, validated_path_str =
_validate_and_resolve_path(file_path_str, project_root)
if (not await aiofiles.os.path.exists(resolved_file)) or (not await
aiofiles.os.path.isfile(resolved_file)):
@@ -57,8 +58,8 @@ async def resolve(_session: web.Public, _ref: Literal["ref"],
ref_path: unsafe.P
github_url =
f"https://github.com/apache/tooling-trusted-releases/blob/main/{validated_path_str}#L{line_number}"
return quart.redirect(github_url, code=303)
- is_directory = ref_path.endswith("/")
- path_str = ref_path.rstrip("/")
+ is_directory = path.endswith("/")
+ path_str = path.rstrip("/")
resolved_path, validated_path_str = _validate_and_resolve_path(path_str,
project_root)
if not await aiofiles.os.path.exists(resolved_path):
diff --git a/atr/get/test.py b/atr/get/test.py
index 7d36472f..a8075c23 100644
--- a/atr/get/test.py
+++ b/atr/get/test.py
@@ -33,6 +33,7 @@ import atr.htm as htm
import atr.models.safe as safe
import atr.models.session
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.paths as paths
import atr.shared as shared
import atr.storage as storage
@@ -204,7 +205,7 @@ async def test_single(session: web.Public, _test_single:
Literal["test/single"])
async def test_vote(
session: web.Public,
_test_vote: Literal["test/vote"],
- category: str,
+ category: unsafe.UnsafeStr,
project_name: safe.ProjectName,
version_name: safe.VersionName,
) -> str:
@@ -222,10 +223,10 @@ async def test_vote(
"pmc_member_rm": vote.UserCategory.PMC_MEMBER_RM,
}
- user_category = category_map.get(category.lower())
+ user_category = category_map.get(str(category).lower())
if user_category is None:
raise base.ASFQuartException(
- f"Invalid category: {category}. Valid options: {',
'.join(category_map.keys())}",
+ f"Invalid category: {category!s}. Valid options: {',
'.join(category_map.keys())}",
errorcode=400,
)
diff --git a/atr/get/voting.py b/atr/get/voting.py
index a92c496d..90db36ff 100644
--- a/atr/get/voting.py
+++ b/atr/get/voting.py
@@ -47,7 +47,7 @@ async def selected_revision(
_voting: Literal["voting"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str,
+ revision: safe.RevisionNumber,
) -> web.WerkzeugResponse | str:
"""
URL: /voting/<project_name>/<version_name>/<revision>
@@ -63,7 +63,7 @@ async def selected_revision(
error=error,
project_name=str(project_name),
version_name=str(version_name),
- revision=revision,
+ revision=str(revision),
)
case (release, committee):
pass
@@ -94,7 +94,7 @@ async def selected_revision(
content = await _render_page(
release=release,
- revision_number=revision,
+ revision_number=str(revision),
permitted_recipients=permitted_recipients,
default_subject=default_subject,
subject_template_hash=subject_template_hash,
diff --git a/atr/merge.py b/atr/merge.py
index 81ddc81a..84e7a050 100644
--- a/atr/merge.py
+++ b/atr/merge.py
@@ -38,7 +38,7 @@ async def merge(
prior_dir: pathlib.Path,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- prior_revision_number: str,
+ prior_revision_number: safe.RevisionNumber,
temp_dir: pathlib.Path,
n_inodes: dict[str, int],
n_hashes: dict[str, str],
@@ -122,7 +122,7 @@ async def _add_from_prior(
prior_hashes: dict[str, str] | None,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- prior_revision_number: str,
+ prior_revision_number: safe.RevisionNumber,
) -> dict[str, str] | None:
target = temp_dir / path
await asyncio.to_thread(_makedirs_with_permissions, target.parent,
temp_dir)
@@ -179,7 +179,7 @@ async def _merge_all_present(
prior_hashes: dict[str, str] | None,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- prior_revision_number: str,
+ prior_revision_number: safe.RevisionNumber,
) -> dict[str, str] | None:
# Cases 6, 8: prior and new share an inode so they already agree
if p_ino == n_ino:
@@ -240,7 +240,7 @@ async def _replace_with_prior(
prior_hashes: dict[str, str] | None,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- prior_revision_number: str,
+ prior_revision_number: safe.RevisionNumber,
) -> dict[str, str] | None:
await aiofiles.os.remove(temp_dir / path)
await aiofiles.os.link(prior_dir / path, temp_dir / path)
diff --git a/atr/models/api.py b/atr/models/api.py
index 56f2e737..342773b2 100644
--- a/atr/models/api.py
+++ b/atr/models/api.py
@@ -166,7 +166,7 @@ class DistributionRecordResults(schema.Strict):
class IgnoreAddArgs(schema.Strict):
project_name: safe.ProjectName = schema.example("example")
release_glob: str | None = schema.default_example(None, "example-0.0.*")
- revision_number: str | None = schema.default_example(None, "00001")
+ revision_number: safe.RevisionNumber | None = schema.default_example(None,
"00001")
checker_glob: str | None = schema.default_example(None,
"atr.tasks.checks.license.files")
primary_rel_path_glob: str | None = schema.default_example(None,
"apache-example-0.0.1-*.tar.gz")
member_rel_path_glob: str | None = schema.default_example(None,
"apache-example-0.0.1/*.xml")
@@ -357,7 +357,7 @@ class PublisherReleaseAnnounceArgs(schema.Strict):
publisher: str = schema.example("user")
jwt: str = schema.example("eyJhbGciOiJIUzI1[...]mMjLiuyu5CSpyHI=")
version: safe.VersionName = schema.example("0.0.1")
- revision: str = schema.example("00005")
+ revision: safe.RevisionNumber = schema.example("00005")
email_to: str = schema.example("[email protected]")
body: str = schema.example("The Apache Example team is pleased to announce
the release of Example 1.0.0...")
path_suffix: str = schema.example("example/1.0.0")
@@ -396,7 +396,7 @@ class PublisherVoteResolveResults(schema.Strict):
class ReleaseAnnounceArgs(schema.Strict):
project: safe.ProjectName = schema.example("example")
version: safe.VersionName = schema.example("1.0.0")
- revision: str = schema.example("00005")
+ revision: safe.RevisionNumber = schema.example("00005")
email_to: str = schema.example("[email protected]")
body: str = schema.example("The Apache Example team is pleased to announce
the release of Example 1.0.0...")
path_suffix: str = schema.example("example/1.0.0")
@@ -581,7 +581,7 @@ class VoteResolveResults(schema.Strict):
class VoteStartArgs(schema.Strict):
project: safe.ProjectName = schema.example("example")
version: safe.VersionName = schema.example("0.0.1")
- revision: str = schema.example("00005")
+ revision: safe.RevisionNumber = schema.example("00005")
email_to: str = schema.example("[email protected]")
vote_duration: int = schema.example(10)
subject: str = schema.example("[VOTE] Apache Example 0.0.1 release")
diff --git a/atr/models/safe.py b/atr/models/safe.py
index 083a8b2b..9ec73067 100644
--- a/atr/models/safe.py
+++ b/atr/models/safe.py
@@ -22,16 +22,17 @@ import unicodedata
from typing import Any, Final
_ALPHANUM: Final = frozenset(string.ascii_letters + string.digits + "-")
+_NUMERIC: Final = frozenset(string.digits)
_VERSION_CHARS: Final = _ALPHANUM | frozenset(".+")
-class Alphanumeric:
+class SafeType:
__slots__ = ("_value",)
@classmethod
def _valid_chars(cls) -> frozenset[str]:
# default is the base set; subclasses can override this method
- return _ALPHANUM
+ return frozenset()
def _additional_validations(self, value: str):
pass
@@ -53,7 +54,7 @@ class Alphanumeric:
return True
def __eq__(self, other: object) -> bool:
- if isinstance(other, Alphanumeric):
+ if isinstance(other, self.__class__):
return self._value == other._value
return NotImplemented
@@ -76,6 +77,20 @@ class Alphanumeric:
)
+class Alphanumeric(SafeType):
+ @classmethod
+ def _valid_chars(cls) -> frozenset[str]:
+ # default is the base set; subclasses can override this method
+ return _ALPHANUM
+
+
+class Numeric(SafeType):
+ @classmethod
+ def _valid_chars(cls) -> frozenset[str]:
+ # default is the base set; subclasses can override this method
+ return _NUMERIC
+
+
class ProjectName(Alphanumeric):
"""A project name that has been validated for safety."""
@@ -88,6 +103,10 @@ class ReleaseName(Alphanumeric):
return _VERSION_CHARS
+class RevisionNumber(Numeric):
+ """A revision number that has been validated for safety."""
+
+
class VersionName(Alphanumeric):
"""A version name that has been validated for safety"""
diff --git a/atr/models/sql.py b/atr/models/sql.py
index 3b288602..2ba21727 100644
--- a/atr/models/sql.py
+++ b/atr/models/sql.py
@@ -923,6 +923,11 @@ class Release(sqlmodel.SQLModel, table=True):
# return None
return project.committee
+ @property
+ def safe_latest_revision_number(self) -> safe.RevisionNumber:
+ """Get the typesafe validated name for the Release"""
+ return safe.RevisionNumber(self.unwrap_revision_number)
+
@property
def safe_name(self) -> safe.ReleaseName:
"""Get the typesafe validated name for the Release"""
@@ -1317,6 +1322,11 @@ class Revision(sqlmodel.SQLModel, table=True):
tag: str | None = sqlmodel.Field(default=None, **example("rc1"))
was_quarantined: bool = sqlmodel.Field(default=False, **example(False))
+ @property
+ def safe_number(self) -> safe.RevisionNumber:
+ """Get the typesafe validated number for the revision"""
+ return safe.RevisionNumber(self.number)
+
def model_post_init(self, _context):
if isinstance(self.created, str):
self.created =
datetime.datetime.fromisoformat(self.created.rstrip("Z"))
diff --git a/atr/models/unsafe.py b/atr/models/unsafe.py
index 6e38d4e4..d76127e7 100644
--- a/atr/models/unsafe.py
+++ b/atr/models/unsafe.py
@@ -28,5 +28,8 @@ class UnsafeStr:
def __repr__(self) -> str:
return f"UnsafeStr({self._value!r})"
+ def __str__(self) -> str:
+ return self._value
-Path = NewType("Path", str)
+
+Path = NewType("Path", UnsafeStr)
diff --git a/atr/paths.py b/atr/paths.py
index 2aae97ae..7931db1d 100644
--- a/atr/paths.py
+++ b/atr/paths.py
@@ -23,9 +23,9 @@ import atr.models.sql as sql
def base_path_for_revision(
- project_name: safe.ProjectName, version_name: safe.VersionName, revision:
str
+ project_name: safe.ProjectName, version_name: safe.VersionName, revision:
safe.RevisionNumber
) -> pathlib.Path:
- return pathlib.Path(get_unfinished_dir(), str(project_name),
str(version_name), revision)
+ return pathlib.Path(get_unfinished_dir(), str(project_name),
str(version_name), str(revision))
def get_attestable_dir() -> pathlib.Path:
@@ -135,6 +135,6 @@ def release_directory_version(release: sql.Release) ->
pathlib.Path:
def revision_path_for_file(
- project_name: safe.ProjectName, version_name: safe.VersionName, revision:
str, file_name: str
+ project_name: safe.ProjectName, version_name: safe.VersionName, revision:
safe.RevisionNumber, file_name: str
) -> pathlib.Path:
return base_path_for_revision(project_name, version_name, revision) /
file_name
diff --git a/atr/post/announce.py b/atr/post/announce.py
index f988b40a..23ea535f 100644
--- a/atr/post/announce.py
+++ b/atr/post/announce.py
@@ -61,14 +61,14 @@ async def selected(
with_release_policy=True,
with_project_release_policy=True,
)
- preview_revision_number = release.unwrap_revision_number
+ preview_revision_number = release.safe_latest_revision_number
# Validate that the revision number matches
- if announce_form.revision_number != preview_revision_number:
+ if announce_form.revision_number != str(preview_revision_number):
return await session.redirect(
get.announce.selected,
error=f"The release has been updated since you loaded the form. "
- f"Please review the current revision ({preview_revision_number})
and submit the form again.",
+ f"Please review the current revision ({preview_revision_number!s})
and submit the form again.",
project_name=str(project_name),
version_name=str(version_name),
)
diff --git a/atr/post/draft.py b/atr/post/draft.py
index 675d4ef8..dcfa6add 100644
--- a/atr/post/draft.py
+++ b/atr/post/draft.py
@@ -257,7 +257,8 @@ async def sbomgen(
URL: /draft/sbomgen/<project_name>/<version_name>/<file_path>
Generate a CycloneDX SBOM file for a candidate draft file, creating a new
revision.
"""
- rel_path = form.to_relpath(file_path)
+ path = str(file_path)
+ rel_path = form.to_relpath(path)
if rel_path is None:
await quart.flash("Invalid file path", "error")
return await session.redirect(
@@ -265,14 +266,9 @@ async def sbomgen(
)
# Check that the file is a .tar.gz archive before creating a revision
- if not (
- file_path.endswith(".tar.gz")
- or file_path.endswith(".tgz")
- or file_path.endswith(".zip")
- or file_path.endswith(".jar")
- ):
+ if not (path.endswith(".tar.gz") or path.endswith(".tgz") or
path.endswith(".zip") or path.endswith(".jar")):
raise base.ASFQuartException(
- f"SBOM generation requires .tar.gz, .tgz, .zip or .jar files.
Received: {file_path}", errorcode=400
+ f"SBOM generation requires .tar.gz, .tgz, .zip or .jar files.
Received: {path}", errorcode=400
)
try:
diff --git a/atr/post/keys.py b/atr/post/keys.py
index 9042211a..df295ea1 100644
--- a/atr/post/keys.py
+++ b/atr/post/keys.py
@@ -31,6 +31,7 @@ import atr.htm as htm
import atr.log as log
import atr.models.safe as safe
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.shared as shared
import atr.storage as storage
import atr.storage.outcome as outcome
@@ -93,18 +94,18 @@ async def add(
async def details(
session: web.Committer,
_keys_details: Literal["keys/details"],
- fingerprint: str,
+ fingerprint: unsafe.UnsafeStr,
update_form: shared.keys.UpdateKeyCommitteesForm,
) -> web.WerkzeugResponse:
"""
URL: /keys/details/<fingerprint>
Update committee associations for an OpenPGP key.
"""
- fingerprint = fingerprint.lower()
+ key_fingerprint = str(fingerprint).lower()
try:
async with db.session() as data:
- key = await data.public_signing_key(fingerprint=fingerprint,
_committees=True).get()
+ key = await data.public_signing_key(fingerprint=key_fingerprint,
_committees=True).get()
if not key:
await quart.flash("OpenPGP key not found", "error")
return await session.redirect(get.keys.keys)
@@ -135,7 +136,7 @@ async def details(
log.exception("Error updating key committee associations:")
await quart.flash(f"An unexpected error occurred: {e!s}", "error")
- return await session.redirect(get.keys.details, fingerprint=fingerprint)
+ return await session.redirect(get.keys.details,
fingerprint=key_fingerprint)
@post.typed
diff --git a/atr/post/manual.py b/atr/post/manual.py
index 4fd8a06e..591b323d 100644
--- a/atr/post/manual.py
+++ b/atr/post/manual.py
@@ -87,7 +87,7 @@ async def start_selected_revision(
_manual_start: Literal["manual/start"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str,
+ revision: safe.RevisionNumber,
_form: form.Empty,
) -> web.WerkzeugResponse | str:
"""
diff --git a/atr/post/projects.py b/atr/post/projects.py
index 68f9d878..6763bf48 100644
--- a/atr/post/projects.py
+++ b/atr/post/projects.py
@@ -27,6 +27,7 @@ import atr.db as db
import atr.get as get
import atr.models.safe as safe
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.shared as shared
import atr.storage as storage
import atr.web as web
@@ -36,7 +37,7 @@ import atr.web as web
async def add_project(
session: web.Committer,
_project_add: Literal["project/add"],
- committee_name: str,
+ committee_name: unsafe.UnsafeStr,
project_form: shared.projects.AddProjectForm,
) -> web.WerkzeugResponse:
"""
@@ -47,12 +48,12 @@ async def add_project(
# TODO: Is this right? Name is unvalidated
async with storage.write(session) as write:
- wacm = await
write.as_project_committee_member(safe.ProjectName(committee_name))
+ wacm = await
write.as_project_committee_member(safe.ProjectName(str(committee_name)))
try:
- await wacm.project.create(committee_name, display_name, label)
+ await wacm.project.create(str(committee_name), display_name, label)
except storage.AccessError as e:
return await session.redirect(
- get.projects.add_project, committee_name=committee_name,
error=f"Error adding project: {e}"
+ get.projects.add_project, committee_name=str(committee_name),
error=f"Error adding project: {e}"
)
return await session.redirect(
@@ -88,7 +89,7 @@ async def delete(
async def view(
session: web.Committer,
_projects: Literal["projects"],
- name: str,
+ name: unsafe.UnsafeStr,
project_form: shared.projects.ProjectViewForm,
) -> web.WerkzeugResponse:
"""
diff --git a/atr/post/upload.py b/atr/post/upload.py
index 54b1a056..930eda14 100644
--- a/atr/post/upload.py
+++ b/atr/post/upload.py
@@ -33,6 +33,7 @@ import atr.get as get
import atr.log as log
import atr.models.safe as safe
import atr.models.sql as sql
+import atr.models.unsafe as unsafe
import atr.paths as paths
import atr.shared as shared
import atr.storage as storage
@@ -49,14 +50,14 @@ async def finalise(
_upload_finalise: Literal["upload/finalise"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- upload_session: str,
+ upload_session: unsafe.UnsafeStr,
) -> web.WerkzeugResponse:
"""
URL: /upload/finalise/<project_name>/<version_name>/<upload_session>
"""
try:
- staging_dir = paths.get_upload_staging_dir(upload_session)
+ staging_dir = paths.get_upload_staging_dir(str(upload_session))
except ValueError:
return _json_error("Invalid session token", 400)
@@ -143,14 +144,14 @@ async def stage(
_upload_stage: Literal["upload/stage"],
_project_name: safe.ProjectName,
_version_name: safe.VersionName,
- upload_session: str,
+ upload_session: unsafe.UnsafeStr,
) -> web.WerkzeugResponse:
"""
URL: /upload/stage/<project_name>/<version_name>/<upload_session>
"""
try:
- staging_dir = paths.get_upload_staging_dir(upload_session)
+ staging_dir = paths.get_upload_staging_dir(str(upload_session))
except ValueError:
return _json_error("Invalid session token", 400)
diff --git a/atr/post/voting.py b/atr/post/voting.py
index bed6534e..cac23325 100644
--- a/atr/post/voting.py
+++ b/atr/post/voting.py
@@ -43,7 +43,7 @@ async def body_preview(
_voting_body_preview: Literal["voting/body/preview"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
preview_form: BodyPreviewForm,
) -> web.QuartResponse:
"""
@@ -72,7 +72,7 @@ async def selected_revision(
_voting: Literal["voting"],
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision: str,
+ revision: safe.RevisionNumber,
start_voting_form: shared.voting.StartVotingForm,
) -> web.WerkzeugResponse | str:
"""
@@ -89,7 +89,7 @@ async def selected_revision(
error=error,
project_name=str(project_name),
version_name=str(version_name),
- revision=revision,
+ revision=str(revision),
)
case (release, committee):
pass
diff --git a/atr/shared/ignores.py b/atr/shared/ignores.py
index 2b41f66f..56cb0b84 100644
--- a/atr/shared/ignores.py
+++ b/atr/shared/ignores.py
@@ -23,6 +23,7 @@ from typing import Annotated, Literal
import pydantic
import atr.form as form
+import atr.models.safe as safe
import atr.models.sql as sql
import atr.models.validation as validation
@@ -43,7 +44,7 @@ class IgnoreStatus(enum.Enum):
class AddIgnoreForm(form.Form):
variant: ADD = form.value(ADD)
release_glob: str = form.label("Release pattern", default="")
- revision_number: str = form.label("Revision number (literal)", default="")
+ revision_number: safe.RevisionNumber = form.label("Revision number
(literal)", default="")
checker_glob: str = form.label("Checker pattern", default="")
primary_rel_path_glob: str = form.label("Primary rel path pattern",
default="")
member_rel_path_glob: str = form.label("Member rel path pattern",
default="")
diff --git a/atr/ssh.py b/atr/ssh.py
index be81b1c9..7e2f53a5 100644
--- a/atr/ssh.py
+++ b/atr/ssh.py
@@ -616,7 +616,9 @@ async def _step_07b_process_validated_rsync_write(
else:
github_payload = server._get_github_payload(process)
if github_payload is not None:
- await attestable.github_tp_payload_write(project_name,
version_name, result.number, github_payload)
+ await attestable.github_tp_payload_write(
+ project_name, version_name, result.safe_number,
github_payload
+ )
log.info(f"rsync upload successful for revision
{result.number}")
host = config.get().APP_HOST
message = f"\nATR: Created revision {result.number} of
{project_name} {version_name}\n"
diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py
index cfbfca0e..502f1109 100644
--- a/atr/storage/readers/releases.py
+++ b/atr/storage/readers/releases.py
@@ -24,6 +24,7 @@ import pathlib
import atr.classify as classify
import atr.db as db
import atr.db.interaction as interaction
+import atr.models.safe as safe
import atr.models.sql as sql
import atr.paths as paths
import atr.storage as storage
@@ -60,7 +61,7 @@ class GeneralPublic:
latest_revision_number = release.latest_revision_number
if latest_revision_number is None:
return None
- await self.__successes_errors_warnings(release,
latest_revision_number, info)
+ await self.__successes_errors_warnings(release,
release.safe_latest_revision_number, info)
base_path = paths.release_directory(release)
source_matcher = None
source_artifact_paths = release.project.policy_source_artifact_paths
@@ -121,7 +122,7 @@ class GeneralPublic:
)
async def __successes_errors_warnings(
- self, release: sql.Release, latest_revision_number: str, info:
types.PathInfo
+ self, release: sql.Release, latest_revision_number:
safe.RevisionNumber, info: types.PathInfo
) -> None:
match_ignore = await
self.__read_as.checks.ignores_matcher(release.safe_project_name)
attestable_checks = await interaction.checks_for(
diff --git a/atr/storage/writers/announce.py b/atr/storage/writers/announce.py
index a56a2997..69f10aa5 100644
--- a/atr/storage/writers/announce.py
+++ b/atr/storage/writers/announce.py
@@ -106,7 +106,7 @@ class CommitteeMember(CommitteeParticipant):
self,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- preview_revision_number: str,
+ preview_revision_number: safe.RevisionNumber,
recipient: str,
body: str,
download_path_suffix: str,
@@ -124,14 +124,14 @@ class CommitteeMember(CommitteeParticipant):
project_name=str(project_name),
version=str(version_name),
phase=sql.ReleasePhase.RELEASE_PREVIEW,
- latest_revision_number=preview_revision_number,
+ latest_revision_number=str(preview_revision_number),
_project_release_policy=True,
_revisions=True,
_distributions=True,
_release_policy=True,
).demand(
storage.AccessError(
- f"Release {project_name} {version_name}
{preview_revision_number} does not exist",
+ f"Release {project_name!s} {version_name!s}
{preview_revision_number!s} does not exist",
)
)
if (committee := release.project.committee) is None:
@@ -199,7 +199,7 @@ class CommitteeMember(CommitteeParticipant):
asf_uid=self.__asf_uid,
project_name=str(project_name),
version_name=str(version_name),
- revision_number=preview_revision_number,
+ revision_number=str(preview_revision_number),
source_directory=unfinished_dir,
target_directory=finished_dir,
email_recipient=recipient,
@@ -282,7 +282,7 @@ class CommitteeMember(CommitteeParticipant):
return predicted_finished_release
async def __promote_in_database(
- self, release: sql.Release, preview_revision_number: str,
release_date: datetime.datetime
+ self, release: sql.Release, preview_revision_number:
safe.RevisionNumber, release_date: datetime.datetime
) -> None:
"""Promote a release preview to a release and delete its old
revisions."""
via = sql.validate_instrumented_attribute
@@ -292,7 +292,7 @@ class CommitteeMember(CommitteeParticipant):
.where(
via(sql.Release.name) == release.name,
via(sql.Release.phase) == sql.ReleasePhase.RELEASE_PREVIEW,
- sql.latest_revision_number_query() == preview_revision_number,
+ sql.latest_revision_number_query() ==
str(preview_revision_number),
)
.values(
phase=sql.ReleasePhase.RELEASE,
diff --git a/atr/storage/writers/checks.py b/atr/storage/writers/checks.py
index 8140dc28..36094fdd 100644
--- a/atr/storage/writers/checks.py
+++ b/atr/storage/writers/checks.py
@@ -95,7 +95,7 @@ class CommitteeMember(CommitteeParticipant):
self,
project_name: safe.ProjectName,
release_glob: str | None = None,
- revision_number: str | None = None,
+ revision_number: safe.RevisionNumber | None = None,
checker_glob: str | None = None,
primary_rel_path_glob: str | None = None,
member_rel_path_glob: str | None = None,
@@ -115,7 +115,7 @@ class CommitteeMember(CommitteeParticipant):
created=datetime.datetime.now(datetime.UTC),
project_name=str(project_name),
release_glob=release_glob,
- revision_number=revision_number,
+ revision_number=str(revision_number),
checker_glob=checker_glob,
primary_rel_path_glob=primary_rel_path_glob,
member_rel_path_glob=member_rel_path_glob,
diff --git a/atr/storage/writers/release.py b/atr/storage/writers/release.py
index e273def3..443dd8e0 100644
--- a/atr/storage/writers/release.py
+++ b/atr/storage/writers/release.py
@@ -321,7 +321,7 @@ class CommitteeParticipant(FoundationCommitter):
async def promote_to_candidate(
self,
release_name: safe.ReleaseName,
- selected_revision_number: str,
+ selected_revision_number: safe.RevisionNumber,
vote_manual: bool = False,
) -> str | None:
"""Promote a release candidate draft to a new phase."""
@@ -330,6 +330,7 @@ class CommitteeParticipant(FoundationCommitter):
)
project_name = release_for_pre_checks.safe_project_name
version_name = release_for_pre_checks.safe_version_name
+ revision_number = release_for_pre_checks.safe_latest_revision_number
# Check for ongoing tasks
ongoing_tasks = await self.__tasks_ongoing(project_name, version_name,
selected_revision_number)
@@ -341,7 +342,7 @@ class CommitteeParticipant(FoundationCommitter):
return "This release is not in the candidate draft phase"
# Check that the revision number is the latest
- if release_for_pre_checks.latest_revision_number !=
selected_revision_number:
+ if revision_number != selected_revision_number:
return "The selected revision number does not match the latest
revision number"
# Check that there is at least one file in the draft
@@ -356,7 +357,7 @@ class CommitteeParticipant(FoundationCommitter):
.where(
via(sql.Release.name) == release_for_pre_checks.name,
via(sql.Release.phase) ==
sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
- sql.latest_revision_number_query() == selected_revision_number,
+ sql.latest_revision_number_query() ==
str(selected_revision_number),
)
.values(
phase=sql.ReleasePhase.RELEASE_CANDIDATE,
@@ -377,7 +378,7 @@ class CommitteeParticipant(FoundationCommitter):
self.__write_as.append_to_audit_log(
asf_uid=self.__asf_uid,
release_name=str(release_name),
- selected_revision_number=selected_revision_number,
+ selected_revision_number=str(selected_revision_number),
vote_manual=vote_manual,
)
return None
@@ -765,14 +766,17 @@ class CommitteeParticipant(FoundationCommitter):
moved_files_names.append(f.name)
async def __tasks_ongoing(
- self, project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str | None = None
+ self,
+ project_name: safe.ProjectName,
+ version_name: safe.VersionName,
+ revision_number: safe.RevisionNumber | None = None,
) -> int:
tasks = sqlmodel.select(sqlalchemy.func.count()).select_from(sql.Task)
query = tasks.where(
sql.Task.project_name == str(project_name),
sql.Task.version_name == str(version_name),
sql.Task.revision_number
- == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is
None) else revision_number),
+ == (sql.RELEASE_LATEST_REVISION_NUMBER if (revision_number is
None) else str(revision_number)),
sql.validate_instrumented_attribute(sql.Task.status).in_([sql.TaskStatus.QUEUED,
sql.TaskStatus.ACTIVE]),
)
result = await self.__data.execute(query)
diff --git a/atr/storage/writers/revision.py b/atr/storage/writers/revision.py
index 0f9394d6..d0f37af6 100644
--- a/atr/storage/writers/revision.py
+++ b/atr/storage/writers/revision.py
@@ -196,7 +196,7 @@ async def _commit_new_revision(
await attestable.write_files_data(
project_name,
version_name,
- new_revision.number,
+ new_revision.safe_number,
policy.model_dump() if policy else None,
asf_uid,
previous_attestable,
@@ -216,7 +216,7 @@ async def _commit_new_revision(
# It does, however, need a transaction to be created using data.begin()
if release.phase == sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
# Must use caller_data here because we acquired the write lock
- await tasks.draft_checks(asf_uid, project_name, version_name,
new_revision.number, caller_data=data)
+ await tasks.draft_checks(asf_uid, project_name, version_name,
new_revision.safe_number, caller_data=data)
return new_revision
@@ -263,12 +263,14 @@ async def _lock_and_merge(
if (
merge_enabled
and (old_revision is not None)
+ and (latest is not None)
and (prior_revision_name is not None)
and (prior_revision_name != old_revision.name)
):
merge_base_revision_name = prior_revision_name
- prior_number = prior_revision_name.split()[-1]
- prior_dir = paths.release_directory_base(merged_release) / prior_number
+ # This won't be None because prior_revision_name is not None here
+ prior_number = latest.safe_number
+ prior_dir = paths.release_directory_base(merged_release) /
str(prior_number)
await merge.merge(
base_inodes,
base_hashes,
@@ -359,7 +361,7 @@ class CommitteeParticipant(FoundationCommitter):
set_local_cache: bool = False,
reset_to_global_cache: bool = False,
modify: Callable[[pathlib.Path, sql.Revision | None], Awaitable[None]]
| None = None,
- clone_from: str | None = None,
+ clone_from: safe.RevisionNumber | None = None,
) -> sql.Revision | sql.Quarantined:
"""Create a new revision, quarantining archives that require
validation."""
release_name = sql.release_name(str(project_name), str(version_name))
@@ -368,7 +370,7 @@ class CommitteeParticipant(FoundationCommitter):
RuntimeError("Release does not exist for new revision
creation")
)
if clone_from is not None:
- old_revision = await data.revision(release_name=release_name,
number=clone_from).demand(
+ old_revision = await data.revision(release_name=release_name,
number=str(clone_from)).demand(
RuntimeError(f"Revision {clone_from} does not exist")
)
else:
@@ -379,7 +381,7 @@ class CommitteeParticipant(FoundationCommitter):
release.check_cache_key = None
if clone_from is not None:
- old_release_dir = paths.release_directory_base(release) /
clone_from
+ old_release_dir = paths.release_directory_base(release) /
str(clone_from)
else:
old_release_dir = paths.release_directory(release)
merge_enabled = clone_from is None
@@ -427,7 +429,7 @@ class CommitteeParticipant(FoundationCommitter):
try:
path_to_hash, path_to_size = await
attestable.paths_to_hashes_and_sizes(temp_dir_path)
- parent_revision_number = old_revision.number if old_revision else
None
+ parent_revision_number = old_revision.safe_number if old_revision
else None
previous_attestable = None
if parent_revision_number is not None:
previous_attestable = await attestable.load(project_name,
version_name, parent_revision_number)
diff --git a/atr/storage/writers/vote.py b/atr/storage/writers/vote.py
index 1afb6d56..fb8e7c59 100644
--- a/atr/storage/writers/vote.py
+++ b/atr/storage/writers/vote.py
@@ -139,7 +139,7 @@ class CommitteeParticipant(FoundationCommitter):
email_to: str,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- selected_revision_number: str,
+ selected_revision_number: safe.RevisionNumber,
vote_duration_choice: int,
subject: str,
body_data: str,
@@ -358,7 +358,7 @@ class CommitteeMember(CommitteeParticipant):
fullname=asf_fullname,
project_name=release.safe_project_name,
version_name=release.safe_version_name,
- revision_number=revision_number,
+ revision_number=release.safe_latest_revision_number,
vote_duration=vote_duration,
)
subject_data, body_data = await
construct.start_vote_subject_and_body(
@@ -369,7 +369,7 @@ class CommitteeMember(CommitteeParticipant):
permitted_recipients=[incubator_vote_address],
project_name=release.safe_project_name,
version_name=release.safe_version_name,
- selected_revision_number=revision_number,
+ selected_revision_number=release.safe_latest_revision_number,
asf_uid=self.__asf_uid,
asf_fullname=asf_fullname,
vote_duration_choice=vote_duration,
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 4bdd1ae4..15c4e1e2 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -54,7 +54,7 @@ import atr.util as util
async def asc_checks(
- asf_uid: str, release: sql.Release, revision: str, signature_path: str,
data: db.Session
+ asf_uid: str, release: sql.Release, revision: safe.RevisionNumber,
signature_path: str, data: db.Session
) -> list[sql.Task | None]:
"""Create signature check task for a .asc file."""
tasks = []
@@ -135,13 +135,13 @@ async def draft_checks(
asf_uid: str,
project_name: safe.ProjectName,
release_version: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
caller_data: db.Session | None = None,
) -> int:
"""Core logic to analyse a draft revision and queue checks."""
# Construct path to the specific revision
# We don't have the release object here, so we can't use
util.release_directory
- revision_path = file_paths.get_unfinished_dir() / str(project_name) /
str(release_version) / revision_number
+ revision_path = file_paths.get_unfinished_dir() / str(project_name) /
str(release_version) / str(revision_number)
relative_paths = [path async for path in
util.paths_recursive(revision_path)]
async with db.ensure_session(caller_data) as data:
@@ -265,7 +265,7 @@ async def queued(
asf_uid: str,
task_type: sql.TaskType,
release: sql.Release,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
primary_rel_path: str | None = None,
extra_args: dict[str, Any] | None = None,
check_cache_key: dict[str, Any] | None = None,
@@ -289,7 +289,7 @@ async def queued(
asf_uid=asf_uid,
project_name=release.project.name,
version_name=release.version,
- revision_number=revision_number,
+ revision_number=str(revision_number),
primary_rel_path=primary_rel_path,
inputs_hash=hash_val,
)
@@ -352,7 +352,7 @@ def resolve(task_type: sql.TaskType) -> Callable[...,
Awaitable[results.Results
async def sha_checks(
- asf_uid: str, release: sql.Release, revision: str, hash_file: str, data:
db.Session
+ asf_uid: str, release: sql.Release, revision: safe.RevisionNumber,
hash_file: str, data: db.Session
) -> list[sql.Task | None]:
"""Create hash check task for a .sha256 or .sha512 file."""
tasks = []
@@ -380,7 +380,7 @@ async def sha_checks(
async def tar_gz_checks(
- asf_uid: str, release: sql.Release, revision: str, path: str, data:
db.Session
+ asf_uid: str, release: sql.Release, revision: safe.RevisionNumber, path:
str, data: db.Session
) -> list[sql.Task | None]:
"""Create check tasks for a .tar.gz or .tgz file."""
# This release has committee, as guaranteed in draft_checks
@@ -489,7 +489,7 @@ async def workflow_update(
async def zip_checks(
- asf_uid: str, release: sql.Release, revision: str, path: str, data:
db.Session
+ asf_uid: str, release: sql.Release, revision: safe.RevisionNumber, path:
str, data: db.Session
) -> list[sql.Task | None]:
"""Create check tasks for a .zip file."""
# This release has committee, as guaranteed in draft_checks
@@ -592,7 +592,7 @@ async def _draft_file_checks(
project_name: safe.ProjectName,
release: sql.Release,
release_version: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
):
path_str = str(path)
task_function: Callable[[str, sql.Release, str, str, db.Session],
Awaitable[list[sql.Task | None]]] | None = None
@@ -603,7 +603,7 @@ async def _draft_file_checks(
if task_function:
for task in await task_function(asf_uid, release, revision_number,
path_str, data):
if task:
- task.revision_number = revision_number
+ task.revision_number = str(revision_number)
await _add_task(data, task)
# TODO: Should we check .json files for their content?
# Ideally we would not have to do that
@@ -617,7 +617,7 @@ async def _draft_file_checks(
extra_args={
"project_name": str(project_name),
"version_name": str(release_version),
- "revision_number": revision_number,
+ "revision_number": str(revision_number),
"previous_release_version": previous_version.version if
previous_version else None,
"file_path": path_str,
"asf_uid": asf_uid,
diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py
index 99681200..2e552c9b 100644
--- a/atr/tasks/checks/__init__.py
+++ b/atr/tasks/checks/__init__.py
@@ -53,7 +53,7 @@ class FunctionArguments:
asf_uid: str
project_name: safe.ProjectName
version_name: safe.VersionName
- revision_number: str
+ revision_number: safe.RevisionNumber
primary_rel_path: str | None
extra_args: dict[str, Any]
@@ -65,7 +65,7 @@ class Recorder:
version_name: safe.VersionName
primary_rel_path: str | None
member_rel_path: str | None
- revision_number: str
+ revision_number: safe.RevisionNumber
afresh: bool
__cached: bool
__input_hash: str | None
@@ -76,7 +76,7 @@ class Recorder:
inputs_hash: str | None,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
primary_rel_path: str | None = None,
member_rel_path: str | None = None,
afresh: bool = True,
@@ -102,7 +102,7 @@ class Recorder:
inputs_hash: str,
project_name: safe.ProjectName,
version_name: safe.VersionName,
- revision_number: str,
+ revision_number: safe.RevisionNumber,
primary_rel_path: str | None = None,
member_rel_path: str | None = None,
afresh: bool = True,
@@ -146,7 +146,7 @@ class Recorder:
result = sql.CheckResult(
release_name=str(self.release_name),
- revision_number=self.revision_number,
+ revision_number=str(self.revision_number),
checker=self.checker,
primary_rel_path=primary_rel_path or self.primary_rel_path,
member_rel_path=member_rel_path,
@@ -317,7 +317,7 @@ async def resolve_cache_key(
checker_version: str,
policy_keys: list[str],
release: sql.Release,
- revision: str,
+ revision: safe.RevisionNumber,
args: dict[str, Any] | None = None,
file: str | None = None,
path: pathlib.Path | None = None,
@@ -384,7 +384,7 @@ async def _resolve_all_files(release: sql.Release,
rel_path: str | None = None)
return []
if not (
base_path := file_paths.base_path_for_revision(
- release.safe_project_name, release.safe_version_name,
release.latest_revision_number
+ release.safe_project_name, release.safe_version_name,
release.safe_latest_revision_number
)
):
return []
@@ -407,7 +407,7 @@ async def _resolve_github_tp_sha(release: sql.Release,
rel_path: str | None = No
if not release.latest_revision_number:
return ""
payload_path = attestable.github_tp_payload_path(
- release.safe_project_name, release.safe_version_name,
release.latest_revision_number
+ release.safe_project_name, release.safe_version_name,
release.safe_latest_revision_number
)
if not await aiofiles.os.path.isfile(payload_path):
return ""
@@ -435,7 +435,7 @@ async def _resolve_unsuffixed_file_hash(release:
sql.Release, rel_path: str | No
if (not rel_path) or (not release.latest_revision_number):
return ""
abs_path = file_paths.revision_path_for_file(
- release.safe_project_name, release.safe_version_name,
release.latest_revision_number, rel_path
+ release.safe_project_name, release.safe_version_name,
release.safe_latest_revision_number, rel_path
)
plain_path = abs_path.with_suffix("")
if await aiofiles.os.path.isfile(plain_path):
diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py
index f2d29881..54906530 100644
--- a/atr/tasks/checks/compare.py
+++ b/atr/tasks/checks/compare.py
@@ -374,7 +374,7 @@ async def _find_archive_root(archive_path: pathlib.Path,
extract_dir: pathlib.Pa
async def _load_tp_payload(
- project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: str
+ project_name: safe.ProjectName, version_name: safe.VersionName,
revision_number: safe.RevisionNumber
) -> github_models.TrustedPublisherPayload | None:
payload_path = attestable.github_tp_payload_path(project_name,
version_name, revision_number)
if not await aiofiles.os.path.isfile(payload_path):
diff --git a/atr/tasks/quarantine.py b/atr/tasks/quarantine.py
index 81beef4f..85f261cf 100644
--- a/atr/tasks/quarantine.py
+++ b/atr/tasks/quarantine.py
@@ -320,7 +320,7 @@ async def _promote(
previous_attestable = None
if old_revision is not None:
- previous_attestable = await attestable.load(project_name,
version_name, old_revision.number)
+ previous_attestable = await attestable.load(project_name,
version_name, old_revision.safe_number)
base_inodes: dict[str, int] = {}
base_hashes: dict[str, str] = {}
diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py
index 665daa4b..3f6ea9a2 100644
--- a/atr/tasks/vote.py
+++ b/atr/tasks/vote.py
@@ -78,7 +78,7 @@ async def _initiate_core_logic(args: Initiate) ->
results.Results | None:
raise VoteInitiationError(f"No revisions found for release
{args.release_name!s}")
ongoing_tasks = await interaction.tasks_ongoing(
- release.safe_project_name, release.safe_version_name,
latest_revision_number
+ release.safe_project_name, release.safe_version_name,
release.safe_latest_revision_number
)
if ongoing_tasks > 0:
raise VoteInitiationError(
diff --git a/atr/web.py b/atr/web.py
index 66671007..e8cbd8e0 100644
--- a/atr/web.py
+++ b/atr/web.py
@@ -165,7 +165,7 @@ class Committer:
project_name: safe.ProjectName,
version_name: safe.VersionName,
phase: sql.ReleasePhase | db.NotSet | None = db.NOT_SET,
- latest_revision_number: str | db.NotSet | None = db.NOT_SET,
+ latest_revision_number: safe.RevisionNumber | db.NotSet | None =
db.NOT_SET,
data: db.Session | None = None,
with_committee: bool = True,
with_project: bool = True,
@@ -182,13 +182,14 @@ class Committer:
phase_value = sql.ReleasePhase.RELEASE_CANDIDATE_DRAFT
else:
phase_value = phase
+ revision = db.NOT_SET if latest_revision_number == db.NOT_SET else
str(latest_revision_number)
release_name = sql.release_name(project_name, version_name)
if data is None:
async with db.session() as data:
release = await data.release(
name=str(release_name),
phase=phase_value,
- latest_revision_number=latest_revision_number,
+ latest_revision_number=revision,
_committee=with_committee,
_project=with_project,
_release_policy=with_release_policy,
@@ -200,7 +201,7 @@ class Committer:
release = await data.release(
name=str(release_name),
phase=phase_value,
- latest_revision_number=latest_revision_number,
+ latest_revision_number=revision,
_committee=with_committee,
_project=with_project,
_release_policy=with_release_policy,
diff --git a/atr/worker.py b/atr/worker.py
index 1694069f..1f65865a 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -121,6 +121,7 @@ async def _execute_check_task(
project_name = safe.ProjectName(task_obj.project_name)
version_name = safe.VersionName(task_obj.version_name)
+ revision_number = safe.RevisionNumber(task_obj.revision_number)
async def recorder_factory() -> checks.Recorder:
return await checks.Recorder.create(
@@ -128,7 +129,7 @@ async def _execute_check_task(
inputs_hash=task_obj.inputs_hash or "",
project_name=project_name,
version_name=version_name,
- revision_number=task_obj.revision_number or "",
+ revision_number=revision_number,
primary_rel_path=task_obj.primary_rel_path,
)
@@ -137,7 +138,7 @@ async def _execute_check_task(
asf_uid=task_obj.asf_uid,
project_name=project_name,
version_name=version_name,
- revision_number=task_obj.revision_number,
+ revision_number=revision_number,
primary_rel_path=task_obj.primary_rel_path,
extra_args=task_args,
)
diff --git a/tests/unit/test_create_revision.py
b/tests/unit/test_create_revision.py
index 8a94458d..82d07ad3 100644
--- a/tests/unit/test_create_revision.py
+++ b/tests/unit/test_create_revision.py
@@ -21,6 +21,7 @@ import unittest.mock as mock
import pytest
+import atr.models.safe as safe
import atr.models.sql as sql
import atr.storage.types as types
import atr.storage.writers.revision as revision
@@ -58,6 +59,10 @@ class FakeRevision:
self.release_name = release_name
self.was_quarantined = was_quarantined
+ @property
+ def safe_number(self) -> safe.RevisionNumber:
+ return safe.RevisionNumber(self.number)
+
class MockSafeData:
def __init__(self, parent_name: str, new_number: str = "00006"):
@@ -109,10 +114,12 @@ async def
test_clone_from_older_revision_skips_merge_without_intervening_change(
latest_revision = mock.MagicMock()
latest_revision.name = f"{release_name} 00005"
latest_revision.number = "00005"
+ latest_revision.safe_number = safe.RevisionNumber("00005")
selected_revision = mock.MagicMock()
selected_revision.name = f"{release_name} 00002"
selected_revision.number = "00002"
+ selected_revision.safe_number = safe.RevisionNumber("00002")
mock_session = _mock_db_session(release,
selected_revision=selected_revision)
participant = _make_participant()
@@ -149,7 +156,9 @@ async def
test_clone_from_older_revision_skips_merge_without_intervening_change(
mock.patch.object(revision.paths, "release_directory",
return_value=tmp_path / "releases" / "00006"),
mock.patch.object(revision.paths, "release_directory_base",
return_value=tmp_path / "releases"),
):
- await participant.create_revision_with_quarantine("proj", "1.0",
"test", clone_from="00002")
+ await participant.create_revision_with_quarantine(
+ safe.ProjectName("proj"), safe.VersionName("1.0"), "test",
clone_from=safe.RevisionNumber("00002")
+ )
if merge_mock.called:
raise AssertionError(
@@ -193,10 +202,12 @@ async def
test_intervening_revision_triggers_merge_and_uses_latest_parent(tmp_pa
old_revision = mock.MagicMock()
old_revision.name = f"{release_name} 00005"
old_revision.number = "00005"
+ old_revision.safe_number = safe.RevisionNumber("00005")
intervening_revision = mock.MagicMock()
intervening_revision.name = f"{release_name} 00006"
intervening_revision.number = "00006"
+ intervening_revision.safe_number = safe.RevisionNumber("00006")
first_attestable = mock.MagicMock(paths={"dist/a.tar.gz": "h1"})
second_attestable = mock.MagicMock(paths={"dist/b.tar.gz": "h2"})
@@ -244,7 +255,7 @@ async def
test_intervening_revision_triggers_merge_and_uses_latest_parent(tmp_pa
merge_await_args = merge_mock.await_args
assert merge_await_args is not None
- assert merge_await_args.args[5] == "00006"
+ assert merge_await_args.args[5] == safe.RevisionNumber("00006")
@pytest.mark.asyncio
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]