This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push:
new d79a675 Implement atomic operatons for draft modifications
d79a675 is described below
commit d79a675f4d72066fe7acabcda55d898757f2f835
Author: Sean B. Palmer <[email protected]>
AuthorDate: Wed Apr 9 14:33:01 2025 +0100
Implement atomic operatons for draft modifications
- Adds a revision.py module with a context manager for managing a new
revision when modifying a draft.
- Ensures that file operations in the web interface, such as adding or
deleting files, creates new draft revisions.
- Standardises the interface for calling check functions using a
FunctionArgument dataclass (due to a problem using Pydantic with this).
- Allow access to absolute paths through a newly renamed Recorder class.
- Integrates draft check task queueing with the revision creation process.
- Modifies the Task model to store revision and path context directly.
- Ensures that check functions operate on a specific revision.
- Uses "latest" for user facing views of draft content.
- Refactors the ssh and rsync processing to use the new revision logic.
- Improves session management within the worker process.
- Removes the RSYNC_ANALYSE task, integrating this into the context manager.
---
atr/db/models.py | 7 +-
atr/revision.py | 142 ++++++++++++++++++
atr/routes/candidate.py | 23 ++-
atr/routes/draft.py | 333 ++++++++++++++++++++++++------------------
atr/server.py | 2 +-
atr/ssh.py | 272 +++++++++++++++++++---------------
atr/tasks/__init__.py | 147 +++++++------------
atr/tasks/checks/__init__.py | 121 +++++++--------
atr/tasks/checks/hashing.py | 18 +--
atr/tasks/checks/license.py | 36 ++---
atr/tasks/checks/paths.py | 39 +++--
atr/tasks/checks/rat.py | 39 ++---
atr/tasks/checks/signature.py | 46 +++---
atr/tasks/checks/targz.py | 41 ++----
atr/tasks/checks/zipformat.py | 64 ++++----
atr/tasks/rsync.py | 51 -------
atr/util.py | 169 ++++++++++-----------
atr/worker.py | 51 ++++++-
18 files changed, 863 insertions(+), 738 deletions(-)
diff --git a/atr/db/models.py b/atr/db/models.py
index 31319fc..9e9f187 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -291,7 +291,7 @@ class TaskType(str, enum.Enum):
LICENSE_HEADERS = "license_headers"
PATHS_CHECK = "paths_check"
RAT_CHECK = "rat_check"
- RSYNC_ANALYSE = "rsync_analyse"
+ # RSYNC_ANALYSE = "rsync_analyse"
SBOM_GENERATE_CYCLONEDX = "sbom_generate_cyclonedx"
SIGNATURE_CHECK = "signature_check"
TARGZ_INTEGRITY = "targz_integrity"
@@ -325,10 +325,13 @@ class Task(sqlmodel.SQLModel, table=True):
)
result: Any | None = sqlmodel.Field(default=None,
sa_column=sqlalchemy.Column(sqlalchemy.JSON))
error: str | None = None
+
+ # Used for check tasks
+ # We don't put these in task_args because we want to query them efficiently
release_name: str | None = sqlmodel.Field(default=None,
foreign_key="release.name")
release: Optional["Release"] =
sqlmodel.Relationship(back_populates="tasks")
- # Identifier for the draft revision that this task targets, if any
draft_revision: str | None = sqlmodel.Field(default=None, index=True)
+ primary_rel_path: str | None = sqlmodel.Field(default=None, index=True)
# Create an index on status and added for efficient task claiming
__table_args__ = (
diff --git a/atr/revision.py b/atr/revision.py
new file mode 100644
index 0000000..95e6bc7
--- /dev/null
+++ b/atr/revision.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import contextlib
+import datetime
+import logging
+import pathlib
+from collections.abc import AsyncGenerator
+
+import aiofiles.os
+import aioshutil
+
+import atr.db as db
+import atr.db.models as models
+import atr.tasks as tasks
+import atr.util as util
+
+_LOGGER = logging.getLogger(__name__)
+
+
[email protected]
+async def create_and_manage(
+ project_name: str, version_name: str, asf_uid: str
+) -> AsyncGenerator[tuple[pathlib.Path, str]]:
+ """Manage the creation and symlinking of a draft release candidate
revision."""
+ draft_base_dir = util.get_release_candidate_draft_dir()
+ release_dir = draft_base_dir / project_name / version_name
+ latest_symlink_path = release_dir / "latest"
+ new_revision_name = _new_name(asf_uid)
+ new_revision_dir = release_dir / new_revision_name
+
+ # Ensure that the base directory for the release exists
+ await aiofiles.os.makedirs(release_dir, exist_ok=True)
+
+ # Check for the parent revision
+ parent_revision_dir, parent_revision_id = await
_manage_draft_revision_find_parent(release_dir, latest_symlink_path)
+
+ temp_dir_created = False
+ try:
+ # Create the new revision directory
+ if parent_revision_dir:
+ _LOGGER.info(f"Creating new revision {new_revision_name} by
hard-linking from {parent_revision_id}")
+ await util.create_hard_link_clone(parent_revision_dir,
new_revision_dir)
+ else:
+ _LOGGER.info(f"Creating new empty revision directory
{new_revision_name}")
+ await aiofiles.os.makedirs(new_revision_dir)
+ temp_dir_created = True
+
+ # Yield control to the block within "async with"
+ yield new_revision_dir, new_revision_name
+
+ # If the "with" block completed without error, store the parent link
+ if parent_revision_id is not None:
+ _LOGGER.info(f"Storing parent link for {new_revision_name} ->
{parent_revision_id}")
+ try:
+ async with db.session() as data:
+ async with data.begin():
+ data.add(models.TextValue(ns="draft_parent",
key=new_revision_name, value=parent_revision_id))
+ except Exception as db_e:
+ _LOGGER.error(f"Failed to store parent link for
{new_revision_name}: {db_e}")
+ # Raise again to ensure clean up in the finally block
+ raise
+
+ _LOGGER.info(f'Updating "latest" symlink to point to
{new_revision_name}')
+ # Target must be relative for the symlink to work correctly within the
release directory
+ await util.update_atomic_symlink(latest_symlink_path,
new_revision_name)
+ # Schedule the checks to be run
+ await tasks.draft_checks(project_name, version_name, new_revision_name)
+
+ except Exception:
+ _LOGGER.exception(f"Error during draft revision management for
{new_revision_name}, cleaning up")
+ # Raise the exception again after the clean up attempt
+ raise
+ finally:
+ # Clean up only if an error occurred during the "with" block or
initial setup
+ # Check whether new_revision_dir exists and whether we should remove it
+ if temp_dir_created:
+ # Determine whether an exception occurred within the "with" block
+ # We just check whether the symlink was updated
+ should_clean_up = True
+ if await aiofiles.os.path.islink(latest_symlink_path):
+ try:
+ target = await
aiofiles.os.readlink(str(latest_symlink_path))
+ if target == new_revision_name:
+ # Symlink points to the new dir, assume success
+ should_clean_up = False
+ except OSError:
+ # Error reading link, proceed with clean up
+ ...
+
+ if should_clean_up:
+ _LOGGER.warning(f"Cleaning up potentially incomplete revision
directory: {new_revision_dir}")
+ with contextlib.suppress(Exception):
+ # Prevent clean_up errors from masking original exception
+ await aioshutil.rmtree(new_revision_dir) # type:
ignore[call-arg]
+
+
+async def _manage_draft_revision_find_parent(
+ release_dir: pathlib.Path, latest_symlink_path: pathlib.Path
+) -> tuple[pathlib.Path | None, str | None]:
+ """Check for and validate the parent revision based on the "latest"
symlink."""
+ parent_revision_dir: pathlib.Path | None = None
+ parent_revision_id: str | None = None
+
+ if await aiofiles.os.path.islink(latest_symlink_path):
+ try:
+ target = await aiofiles.os.readlink(str(latest_symlink_path))
+ # Assume target is relative to release_dir
+ potential_parent_dir = (release_dir / target).resolve()
+ if await aiofiles.os.path.isdir(potential_parent_dir):
+ parent_revision_dir = potential_parent_dir
+ parent_revision_id = potential_parent_dir.name
+ _LOGGER.info(f'Found existing "latest" pointing to parent
revision: {parent_revision_id}')
+ else:
+ _LOGGER.warning(f'A "latest" symlink exists but points to
non-directory: {target}, treating as new')
+ except OSError as e:
+ _LOGGER.warning(f'Error reading "latest" symlink: {e}, treating as
new')
+
+ return parent_revision_dir, parent_revision_id
+
+
+def _new_name(asf_uid: str) -> str:
+ """Generate a new revision name with timestamp truncated to
milliseconds."""
+ now_utc = datetime.datetime.now(datetime.UTC)
+ time_prefix = now_utc.strftime("%Y-%m-%dT%H.%M.%S")
+ milliseconds = now_utc.microsecond // 1000
+ timestamp_str = f"{time_prefix}.{milliseconds:03d}Z"
+ return f"{asf_uid}@{timestamp_str}"
diff --git a/atr/routes/candidate.py b/atr/routes/candidate.py
index d8227e0..5cafb67 100644
--- a/atr/routes/candidate.py
+++ b/atr/routes/candidate.py
@@ -31,6 +31,7 @@ import wtforms
import atr.db as db
import atr.db.models as models
+import atr.revision as revision
import atr.routes as routes
import atr.tasks.vote as tasks_vote
import atr.user as user
@@ -87,6 +88,7 @@ async def viewer(session: routes.CommitterSession,
project_name: str, version_na
file_stats = [
stat async for stat in
util.content_list(util.get_release_candidate_dir(), project_name, version_name)
]
+ logging.warning(f"File stats: {file_stats}")
return await quart.render_template(
"phase-viewer.html",
@@ -415,20 +417,27 @@ async def _resolve_post(session: routes.CommitterSession)
-> response.Response:
await data.commit()
- await _resolve_post_files(project_name, release, vote_result)
+ await _resolve_post_files(project_name, release, vote_result, session.uid)
return await session.redirect(resolve, success=success_message)
-async def _resolve_post_files(project_name: str, release: models.Release,
vote_result: str) -> None:
+async def _resolve_post_files(project_name: str, release: models.Release,
vote_result: str, asf_uid: str) -> None:
# TODO: Obtain a lock for this
source = str(util.get_release_candidate_dir() / project_name /
release.version)
if vote_result == "passed":
+ # The vote passed, so promote the release candidate to the release
preview directory
target = str(util.get_release_preview_dir() / project_name /
release.version)
- else:
- target = str(util.get_release_candidate_draft_dir() / project_name /
release.version)
- if await aiofiles.os.path.exists(target):
- raise base.ASFQuartException("Release already exists", errorcode=400)
- await aioshutil.move(source, target)
+ if await aiofiles.os.path.exists(target):
+ raise base.ASFQuartException("Release already exists",
errorcode=400)
+ await aioshutil.move(source, target)
+ return
+
+ # The vote failed, so move the release candidate to the release draft
directory
+ async with revision.create_and_manage(project_name, release.version,
asf_uid) as (
+ new_revision_dir,
+ _new_revision_name,
+ ):
+ await aioshutil.move(source, new_revision_dir)
async def _task_archive_url(task_mid: str) -> str | None:
diff --git a/atr/routes/draft.py b/atr/routes/draft.py
index 14c120d..5f25ba2 100644
--- a/atr/routes/draft.py
+++ b/atr/routes/draft.py
@@ -36,8 +36,8 @@ import wtforms
import atr.analysis as analysis
import atr.db as db
import atr.db.models as models
+import atr.revision as revision
import atr.routes as routes
-import atr.tasks as tasks
import atr.tasks.sbom as sbom
import atr.util as util
@@ -111,7 +111,7 @@ async def _number_of_release_files(release: models.Release)
-> int:
"""Return the number of files in the release."""
path_project = release.project.name
path_version = release.version
- path = util.get_release_candidate_draft_dir() / path_project / path_version
+ path = util.get_release_candidate_draft_dir() / path_project /
path_version / "latest"
return len(await util.paths_recursive(path))
@@ -187,7 +187,7 @@ async def add_file(session: routes.CommitterSession,
project_name: str, version_
file_name = pathlib.Path(form.file_name.data)
file_data = form.file_data.data
- number_of_files = await _upload_files(project_name, version_name,
file_name, file_data)
+ number_of_files = await _upload_files(project_name, version_name,
session.uid, file_name, file_data)
return await session.redirect(
review,
success=f"{number_of_files} file{'' if number_of_files == 1
else 's'} added successfully",
@@ -255,7 +255,7 @@ async def delete(session: routes.CommitterSession) ->
response.Response:
@routes.committer("/draft/delete-file/<project_name>/<version_name>",
methods=["POST"])
async def delete_file(session: routes.CommitterSession, project_name: str,
version_name: str) -> response.Response:
- """Delete a specific file from the release candidate."""
+ """Delete a specific file from the release candidate, creating a new
revision."""
form = await DeleteFileForm.create_form(data=await quart.request.form)
if not await form.validate_on_submit():
return await session.redirect(review, project_name=project_name,
version_name=version_name)
@@ -264,40 +264,49 @@ async def delete_file(session: routes.CommitterSession,
project_name: str, versi
if not any((p.name == project_name) for p in (await
session.user_projects)):
raise base.ASFQuartException("You do not have access to this project",
errorcode=403)
- async with db.session() as data:
- # Check that the release exists
- await data.release(name=models.release_name(project_name,
version_name), _project=True).demand(
- base.ASFQuartException("Release does not exist", errorcode=404)
- )
-
- file_path = str(form.file_path.data)
- full_path_obj = util.get_release_candidate_draft_dir() / project_name
/ version_name / file_path
- full_path = str(full_path_obj)
-
- # Check that the file exists
- if not await aiofiles.os.path.exists(full_path):
- raise base.ASFQuartException("File does not exist", errorcode=404)
-
- # Check whether the file is an artifact
- metadata_files = 0
- if analysis.is_artifact(full_path_obj):
- # If so, delete all associated metadata files
- for p in await util.paths_recursive(full_path_obj.parent):
- if p.name.startswith(full_path_obj.name + "."):
- await aiofiles.os.remove(full_path_obj.parent / p.name)
- metadata_files += 1
+ rel_path_to_delete = pathlib.Path(str(form.file_path.data))
+ metadata_files_deleted = 0
- # Delete the file
- await aiofiles.os.remove(full_path)
-
- # Ensure that checks are queued again
- await tasks.draft_checks(project_name, version_name, caller_data=data)
- await data.commit()
+ try:
+ async with revision.create_and_manage(project_name, version_name,
session.uid) as (
+ new_revision_dir,
+ new_revision_name,
+ ):
+ # Path to delete within the new revision directory
+ path_in_new_revision = new_revision_dir / rel_path_to_delete
+
+ # Check that the file exists in the new revision
+ if not await aiofiles.os.path.exists(path_in_new_revision):
+ # This indicates a potential severe issue with hard linking or
logic
+ logging.error(
+ f"SEVERE ERROR! File {rel_path_to_delete} not found in new
revision"
+ f" {new_revision_name} before deletion"
+ )
+ raise routes.FlashError("File to delete was not found in the
new revision")
+
+ # Check whether the file is an artifact
+ if analysis.is_artifact(path_in_new_revision):
+ # If so, delete all associated metadata files in the new
revision
+ for p in await
util.paths_recursive(path_in_new_revision.parent):
+ # Construct full path within the new revision
+ metadata_path_obj = new_revision_dir / p
+ if p.name.startswith(rel_path_to_delete.name + "."):
+ await aiofiles.os.remove(metadata_path_obj)
+ metadata_files_deleted += 1
+
+ # Delete the file
+ await aiofiles.os.remove(path_in_new_revision)
+
+ except Exception as e:
+ logging.exception("Error deleting file:")
+ await quart.flash(f"Error deleting file: {e!s}", "error")
+ return await session.redirect(review, project_name=project_name,
version_name=version_name)
- success_message = "File deleted successfully"
- if metadata_files:
+ success_message = f"File '{rel_path_to_delete.name}' deleted successfully"
+ if metadata_files_deleted:
success_message += (
- f", and {metadata_files} associated metadata file{'' if
metadata_files == 1 else 's'} deleted"
+ f", and {metadata_files_deleted} associated metadata "
+ f"file{'' if metadata_files_deleted == 1 else 's'} deleted"
)
return await session.redirect(review, success=success_message,
project_name=project_name, version_name=version_name)
@@ -306,52 +315,55 @@ async def delete_file(session: routes.CommitterSession,
project_name: str, versi
async def hashgen(
session: routes.CommitterSession, project_name: str, version_name: str,
file_path: str
) -> response.Response:
- """Generate an sha256 or sha512 hash file for a candidate draft file."""
+ """Generate an sha256 or sha512 hash file for a candidate draft file,
creating a new revision."""
# Check that the user has access to the project
if not any((p.name == project_name) for p in (await
session.user_projects)):
raise base.ASFQuartException("You do not have access to this project",
errorcode=403)
- async with db.session() as data:
- # Check that the release exists
- await data.release(name=models.release_name(project_name,
version_name), _project=True).demand(
- base.ASFQuartException("Release does not exist", errorcode=404)
- )
-
- # Get the hash type from the form data
- # This is just a button, so we don't make a whole form validation
schema for it
- form = await quart.request.form
- hash_type = form.get("hash_type")
- if hash_type not in {"sha256", "sha512"}:
- raise base.ASFQuartException("Invalid hash type", errorcode=400)
+ # Get the hash type from the form data
+ # This is just a button, so we don't make a whole form validation schema
for it
+ form = await quart.request.form
+ hash_type = form.get("hash_type")
+ if hash_type not in {"sha256", "sha512"}:
+ raise base.ASFQuartException("Invalid hash type", errorcode=400)
- # Construct paths
- base_path = util.get_release_candidate_draft_dir() / project_name /
version_name
- full_path = base_path / file_path
- hash_path = file_path + f".{hash_type}"
- full_hash_path = base_path / hash_path
+ rel_path = pathlib.Path(file_path)
- # Check that the source file exists
- if not await aiofiles.os.path.exists(full_path):
- raise base.ASFQuartException("Source file does not exist",
errorcode=404)
-
- # Check that the hash file does not already exist
- if await aiofiles.os.path.exists(full_hash_path):
- raise base.ASFQuartException(f"{hash_type} file already exists",
errorcode=400)
-
- # Read the file and compute the hash
- hash_obj = hashlib.sha256() if hash_type == "sha256" else
hashlib.sha512()
- async with aiofiles.open(full_path, "rb") as f:
- while chunk := await f.read(8192):
- hash_obj.update(chunk)
-
- # Write the hash file
- hash_value = hash_obj.hexdigest()
- async with aiofiles.open(full_hash_path, "w") as f:
- await f.write(f"{hash_value} {file_path}\n")
-
- # Ensure that checks are queued again
- await tasks.draft_checks(project_name, version_name, caller_data=data)
- await data.commit()
+ try:
+ async with revision.create_and_manage(project_name, version_name,
session.uid) as (
+ new_revision_dir,
+ new_revision_name,
+ ):
+ path_in_new_revision = new_revision_dir / rel_path
+ hash_path_rel = rel_path.name + f".{hash_type}"
+ hash_path_in_new_revision = new_revision_dir / rel_path.parent /
hash_path_rel
+
+ # Check that the source file exists in the new revision
+ if not await aiofiles.os.path.exists(path_in_new_revision):
+ logging.error(
+ f"Source file {rel_path} not found in new revision
{new_revision_name} for hash generation."
+ )
+ raise routes.FlashError("Source file not found in the new
revision.")
+
+ # Check that the hash file does not already exist in the new
revision
+ if await aiofiles.os.path.exists(hash_path_in_new_revision):
+ raise base.ASFQuartException(f"{hash_type} file already
exists", errorcode=400)
+
+ # Read the source file from the new revision and compute the hash
+ hash_obj = hashlib.sha256() if hash_type == "sha256" else
hashlib.sha512()
+ async with aiofiles.open(path_in_new_revision, "rb") as f:
+ while chunk := await f.read(8192):
+ hash_obj.update(chunk)
+
+ # Write the hash file into the new revision
+ hash_value = hash_obj.hexdigest()
+ async with aiofiles.open(hash_path_in_new_revision, "w") as f:
+ await f.write(f"{hash_value} {rel_path.name}\n")
+
+ except Exception as e:
+ logging.exception("Error generating hash file:")
+ await quart.flash(f"Error generating hash file: {e!s}", "error")
+ return await session.redirect(review, project_name=project_name,
version_name=version_name)
return await session.redirect(
review, success=f"{hash_type} file generated successfully",
project_name=project_name, version_name=version_name
@@ -448,7 +460,7 @@ async def review(session: routes.CommitterSession,
project_name: str, version_na
base.ASFQuartException("Release does not exist", errorcode=404)
)
- base_path = util.get_release_candidate_draft_dir() / project_name /
version_name
+ base_path = util.get_release_candidate_draft_dir() / project_name /
version_name / "latest"
paths = await util.paths_recursive(base_path)
# paths_set = set(paths)
path_templates = {}
@@ -486,7 +498,7 @@ async def review(session: routes.CommitterSession,
project_name: str, version_na
path_metadata.add(path)
# Get modified time
- full_path = str(util.get_release_candidate_draft_dir() / project_name
/ version_name / path)
+ full_path = str(util.get_release_candidate_draft_dir() / project_name
/ version_name / "latest" / path)
path_modified[path] = int(await aiofiles.os.path.getmtime(full_path))
# Get successes, warnings, and errors
@@ -547,7 +559,8 @@ async def review_path(session: routes.CommitterSession,
project_name: str, versi
base.ASFQuartException("Release does not exist", errorcode=404)
)
- abs_path = util.get_release_candidate_draft_dir() / project_name /
version_name / rel_path
+ # TODO: When we do more than one thing in a dir, we should use the
revision directory directly
+ abs_path = util.get_release_candidate_draft_dir() / project_name /
version_name / "latest" / rel_path
# Check that the file exists
if not await aiofiles.os.path.exists(abs_path):
@@ -594,53 +607,75 @@ async def review_path(session: routes.CommitterSession,
project_name: str, versi
async def sbomgen(
session: routes.CommitterSession, project_name: str, version_name: str,
file_path: str
) -> response.Response:
- """Generate a CycloneDX SBOM file for a candidate draft file."""
+ """Generate a CycloneDX SBOM file for a candidate draft file, creating a
new revision."""
# Check that the user has access to the project
if not any((p.name == project_name) for p in (await
session.user_projects)):
raise base.ASFQuartException("You do not have access to this project",
errorcode=403)
- async with db.session() as data:
- # Check that the release exists
- release = await data.release(name=models.release_name(project_name,
version_name), _project=True).demand(
- base.ASFQuartException("Release does not exist", errorcode=404)
- )
+ rel_path = pathlib.Path(file_path)
- # Construct paths
- base_path = util.get_release_candidate_draft_dir() / project_name /
version_name
- full_path = base_path / file_path
- # Standard CycloneDX extension
- sbom_path_rel = file_path + ".cdx.json"
- full_sbom_path = base_path / sbom_path_rel
+ # Check that the file is a .tar.gz archive before creating a revision
+ if not (file_path.endswith(".tar.gz") or file_path.endswith(".tgz")):
+ raise base.ASFQuartException("SBOM generation is only supported for
.tar.gz files", errorcode=400)
- # Check that the source file exists
- if not await aiofiles.os.path.exists(full_path):
- raise base.ASFQuartException("Source artifact file does not
exist", errorcode=404)
-
- # Check that the file is a .tar.gz archive
- if not file_path.endswith(".tar.gz"):
- raise base.ASFQuartException("SBOM generation is only supported
for .tar.gz files", errorcode=400)
-
- # Check that the SBOM file does not already exist
- if await aiofiles.os.path.exists(full_sbom_path):
- raise base.ASFQuartException("SBOM file already exists",
errorcode=400)
-
- # Create and queue the task
- sbom_task = models.Task(
- task_type=models.TaskType.SBOM_GENERATE_CYCLONEDX,
- task_args=sbom.GenerateCycloneDX(
- artifact_path=str(full_path),
- output_path=str(full_sbom_path),
- ).model_dump(),
- added=datetime.datetime.now(datetime.UTC),
- status=models.TaskStatus.QUEUED,
- release_name=release.name,
- )
- data.add(sbom_task)
- await data.commit()
+ try:
+ async with revision.create_and_manage(project_name, version_name,
session.uid) as (
+ new_revision_dir,
+ new_revision_name,
+ ):
+ path_in_new_revision = new_revision_dir / rel_path
+ sbom_path_rel = rel_path.with_suffix(rel_path.suffix +
".cdx.json").name
+ sbom_path_in_new_revision = new_revision_dir / rel_path.parent /
sbom_path_rel
+
+ # Check that the source file exists in the new revision
+ if not await aiofiles.os.path.exists(path_in_new_revision):
+ logging.error(
+ f"Source file {rel_path} not found in new revision
{new_revision_name} for SBOM generation."
+ )
+ raise routes.FlashError("Source artifact file not found in the
new revision.")
+
+ # Check that the SBOM file does not already exist in the new
revision
+ if await aiofiles.os.path.exists(sbom_path_in_new_revision):
+ raise base.ASFQuartException("SBOM file already exists",
errorcode=400)
+
+ # Create and queue the task, using paths within the new revision
+ async with db.session() as data:
+ # We still need release.name for the task metadata
+ release = await data.release(
+ name=models.release_name(project_name, version_name),
_project=True
+ ).demand(base.ASFQuartException("Release does not exist",
errorcode=404))
+
+ sbom_task = models.Task(
+ task_type=models.TaskType.SBOM_GENERATE_CYCLONEDX,
+ task_args=sbom.GenerateCycloneDX(
+ artifact_path=str(path_in_new_revision.resolve()),
+ output_path=str(sbom_path_in_new_revision.resolve()),
+ ).model_dump(),
+ added=datetime.datetime.now(datetime.UTC),
+ status=models.TaskStatus.QUEUED,
+ release_name=release.name,
+ draft_revision=new_revision_name,
+ )
+ data.add(sbom_task)
+ await data.commit()
+
+ # We must wait until the sbom_task is complete before we can
queue checks
+ # Maximum wait time is 60 * 100ms = 6000ms
+ for _attempt in range(60):
+ await data.refresh(sbom_task)
+ if sbom_task.status != models.TaskStatus.QUEUED:
+ break
+ # Wait 100ms before checking again
+ await asyncio.sleep(0.1)
+
+ except Exception as e:
+ logging.exception("Error generating SBOM:")
+ await quart.flash(f"Error generating SBOM: {e!s}", "error")
+ return await session.redirect(review, project_name=project_name,
version_name=version_name)
return await session.redirect(
review,
- success=f"SBOM generation task queued for
{pathlib.Path(file_path).name}",
+ success=f"SBOM generation task queued for {rel_path.name}",
project_name=project_name,
version_name=version_name,
)
@@ -659,7 +694,7 @@ async def tools(session: routes.CommitterSession,
project_name: str, version_nam
base.ASFQuartException("Release does not exist", errorcode=404)
)
- full_path = str(util.get_release_candidate_draft_dir() / project_name
/ version_name / file_path)
+ full_path = str(util.get_release_candidate_draft_dir() / project_name
/ version_name / "latest" / file_path)
# Check that the file exists
if not await aiofiles.os.path.exists(full_path):
@@ -734,7 +769,7 @@ async def viewer_path(
)
_max_view_size = 1 * 1024 * 1024
- full_path = util.get_release_candidate_draft_dir() / project_name /
version_name / file_path
+ full_path = util.get_release_candidate_draft_dir() / project_name /
version_name / "latest" / file_path
content, is_text, is_truncated, error_message = await
util.read_file_for_viewer(full_path, _max_view_size)
return await quart.render_template(
"phase-viewer-path.html",
@@ -850,7 +885,9 @@ async def _promote(
if release.phase != models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
return await session.redirect(promote, error="This release is not in
the candidate draft phase")
- source_dir = util.get_release_candidate_draft_dir() / project_name /
version_name
+ base_dir = util.get_release_candidate_draft_dir() / project_name /
version_name
+ # Read the "latest" symlink and make an absolute path from it relative to
base_dir
+ source_dir = base_dir / await aiofiles.os.readlink(str(base_dir /
"latest"))
target_dir: pathlib.Path
success_message: str
@@ -872,14 +909,16 @@ async def _promote(
target_dir = util.get_release_dir() / project_name / version_name
success_message = "Candidate draft successfully promoted to release
(after announcement)"
else:
- # Should not happen due to form validation
+ # Should not happen, due to form validation
return await session.redirect(promote, error="Unsupported target
phase")
if await aiofiles.os.path.exists(target_dir):
return await session.redirect(promote, error=f"Target directory
{target_dir.name} already exists")
await data.commit()
+ logging.warning(f"Moving {source_dir} to {target_dir} (base: {base_dir})")
await aioshutil.move(str(source_dir), str(target_dir))
+ await aioshutil.rmtree(str(base_dir)) # type: ignore[call-arg]
return await session.redirect(promote, success=success_message)
@@ -887,33 +926,39 @@ async def _promote(
async def _upload_files(
project_name: str,
version_name: str,
+ asf_uid: str,
file_name: pathlib.Path | None,
files: Sequence[datastructures.FileStorage],
) -> int:
- """Process and save the uploaded files."""
- # Create target directory
- target_dir = util.get_release_candidate_draft_dir() / project_name /
version_name
- target_dir.mkdir(parents=True, exist_ok=True)
-
- def get_filepath(file: datastructures.FileStorage) -> pathlib.Path:
- # Use the original filename if no path is specified
- if not file_name:
- if not file.filename:
- raise routes.FlashError("No filename provided")
- return pathlib.Path(file.filename)
- else:
- return file_name
-
- for file in files:
- # Save file to specified path
- file_path = get_filepath(file)
- target_path = target_dir / file_path.relative_to(file_path.anchor)
- target_path.parent.mkdir(parents=True, exist_ok=True)
-
- await _save_file(file, target_path)
-
- # Ensure that checks are queued again
- await tasks.draft_checks(project_name, version_name)
+ """Process and save the uploaded files into a new draft revision."""
+ async with revision.create_and_manage(project_name, version_name, asf_uid)
as (
+ new_revision_dir,
+ _new_revision_name,
+ ):
+
+ def get_target_path(file: datastructures.FileStorage) -> pathlib.Path:
+ # Determine the target path within the new revision directory
+ relative_file_path: pathlib.Path
+ if not file_name:
+ if not file.filename:
+ raise routes.FlashError("No filename provided")
+ # Use the original name
+ relative_file_path = pathlib.Path(file.filename)
+ else:
+ # Use the provided name, relative to its anchor
+ # In other words, ignore the leading "/"
+ relative_file_path = file_name.relative_to(file_name.anchor)
+
+ # Construct path inside the new revision directory
+ target_path = new_revision_dir / relative_file_path
+ return target_path
+
+ # Save each uploaded file to the new revision directory
+ for file in files:
+ target_path = get_target_path(file)
+ # Ensure parent directories exist within the new revision
+ target_path.parent.mkdir(parents=True, exist_ok=True)
+ await _save_file(file, target_path)
return len(files)
diff --git a/atr/server.py b/atr/server.py
index 19ffc0f..141408e 100644
--- a/atr/server.py
+++ b/atr/server.py
@@ -64,7 +64,7 @@ def register_routes(app: base.QuartApp) -> ModuleType:
# NOTE: These imports are for their side effects only
import atr.routes.modules as modules
- # Add a global error handler to show helpful error messages with
tracebacks.
+ # Add a global error handler to show helpful error messages with tracebacks
@app.errorhandler(Exception)
async def handle_any_exception(error: Exception) -> Any:
import traceback
diff --git a/atr/ssh.py b/atr/ssh.py
index ad7fb2b..6e95838 100644
--- a/atr/ssh.py
+++ b/atr/ssh.py
@@ -23,7 +23,7 @@ import datetime
import logging
import os
import string
-from typing import Final
+from typing import Final, TypeVar
import aiofiles
import aiofiles.os
@@ -32,13 +32,14 @@ import asyncssh
import atr.config as config
import atr.db as db
import atr.db.models as models
-import atr.tasks.rsync as rsync
+import atr.revision as revision
import atr.user as user
-import atr.util as util
_LOGGER: Final = logging.getLogger(__name__)
_CONFIG: Final = config.get()
+T = TypeVar("T")
+
class _SSHServer(asyncssh.SSHServer):
"""Simple SSH server that handles connections."""
@@ -165,66 +166,48 @@ def _command_path_validate(path: str) -> tuple[str, str]
| str:
return path_project, path_version
-def _command_simple_validate(argv: list[str]) -> str | None:
+def _command_simple_validate(argv: list[str]) -> tuple[str | None, int]:
if argv[0] != "rsync":
- return "The first argument should be rsync"
+ return "The first argument should be rsync", -1
if argv[1] != "--server":
- return "The second argument should be --server"
+ return "The second argument should be --server", -1
# TODO: Might need to accept permutations of this
# Also certain versions of rsync might change the options
acceptable_options: Final[str] = "vlogDtpre"
if not argv[2].startswith(f"-{acceptable_options}."):
- return f"The third argument should start with -{acceptable_options}."
+ return f"The third argument should start with -{acceptable_options}.",
-1
if not argv[2][len(f"-{acceptable_options}.") :].isalpha():
- return "The third argument should be a valid command"
+ return "The third argument should be a valid command", -1
# Support --delete as an optional argument before the path
if argv[3] != "--delete":
# No --delete, short command
if argv[3] != ".":
- return "The fourth argument should be ."
+ return "The fourth argument should be .", -1
if len(argv) != 5:
- return "There should be 5 arguments"
+ return "There should be 5 arguments", -1
+ path_index = 4
else:
# Has --delete, long command
if argv[4] != ".":
- return "The fifth argument should be ."
+ return "The fifth argument should be .", -1
if len(argv) != 6:
- return "There should be 6 arguments"
-
- return None
-
-
-async def _command_validate(process: asyncssh.SSHServerProcess) -> tuple[str,
str, list[str]] | None:
- def fail(message: str) -> tuple[str, str, list[str]] | None:
- # NOTE: Changing the return type to just None really confuses mypy
- _LOGGER.error(message)
- process.stderr.write(f"ATR SSH error: {message}\nCommand:
{process.command}\n".encode())
- process.exit(1)
- return None
-
- command = process.command
- if not command:
- return fail("No command specified")
-
- _LOGGER.info(f"Command received: {command}")
- argv = command.split()
+ return "There should be 6 arguments", -1
+ path_index = 5
- error = _command_simple_validate(argv)
- if error:
- return fail(error)
+ return None, path_index
- if argv[3] == "--delete":
- path_index = 5
- else:
- path_index = 4
+async def _command_validate(
+ process: asyncssh.SSHServerProcess, argv: list[str], path_index: int
+) -> tuple[str, str] | None:
result = _command_path_validate(argv[path_index])
if isinstance(result, str):
- return fail(result)
+ _fail(process, result, None)
+ return None
path_project, path_version = result
# Ensure that the user has permission to upload to this project
@@ -232,30 +215,141 @@ async def _command_validate(process:
asyncssh.SSHServerProcess) -> tuple[str, st
project = await data.project(name=path_project, _committee=True).get()
if not project:
# Projects are public, so existence information is public
- return fail("This project does not exist")
+ _fail(process, "This project does not exist", None)
+ return None
release = await data.release(project_id=project.id,
version=path_version).get()
# The SSH UID has also been validated by SSH as being the ASF UID
# Since users can only set an SSH key when authenticated using ASF
OAuth
ssh_uid = process.get_extra_info("username")
if not release:
# The user is requesting to create a new release
- # Check if the user has permission to create a release for this
project
+ # Check whether the user has permission to create a release for
this project
if not user.is_committee_member(project.committee, ssh_uid):
- return fail("You must be a member of this project's committee
to create a release")
+ _fail(process, "You must be a member of this project's
committee to create a release", None)
+ return None
else:
# The user is requesting to upload to an existing release
- # Check if the user has permission to upload to this release
+ # Check whether the user has permission to upload to this release
if not user.is_committer(release.committee, ssh_uid):
- return fail("You must be a member of this project's committee
or a committer to upload to this release")
+ _fail(
+ process,
+ "You must be a member of this project's committee or a
committer to upload to this release",
+ None,
+ )
+ return None
+ return path_project, path_version
+
+
+async def _ensure_release_object(
+ process: asyncssh.SSHServerProcess, project_name: str, version_name: str,
new_draft_revision: str
+) -> bool:
+ try:
+ async with db.session() as data:
+ async with data.begin():
+ release = await data.release(
+ name=models.release_name(project_name, version_name),
_committee=True
+ ).get()
+ if release is None:
+ project = await data.project(name=project_name,
_committee=True).demand(
+ RuntimeError("Project not found")
+ )
+ # Create a new release object
+ release = models.Release(
+ project_id=project.id,
+ project=project,
+ version=version_name,
+ stage=models.ReleaseStage.RELEASE_CANDIDATE,
+ phase=models.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
+ created=datetime.datetime.now(datetime.UTC),
+ )
+ data.add(release)
+ elif release.phase !=
models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
+ return _fail(
+ process, f"Release {release.name} is no longer in
draft phase ({release.phase.value})", False
+ )
+
+ # # TODO: We now do this in the context manager, so we can
delete the rsync task
+ # data.add(
+ # models.Task(
+ # status=models.TaskStatus.QUEUED,
+ # task_type=models.TaskType.RSYNC_ANALYSE,
+ # task_args=rsync.Analyse(
+ # project_name=project_name,
+ # release_version=version_name,
+ # draft_revision=new_draft_revision,
+ # ).model_dump(),
+ # release_name=models.release_name(project_name,
version_name),
+ # draft_revision=new_draft_revision,
+ # )
+ # )
+ return True
+ except Exception as e:
+ _LOGGER.exception("Error finalising upload in database")
+ return _fail(process, f"Internal error finalising upload: {e}", False)
+
+
+async def _execute_rsync(process: asyncssh.SSHServerProcess, argv: list[str])
-> int:
+ # This is Step 2 of the upload process
+ _LOGGER.info(f"Executing modified command: {' '.join(argv)}")
+ proc = await asyncio.create_subprocess_shell(
+ " ".join(argv),
+ stdin=asyncio.subprocess.PIPE,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
+ await process.redirect(stdin=proc.stdin, stdout=proc.stdout,
stderr=proc.stderr)
+ exit_status = await proc.wait()
+ return exit_status
- # Set the target directory to the release storage directory
- argv[path_index] = str(util.get_release_candidate_draft_dir() /
path_project / path_version)
- _LOGGER.info(f"Modified command: {argv}")
- # Create the release's storage directory if it doesn't exist
- await aiofiles.os.makedirs(argv[path_index], exist_ok=True)
+def _fail(proc: asyncssh.SSHServerProcess, message: str, return_value: T) -> T:
+ _LOGGER.error(message)
+ proc.stderr.write(f"ATR SSH error: {message}\n".encode())
+ proc.exit(1)
+ return return_value
- return path_project, path_version, argv
+
+async def _process_validated_rsync(
+ process: asyncssh.SSHServerProcess,
+ argv: list[str],
+ path_index: int,
+ project_name: str,
+ version_name: str,
+) -> None:
+ asf_uid = process.get_extra_info("username")
+ exit_status = 1
+
+ try:
+ async with revision.create_and_manage(project_name, version_name,
asf_uid) as (
+ new_revision_dir,
+ new_draft_revision,
+ ):
+ # Update the rsync command path to the new temporary revision
directory
+ argv[path_index] = str(new_revision_dir)
+
+ # Execute the rsync command
+ exit_status = await _execute_rsync(process, argv)
+ if exit_status != 0:
+ _LOGGER.error(
+ f"rsync failed with exit status {exit_status} for revision
{new_draft_revision}. \
+ Command: {process.command} (run as {' '.join(argv)})"
+ )
+ process.exit(exit_status)
+ return
+
+ # Ensure that the release object exists and is in the correct phase
+ if not await _ensure_release_object(process, project_name,
version_name, new_draft_revision):
+ process.exit(1)
+ return
+
+ # Exit with the rsync exit status
+ process.exit(exit_status)
+
+ except Exception as e:
+ _LOGGER.exception(f"Error during draft revision processing for
{project_name}-{version_name}")
+ _fail(process, f"Internal error processing revision: {e}", None)
+ if not process.is_closing():
+ process.exit(1)
async def _handle_client(process: asyncssh.SSHServerProcess) -> None:
@@ -263,73 +357,23 @@ async def _handle_client(process:
asyncssh.SSHServerProcess) -> None:
asf_uid = process.get_extra_info("username")
_LOGGER.info(f"Handling command for authenticated user: {asf_uid}")
- validation_results = await _command_validate(process)
- if not validation_results:
+ if not process.command:
+ process.stderr.write(b"ATR SSH error: No command specified\n")
+ process.exit(1)
return
- project_name, release_version, argv = validation_results
- try:
- # Create subprocess to actually run the command
- # NOTE: asyncio base_events subprocess_shell requires cmd be str |
bytes
- # Ought to be list[str] | list[bytes] really
- proc = await asyncio.create_subprocess_shell(
- " ".join(argv),
- stdin=asyncio.subprocess.PIPE,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE,
- )
-
- # Redirect I/O between SSH process and the subprocess
- await process.redirect(stdin=proc.stdin, stdout=proc.stdout,
stderr=proc.stderr)
-
- # Wait for the process to complete
- exit_status = await proc.wait()
- if exit_status != 0:
- _LOGGER.error(f"Command {process.command} failed with exit status
{exit_status}")
- process.exit(exit_status)
- return
+ _LOGGER.info(f"Command received: {process.command}")
+ argv = process.command.split()
- # Start a task to process the new files
- async with db.session() as data:
- release = await
data.release(name=models.release_name(project_name, release_version),
_committee=True).get()
- # Create the release if it does not already exist
- if release is None:
- project = await data.project(name=project_name,
_committee=True).demand(
- RuntimeError("Project not found")
- )
- release = models.Release(
- project_id=project.id,
- project=project,
- version=release_version,
- stage=models.ReleaseStage.RELEASE_CANDIDATE,
- phase=models.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
- created=datetime.datetime.now(),
- )
- data.add(release)
- await data.commit()
- if release.stage != models.ReleaseStage.RELEASE_CANDIDATE:
- raise RuntimeError("Release is not in the candidate stage")
- if release.phase != models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
- raise RuntimeError("Release is not in the candidate draft
phase")
-
- # Add a task to analyse the new files
- data.add(
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.RSYNC_ANALYSE,
- task_args=rsync.Analyse(
- project_name=project_name,
- release_version=release_version,
- ).model_dump(),
- )
- )
- await data.commit()
+ simple_validation_error, path_index = _command_simple_validate(argv)
+ if simple_validation_error:
+ process.stderr.write(f"ATR SSH error:
{simple_validation_error}\nCommand: {process.command}\n".encode())
+ process.exit(1)
+ return
- # Exit the SSH process with the same status as the rsync process
- # Should be 0 here
- process.exit(exit_status)
+ validation_results = await _command_validate(process, argv, path_index)
+ if not validation_results:
+ return
+ project_name, version_name = validation_results
- except Exception as e:
- _LOGGER.exception(f"Error executing command {process.command}")
- process.stderr.write(f"Error: {e!s}\n")
- process.exit(1)
+ await _process_validated_rsync(process, argv, path_index, project_name,
version_name)
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 34d56e0..3d19989 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -21,7 +21,6 @@ from typing import Any, Final
import atr.db as db
import atr.db.models as models
-import atr.tasks.checks as checks
import atr.tasks.checks.hashing as hashing
import atr.tasks.checks.license as license
import atr.tasks.checks.paths as paths
@@ -29,37 +28,38 @@ import atr.tasks.checks.rat as rat
import atr.tasks.checks.signature as signature
import atr.tasks.checks.targz as targz
import atr.tasks.checks.zipformat as zipformat
-import atr.tasks.rsync as rsync
+
+# import atr.tasks.rsync as rsync
import atr.tasks.sbom as sbom
import atr.tasks.vote as vote
import atr.util as util
-async def asc_checks(release: models.Release, signature_path: str) ->
list[models.Task]:
+async def asc_checks(release: models.Release, draft_revision: str,
signature_path: str) -> list[models.Task]:
"""Create signature check task for a .asc file."""
tasks = []
if release.committee:
tasks.append(
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.SIGNATURE_CHECK,
- task_args=signature.Check(
- release_name=release.name,
- committee_name=release.committee.name,
- signature_rel_path=signature_path,
- ).model_dump(),
- release_name=release.name,
- ),
+ queued(
+ models.TaskType.SIGNATURE_CHECK,
+ release,
+ draft_revision,
+ signature_path,
+ {"extra_args": {"committee_name": release.committee.name}},
+ )
)
return tasks
-async def draft_checks(project_name: str, release_version: str, caller_data:
db.Session | None = None) -> int:
- """Core logic to analyse an rsync upload and queue checks."""
- base_path = util.get_release_candidate_draft_dir() / project_name /
release_version
- relative_paths = await util.paths_recursive(base_path)
+async def draft_checks(
+ project_name: str, release_version: str, draft_revision: str, caller_data:
db.Session | None = None
+) -> int:
+ """Core logic to analyse a draft revision and queue checks."""
+ # Construct path to the specific revision
+ revision_path = util.get_release_candidate_draft_dir() / project_name /
release_version / draft_revision
+ relative_paths = await util.paths_recursive(revision_path)
session_context = db.session() if (caller_data is None) else
contextlib.nullcontext(caller_data)
async with session_context as data:
@@ -68,17 +68,14 @@ async def draft_checks(project_name: str, release_version:
str, caller_data: db.
)
for path in relative_paths:
path_str = str(path)
+ task_function: Callable[[models.Release, str, str],
Awaitable[list[models.Task]]] | None = None
for suffix, task_function in TASK_FUNCTIONS.items():
if path.name.endswith(suffix):
- for task in await task_function(release, path_str):
+ for task in await task_function(release, draft_revision,
path_str):
+ task.draft_revision = draft_revision
data.add(task)
- path_check_task = models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.PATHS_CHECK,
- task_args=paths.Check(release_name=release.name).model_dump(),
- release_name=release.name,
- )
+ path_check_task = queued(models.TaskType.PATHS_CHECK, release,
draft_revision)
data.add(path_check_task)
if caller_data is None:
await data.commit()
@@ -86,6 +83,23 @@ async def draft_checks(project_name: str, release_version:
str, caller_data: db.
return len(relative_paths)
+def queued(
+ task_type: models.TaskType,
+ release: models.Release,
+ draft_revision: str,
+ primary_rel_path: str | None = None,
+ extra_args: dict[str, Any] | None = None,
+) -> models.Task:
+ return models.Task(
+ status=models.TaskStatus.QUEUED,
+ task_type=task_type,
+ task_args=extra_args or {},
+ release_name=release.name,
+ draft_revision=draft_revision,
+ primary_rel_path=primary_rel_path,
+ )
+
+
def resolve(task_type: models.TaskType) -> Callable[..., Awaitable[str |
None]]: # noqa: C901
match task_type:
case models.TaskType.HASHING_CHECK:
@@ -98,8 +112,8 @@ def resolve(task_type: models.TaskType) -> Callable[...,
Awaitable[str | None]]:
return paths.check
case models.TaskType.RAT_CHECK:
return rat.check
- case models.TaskType.RSYNC_ANALYSE:
- return rsync.analyse
+ # case models.TaskType.RSYNC_ANALYSE:
+ # return rsync.analyse
case models.TaskType.SBOM_GENERATE_CYCLONEDX:
return sbom.generate_cyclonedx
case models.TaskType.SIGNATURE_CHECK:
@@ -122,90 +136,35 @@ def resolve(task_type: models.TaskType) -> Callable[...,
Awaitable[str | None]]:
# Otherwise we lose exhaustiveness checking
-async def sha_checks(release: models.Release, hash_file: str) ->
list[models.Task]:
+async def sha_checks(release: models.Release, draft_revision: str, hash_file:
str) -> list[models.Task]:
"""Create hash check task for a .sha256 or .sha512 file."""
tasks = []
- tasks.append(
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.HASHING_CHECK,
- task_args=checks.ReleaseAndRelPath(
- release_name=release.name,
- primary_rel_path=hash_file,
- ).model_dump(),
- release_name=release.name,
- ),
- )
+ tasks.append(queued(models.TaskType.HASHING_CHECK, release,
draft_revision, hash_file))
return tasks
-async def tar_gz_checks(release: models.Release, path: str) ->
list[models.Task]:
+async def tar_gz_checks(release: models.Release, draft_revision: str, path:
str) -> list[models.Task]:
"""Create check tasks for a .tar.gz or .tgz file."""
tasks = [
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.LICENSE_FILES,
- task_args=checks.ReleaseAndRelPath(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.LICENSE_HEADERS,
- task_args=checks.ReleaseAndRelPath(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.RAT_CHECK,
- task_args=rat.Check(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.TARGZ_INTEGRITY,
- task_args=targz.Integrity(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.TARGZ_STRUCTURE,
- task_args=checks.ReleaseAndRelPath(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
+ queued(models.TaskType.LICENSE_FILES, release, draft_revision, path),
+ queued(models.TaskType.LICENSE_HEADERS, release, draft_revision, path),
+ queued(models.TaskType.RAT_CHECK, release, draft_revision, path),
+ queued(models.TaskType.TARGZ_INTEGRITY, release, draft_revision, path),
+ queued(models.TaskType.TARGZ_STRUCTURE, release, draft_revision, path),
]
return tasks
-async def zip_checks(release: models.Release, path: str) -> list[models.Task]:
+async def zip_checks(release: models.Release, draft_revision: str, path: str)
-> list[models.Task]:
"""Create check tasks for a .zip file."""
tasks = [
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.ZIPFORMAT_INTEGRITY,
- task_args=checks.ReleaseAndRelPath(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.ZIPFORMAT_LICENSE_FILES,
- task_args=checks.ReleaseAndRelPath(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.ZIPFORMAT_LICENSE_HEADERS,
- task_args=checks.ReleaseAndRelPath(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
- models.Task(
- status=models.TaskStatus.QUEUED,
- task_type=models.TaskType.ZIPFORMAT_STRUCTURE,
- task_args=checks.ReleaseAndRelPath(release_name=release.name,
primary_rel_path=path).model_dump(),
- release_name=release.name,
- ),
+ queued(models.TaskType.ZIPFORMAT_INTEGRITY, release, draft_revision,
path),
+ queued(models.TaskType.ZIPFORMAT_LICENSE_FILES, release,
draft_revision, path),
+ queued(models.TaskType.ZIPFORMAT_LICENSE_HEADERS, release,
draft_revision, path),
+ queued(models.TaskType.ZIPFORMAT_STRUCTURE, release, draft_revision,
path),
]
return tasks
diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py
index 5480427..76a07df 100644
--- a/atr/tasks/checks/__init__.py
+++ b/atr/tasks/checks/__init__.py
@@ -17,64 +17,84 @@
from __future__ import annotations
+import dataclasses
import datetime
+import pathlib
from functools import wraps
-from types import FunctionType
-from typing import TYPE_CHECKING, Any, TypeVar
+from typing import TYPE_CHECKING, Any
-import pydantic
import sqlmodel
if TYPE_CHECKING:
- import pathlib
from collections.abc import Awaitable, Callable
+ import pydantic
+
import atr.db as db
import atr.db.models as models
import atr.util as util
-class Check:
+# Pydantic does not like Callable types, so we use a dataclass instead
+# It says: "you should define `Callable`, then call
`FunctionArguments.model_rebuild()`"
[email protected]
+class FunctionArguments:
+ recorder: Callable[[], Awaitable[Recorder]]
+ release_name: str
+ draft_revision: str
+ primary_rel_path: str | None
+ extra_args: dict[str, Any]
+
+
+class Recorder:
+ checker: str
+ release_name: str
+ project_name: str
+ version_name: str
+ primary_rel_path: str | None
+ draft_revision: str
+ afresh: bool
+
def __init__(
self,
checker: str | Callable[..., Any],
release_name: str,
+ draft_revision: str,
primary_rel_path: str | None = None,
afresh: bool = True,
) -> None:
- if isinstance(checker, FunctionType):
- checker = function_key(checker)
- if not isinstance(checker, str):
- raise ValueError("Checker must be a string or a callable")
- self.checker = checker
- project_name, version_name = models.project_version(release_name)
- self.project_name = project_name
- self.version_name = version_name
+ self.checker = function_key(checker) if callable(checker) else checker
self.release_name = release_name
+ self.draft_revision = draft_revision
self.primary_rel_path = primary_rel_path
self.afresh = afresh
self._constructed = False
+ project_name, version_name = models.project_version(release_name)
+ self.project_name = project_name
+ self.version_name = version_name
+
@classmethod
async def create(
cls,
checker: str | Callable[..., Any],
release_name: str,
+ draft_revision: str,
primary_rel_path: str | None = None,
afresh: bool = True,
- ) -> Check:
- check = cls(checker, release_name, primary_rel_path, afresh)
+ ) -> Recorder:
+ recorder = cls(checker, release_name, draft_revision,
primary_rel_path, afresh)
if afresh is True:
# Clear outer path whether it's specified or not
- await check.clear(primary_rel_path)
- check._constructed = True
- return check
+ await recorder.clear(primary_rel_path)
+ recorder._constructed = True
+ return recorder
async def _add(
self, status: models.CheckResultStatus, message: str, data: Any,
primary_rel_path: str | None = None
) -> models.CheckResult:
if self._constructed is False:
- raise RuntimeError("Cannot add check result to a check that has
not been constructed")
+ raise RuntimeError("Cannot add check result to a recorder that has
not been constructed")
if primary_rel_path is not None:
if self.primary_rel_path is not None:
raise ValueError("Cannot specify path twice")
@@ -101,13 +121,24 @@ class Check:
return result
async def abs_path(self, rel_path: str | None = None) -> pathlib.Path |
None:
+ """Construct the absolute path using the required draft_revision."""
+ base_dir = util.get_release_candidate_draft_dir()
+ project_part = self.project_name
+ version_part = self.version_name
+ revision_part = self.draft_revision
+
+ # Determine the relative path part
+ rel_path_part: str | None = None
if rel_path is not None:
- return util.get_release_candidate_draft_dir() / self.project_name
/ self.version_name / rel_path
- if self.primary_rel_path is not None:
- return (
- util.get_release_candidate_draft_dir() / self.project_name /
self.version_name / self.primary_rel_path
- )
- return util.get_release_candidate_draft_dir() / self.project_name /
self.version_name
+ rel_path_part = rel_path
+ elif self.primary_rel_path is not None:
+ rel_path_part = self.primary_rel_path
+
+ # Construct the absolute path
+ abs_path_parts: list[str | pathlib.Path] = [base_dir, project_part,
version_part, revision_part]
+ if isinstance(rel_path_part, str):
+ abs_path_parts.append(rel_path_part)
+ return pathlib.Path(*abs_path_parts)
async def clear(self, primary_rel_path: str | None = None) -> None:
async with db.session() as data:
@@ -136,45 +167,15 @@ def function_key(func: Callable[..., Any]) -> str:
return func.__module__ + "." + func.__name__
-# def rel_path(abs_path: str) -> str:
-# """Return the relative path for a given absolute path."""
-# conf = config.get()
-# phase_dir = pathlib.Path(conf.PHASE_STORAGE_DIR)
-# phase_sub_dir = pathlib.Path(abs_path).relative_to(phase_dir)
-# # Skip the first component, which is the phase name
-# # And the next two components, which are the project name and version
name
-# return str(pathlib.Path(*phase_sub_dir.parts[3:]))
-
-
-# def using(cls: type[pydantic.BaseModel]) -> Callable[[Callable[..., Any]],
Callable[..., Any]]:
-# """Decorator to specify the parameters for a check."""
-
-# def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
-# @wraps(func)
-# async def wrapper(data_dict: dict[str, Any], *args: Any, **kwargs:
Any) -> Any:
-# model_instance = cls(**data_dict)
-# return await func(model_instance, *args, **kwargs)
-# return wrapper
-
-# return decorator
+def with_model(cls: type[pydantic.BaseModel]) -> Callable[[Callable[...,
Any]], Callable[..., Any]]:
+ """Decorator to specify the parameters for a check."""
-
-T = TypeVar("T", bound=pydantic.BaseModel)
-R = TypeVar("R")
-
-
-def with_model(model_class: type[T]) -> Callable[[Callable[...,
Awaitable[R]]], Callable[..., Awaitable[R]]]:
- def decorator(func: Callable[..., Awaitable[R]]) -> Callable[...,
Awaitable[R]]:
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
- async def wrapper(data_dict: dict[str, Any], *args: Any, **kwargs:
Any) -> R:
- model_instance = model_class(**data_dict)
+ async def wrapper(data_dict: dict[str, Any], *args: Any, **kwargs:
Any) -> Any:
+ model_instance = cls(**data_dict)
return await func(model_instance, *args, **kwargs)
return wrapper
return decorator
-
-
-class ReleaseAndRelPath(pydantic.BaseModel):
- release_name: str = pydantic.Field(..., description="Release name")
- primary_rel_path: str = pydantic.Field(..., description="Primary relative
path to check")
diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py
index b02d048..299521c 100644
--- a/atr/tasks/checks/hashing.py
+++ b/atr/tasks/checks/hashing.py
@@ -27,19 +27,15 @@ import atr.tasks.checks as checks
_LOGGER: Final = logging.getLogger(__name__)
[email protected]_model(checks.ReleaseAndRelPath)
-async def check(args: checks.ReleaseAndRelPath) -> str | None:
+async def check(args: checks.FunctionArguments) -> str | None:
"""Check the hash of a file."""
- check_obj = await checks.Check.create(
- checker=check, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
-
- if not (hash_abs_path := await check_obj.abs_path()):
+ recorder = await args.recorder()
+ if not (hash_abs_path := await recorder.abs_path()):
return None
algorithm = hash_abs_path.suffix.lstrip(".")
if algorithm not in {"sha256", "sha512"}:
- await check_obj.failure("Unsupported hash algorithm", {"algorithm":
algorithm})
+ await recorder.failure("Unsupported hash algorithm", {"algorithm":
algorithm})
return None
# Remove the hash file suffix to get the artifact path
@@ -69,15 +65,15 @@ async def check(args: checks.ReleaseAndRelPath) -> str |
None:
expected_hash = expected_hash.strip().split()[0]
if secrets.compare_digest(computed_hash, expected_hash):
- await check_obj.success(
+ await recorder.success(
f"Hash ({algorithm}) matches expected value",
{"computed_hash": computed_hash, "expected_hash":
expected_hash},
)
else:
- await check_obj.failure(
+ await recorder.failure(
f"Hash ({algorithm}) mismatch",
{"computed_hash": computed_hash, "expected_hash":
expected_hash},
)
except Exception as e:
- await check_obj.failure("Unable to verify hash", {"error": str(e)})
+ await recorder.failure("Unable to verify hash", {"error": str(e)})
return None
diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py
index 389b24f..7a22a39 100644
--- a/atr/tasks/checks/license.py
+++ b/atr/tasks/checks/license.py
@@ -172,14 +172,10 @@ INCLUDED_PATTERNS: Final[list[str]] = [
# Tasks
[email protected]_model(checks.ReleaseAndRelPath)
-async def files(args: checks.ReleaseAndRelPath) -> str | None:
+async def files(args: checks.FunctionArguments) -> str | None:
"""Check that the LICENSE and NOTICE files exist and are valid."""
- check = await checks.Check.create(
- checker=files, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
-
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking license files for {artifact_abs_path} (rel:
{args.primary_rel_path})")
@@ -188,27 +184,23 @@ async def files(args: checks.ReleaseAndRelPath) -> str |
None:
result_data = await asyncio.to_thread(_files_check_core_logic,
str(artifact_abs_path))
if result_data.get("error"):
- await check.failure(result_data["error"], result_data)
+ await recorder.failure(result_data["error"], result_data)
elif result_data["license_valid"] and result_data["notice_valid"]:
- await check.success("LICENSE and NOTICE files present and valid",
result_data)
+ await recorder.success("LICENSE and NOTICE files present and
valid", result_data)
else:
# TODO: Be more specific about the issues
- await check.failure("Issues found with LICENSE or NOTICE files",
result_data)
+ await recorder.failure("Issues found with LICENSE or NOTICE
files", result_data)
except Exception as e:
- await check.failure("Error checking license files", {"error": str(e)})
+ await recorder.failure("Error checking license files", {"error":
str(e)})
return None
[email protected]_model(checks.ReleaseAndRelPath)
-async def headers(args: checks.ReleaseAndRelPath) -> str | None:
+async def headers(args: checks.FunctionArguments) -> str | None:
"""Check that all source files have valid license headers."""
- check = await checks.Check.create(
- checker=headers, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
-
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking license headers for {artifact_abs_path} (rel:
{args.primary_rel_path})")
@@ -218,16 +210,16 @@ async def headers(args: checks.ReleaseAndRelPath) -> str
| None:
if result_data.get("error_message"):
# Handle errors during the check process itself
- await check.failure(result_data["error_message"], result_data)
+ await recorder.failure(result_data["error_message"], result_data)
elif not result_data["valid"]:
# Handle validation failures
- await check.failure(result_data["message"], result_data)
+ await recorder.failure(result_data["message"], result_data)
else:
# Handle success
- await check.success(result_data["message"], result_data)
+ await recorder.success(result_data["message"], result_data)
except Exception as e:
- await check.failure("Error checking license headers", {"error":
str(e)})
+ await recorder.failure("Error checking license headers", {"error":
str(e)})
return None
diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py
index 661c53b..55eea0e 100644
--- a/atr/tasks/checks/paths.py
+++ b/atr/tasks/checks/paths.py
@@ -21,7 +21,6 @@ import re
from typing import Final
import aiofiles.os
-import pydantic
import atr.analysis as analysis
import atr.tasks.checks as checks
@@ -30,12 +29,6 @@ import atr.util as util
_LOGGER: Final = logging.getLogger(__name__)
-class Check(pydantic.BaseModel):
- """Arguments for the path structure and naming convention check."""
-
- release_name: str = pydantic.Field(..., description="Name of the release
being checked")
-
-
async def _check_artifact_rules(
base_path: pathlib.Path, relative_path: pathlib.Path, relative_paths:
set[str], errors: list[str]
) -> None:
@@ -96,9 +89,9 @@ async def _check_metadata_rules(
async def _check_path_process_single(
base_path: pathlib.Path,
relative_path: pathlib.Path,
- check_errors: checks.Check,
- check_warnings: checks.Check,
- check_success: checks.Check,
+ recorder_errors: checks.Recorder,
+ recorder_warnings: checks.Recorder,
+ recorder_success: checks.Recorder,
relative_paths: set[str],
) -> None:
"""Process and check a single path within the release directory."""
@@ -138,43 +131,45 @@ async def _check_path_process_single(
# Must aggregate errors and aggregate warnings otherwise they will be
removed by afresh=True
# Alternatively we could call Check.clear() manually
if errors:
- await check_errors.failure("; ".join(errors), {"errors": errors},
primary_rel_path=relative_path_str)
+ await recorder_errors.failure("; ".join(errors), {"errors": errors},
primary_rel_path=relative_path_str)
if warnings:
- await check_warnings.warning("; ".join(warnings), {"warnings":
warnings}, primary_rel_path=relative_path_str)
+ await recorder_warnings.warning("; ".join(warnings), {"warnings":
warnings}, primary_rel_path=relative_path_str)
if not (errors or warnings):
- await check_success.success(
+ await recorder_success.success(
"Path structure and naming conventions conform to policy", {},
primary_rel_path=relative_path_str
)
[email protected]_model(Check)
-async def check(args: Check) -> None:
+async def check(args: checks.FunctionArguments) -> None:
"""Check file path structure and naming conventions against ASF release
policy for all files in a release."""
# We refer to the following authoritative policies:
# - Release Creation Process (RCP)
# - Release Distribution Policy (RDP)
- check_errors = await checks.Check.create(
+ recorder_errors = await checks.Recorder.create(
checker=checks.function_key(check) + "_errors",
release_name=args.release_name,
+ draft_revision=args.draft_revision,
primary_rel_path=None,
afresh=True,
)
- check_warnings = await checks.Check.create(
+ recorder_warnings = await checks.Recorder.create(
checker=checks.function_key(check) + "_warnings",
release_name=args.release_name,
+ draft_revision=args.draft_revision,
primary_rel_path=None,
afresh=True,
)
- check_success = await checks.Check.create(
+ recorder_success = await checks.Recorder.create(
checker=checks.function_key(check) + "_success",
release_name=args.release_name,
+ draft_revision=args.draft_revision,
primary_rel_path=None,
afresh=True,
)
# As primary_rel_path is None, the base path is the release candidate
draft directory
- if not (base_path := await check_success.abs_path()):
+ if not (base_path := await recorder_success.abs_path()):
return
if not await aiofiles.os.path.isdir(base_path):
@@ -188,9 +183,9 @@ async def check(args: Check) -> None:
await _check_path_process_single(
base_path,
relative_path,
- check_errors,
- check_warnings,
- check_success,
+ recorder_errors,
+ recorder_warnings,
+ recorder_success,
relative_paths_set,
)
diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py
index bcab4e8..8b16139 100644
--- a/atr/tasks/checks/rat.py
+++ b/atr/tasks/checks/rat.py
@@ -23,8 +23,6 @@ import tempfile
import xml.etree.ElementTree as ElementTree
from typing import Any, Final
-import pydantic
-
import atr.config as config
import atr.tasks.checks as checks
import atr.tasks.checks.targz as targz
@@ -44,27 +42,10 @@ _JAVA_MEMORY_ARGS: Final[list[str]] = []
_LOGGER: Final = logging.getLogger(__name__)
-class Check(pydantic.BaseModel):
- """Parameters for Apache RAT license checking."""
-
- release_name: str = pydantic.Field(..., description="Release name")
- primary_rel_path: str = pydantic.Field(..., description="Relative path to
the archive file to check")
- rat_jar_path: str = pydantic.Field(
- default=_CONFIG.APACHE_RAT_JAR_PATH, description="Path to the Apache
RAT JAR file"
- )
- max_extract_size: int = pydantic.Field(
- default=_CONFIG.MAX_EXTRACT_SIZE, description="Maximum extraction size
in bytes"
- )
- chunk_size: int = pydantic.Field(default=_CONFIG.EXTRACT_CHUNK_SIZE,
description="Chunk size for extraction")
-
-
[email protected]_model(Check)
-async def check(args: Check) -> str | None:
+async def check(args: checks.FunctionArguments) -> str | None:
"""Use Apache RAT to check the licenses of the files in the artifact."""
- check_obj = await checks.Check.create(
- checker=check, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
- if not (artifact_abs_path := await check_obj.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking RAT licenses for {artifact_abs_path} (rel:
{args.primary_rel_path})")
@@ -73,24 +54,24 @@ async def check(args: Check) -> str | None:
result_data = await asyncio.to_thread(
_check_core_logic,
artifact_path=str(artifact_abs_path),
- rat_jar_path=args.rat_jar_path,
- max_extract_size=args.max_extract_size,
- chunk_size=args.chunk_size,
+ rat_jar_path=args.extra_args.get("rat_jar_path",
_CONFIG.APACHE_RAT_JAR_PATH),
+ max_extract_size=args.extra_args.get("max_extract_size",
_CONFIG.MAX_EXTRACT_SIZE),
+ chunk_size=args.extra_args.get("chunk_size",
_CONFIG.EXTRACT_CHUNK_SIZE),
)
if result_data.get("error"):
# Handle errors from within the core logic
- await check_obj.failure(result_data["message"], result_data)
+ await recorder.failure(result_data["message"], result_data)
elif not result_data["valid"]:
# Handle RAT validation failures
- await check_obj.failure(result_data["message"], result_data)
+ await recorder.failure(result_data["message"], result_data)
else:
# Handle success
- await check_obj.success(result_data["message"], result_data)
+ await recorder.success(result_data["message"], result_data)
except Exception as e:
# TODO: Or bubble for task failure?
- await check_obj.failure("Error running Apache RAT check", {"error":
str(e)})
+ await recorder.failure("Error running Apache RAT check", {"error":
str(e)})
return None
diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py
index 8f822bc..e975009 100644
--- a/atr/tasks/checks/signature.py
+++ b/atr/tasks/checks/signature.py
@@ -21,7 +21,6 @@ import tempfile
from typing import Any, Final
import gnupg
-import pydantic
import sqlmodel
import atr.db as db
@@ -31,49 +30,46 @@ import atr.tasks.checks as checks
_LOGGER = logging.getLogger(__name__)
-class Check(pydantic.BaseModel):
- """Parameters for signature checking."""
-
- release_name: str = pydantic.Field(..., description="Release name")
- committee_name: str = pydantic.Field(..., description="Name of the
committee whose keys should be used")
- signature_rel_path: str = pydantic.Field(..., description="Relative path
to the signature file (.asc)")
-
-
[email protected]_model(Check)
-async def check(args: Check) -> str | None:
+async def check(args: checks.FunctionArguments) -> str | None:
"""Check a signature file."""
- check_obj = await checks.Check.create(
- checker=check, release_name=args.release_name,
primary_rel_path=args.signature_rel_path
- )
+ recorder = await args.recorder()
+ if not (primary_abs_path := await recorder.abs_path()):
+ return None
+
+ if not (primary_rel_path := args.primary_rel_path):
+ await recorder.failure("Primary relative path is required",
{"primary_rel_path": primary_rel_path})
+ return None
- if not (signature_abs_path := await check_obj.abs_path()):
+ artifact_rel_path = primary_rel_path.removesuffix(".asc")
+ if not (artifact_abs_path := await recorder.abs_path(artifact_rel_path)):
return None
- artifact_rel_path = args.signature_rel_path.removesuffix(".asc")
- if not (artifact_abs_path := await check_obj.abs_path(artifact_rel_path)):
+ committee_name = args.extra_args.get("committee_name")
+ if not isinstance(committee_name, str):
+ await recorder.failure("Committee name is required",
{"committee_name": committee_name})
return None
_LOGGER.info(
- f"Checking signature {signature_abs_path} for {artifact_abs_path}"
- f" using {args.committee_name} keys (rel: {args.signature_rel_path})"
+ f"Checking signature {primary_abs_path} for {artifact_abs_path}"
+ f" using {committee_name} keys (rel: {primary_rel_path})"
)
try:
result_data = await _check_core_logic(
- committee_name=args.committee_name,
+ committee_name=committee_name,
artifact_path=str(artifact_abs_path),
- signature_path=str(signature_abs_path),
+ signature_path=str(primary_abs_path),
)
if result_data.get("error"):
- await check_obj.failure(result_data["error"], result_data)
+ await recorder.failure(result_data["error"], result_data)
elif result_data.get("verified"):
- await check_obj.success("Signature verified successfully",
result_data)
+ await recorder.success("Signature verified successfully",
result_data)
else:
# Shouldn't happen
- await check_obj.failure("Signature verification failed for unknown
reasons", result_data)
+ await recorder.failure("Signature verification failed for unknown
reasons", result_data)
except Exception as e:
- await check_obj.failure("Error during signature check execution",
{"error": str(e)})
+ await recorder.failure("Error during signature check execution",
{"error": str(e)})
return None
diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py
index 7cfe019..6a3ed62 100644
--- a/atr/tasks/checks/targz.py
+++ b/atr/tasks/checks/targz.py
@@ -20,47 +20,32 @@ import logging
import tarfile
from typing import Final
-import pydantic
-
import atr.tasks.checks as checks
_LOGGER: Final = logging.getLogger(__name__)
-class Integrity(pydantic.BaseModel):
- """Parameters for archive integrity checking."""
-
- release_name: str = pydantic.Field(..., description="Release name")
- primary_rel_path: str = pydantic.Field(..., description="Relative path to
the .tar.gz file to check")
- chunk_size: int = pydantic.Field(default=4096, description="Size of chunks
to read when checking the file")
-
-
[email protected]_model(Integrity)
-async def integrity(args: Integrity) -> str | None:
+async def integrity(args: checks.FunctionArguments) -> str | None:
"""Check the integrity of a .tar.gz file."""
- check = await checks.Check.create(
- checker=integrity, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking integrity for {artifact_abs_path} (rel:
{args.primary_rel_path})")
+ chunk_size = 4096
try:
- size = await asyncio.to_thread(_integrity_core,
str(artifact_abs_path), args.chunk_size)
- await check.success("Able to read all entries of the archive using
tarfile", {"size": size})
+ size = await asyncio.to_thread(_integrity_core,
str(artifact_abs_path), chunk_size)
+ await recorder.success("Able to read all entries of the archive using
tarfile", {"size": size})
except Exception as e:
- await check.failure("Unable to read all entries of the archive using
tarfile", {"error": str(e)})
+ await recorder.failure("Unable to read all entries of the archive
using tarfile", {"error": str(e)})
return None
[email protected]_model(checks.ReleaseAndRelPath)
-async def structure(args: checks.ReleaseAndRelPath) -> str | None:
+async def structure(args: checks.FunctionArguments) -> str | None:
"""Check the structure of a .tar.gz file."""
- check = await checks.Check.create(
- checker=structure, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
filename = artifact_abs_path.name
@@ -74,17 +59,17 @@ async def structure(args: checks.ReleaseAndRelPath) -> str
| None:
try:
root = await asyncio.to_thread(root_directory, str(artifact_abs_path))
if root == expected_root:
- await check.success(
+ await recorder.success(
"Archive contains exactly one root directory matching the
expected name",
{"root": root, "expected": expected_root},
)
else:
- await check.failure(
+ await recorder.failure(
f"Root directory '{root}' does not match expected name
'{expected_root}'",
{"root": root, "expected": expected_root},
)
except Exception as e:
- await check.failure("Unable to verify archive structure", {"error":
str(e)})
+ await recorder.failure("Unable to verify archive structure", {"error":
str(e)})
return None
diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py
index dbaffc7..d8f159a 100644
--- a/atr/tasks/checks/zipformat.py
+++ b/atr/tasks/checks/zipformat.py
@@ -27,13 +27,10 @@ import atr.tasks.checks.license as license
_LOGGER = logging.getLogger(__name__)
[email protected]_model(checks.ReleaseAndRelPath)
-async def integrity(args: checks.ReleaseAndRelPath) -> str | None:
+async def integrity(args: checks.FunctionArguments) -> str | None:
"""Check that the zip archive is not corrupted and can be opened."""
- check = await checks.Check.create(
- checker=integrity, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking zip integrity for {artifact_abs_path} (rel:
{args.primary_rel_path})")
@@ -41,22 +38,19 @@ async def integrity(args: checks.ReleaseAndRelPath) -> str
| None:
try:
result_data = await asyncio.to_thread(_integrity_check_core_logic,
str(artifact_abs_path))
if result_data.get("error"):
- await check.failure(result_data["error"], result_data)
+ await recorder.failure(result_data["error"], result_data)
else:
- await check.success(f"Zip archive integrity OK
({result_data['member_count']} members)", result_data)
+ await recorder.success(f"Zip archive integrity OK
({result_data['member_count']} members)", result_data)
except Exception as e:
- await check.failure("Error checking zip integrity", {"error": str(e)})
+ await recorder.failure("Error checking zip integrity", {"error":
str(e)})
return None
[email protected]_model(checks.ReleaseAndRelPath)
-async def license_files(args: checks.ReleaseAndRelPath) -> str | None:
+async def license_files(args: checks.FunctionArguments) -> str | None:
"""Check that the LICENSE and NOTICE files exist and are valid within the
zip."""
- check = await checks.Check.create(
- checker=license_files, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking zip license files for {artifact_abs_path} (rel:
{args.primary_rel_path})")
@@ -65,9 +59,9 @@ async def license_files(args: checks.ReleaseAndRelPath) ->
str | None:
result_data = await
asyncio.to_thread(_license_files_check_core_logic_zip, str(artifact_abs_path))
if result_data.get("error"):
- await check.failure(result_data["error"], result_data)
+ await recorder.failure(result_data["error"], result_data)
elif result_data.get("license_valid") and
result_data.get("notice_valid"):
- await check.success("LICENSE and NOTICE files present and valid in
zip", result_data)
+ await recorder.success("LICENSE and NOTICE files present and valid
in zip", result_data)
else:
issues = []
if not result_data.get("license_found"):
@@ -79,21 +73,18 @@ async def license_files(args: checks.ReleaseAndRelPath) ->
str | None:
elif not result_data.get("notice_valid"):
issues.append("NOTICE invalid or empty")
issue_str = ", ".join(issues) if issues else "Issues found with
LICENSE or NOTICE files"
- await check.failure(issue_str, result_data)
+ await recorder.failure(issue_str, result_data)
except Exception as e:
- await check.failure("Error checking zip license files", {"error":
str(e)})
+ await recorder.failure("Error checking zip license files", {"error":
str(e)})
return None
[email protected]_model(checks.ReleaseAndRelPath)
-async def license_headers(args: checks.ReleaseAndRelPath) -> str | None:
+async def license_headers(args: checks.FunctionArguments) -> str | None:
"""Check that all source files within the zip have valid license
headers."""
- check = await checks.Check.create(
- checker=license_headers, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking zip license headers for {artifact_abs_path} (rel:
{args.primary_rel_path})")
@@ -102,29 +93,26 @@ async def license_headers(args: checks.ReleaseAndRelPath)
-> str | None:
result_data = await
asyncio.to_thread(_license_headers_check_core_logic_zip, str(artifact_abs_path))
if result_data.get("error_message"):
- await check.failure(result_data["error_message"], result_data)
+ await recorder.failure(result_data["error_message"], result_data)
elif not result_data.get("valid"):
num_issues = len(result_data.get("files_without_headers", []))
failure_msg = f"{num_issues} file(s) missing or having invalid
license headers"
- await check.failure(failure_msg, result_data)
+ await recorder.failure(failure_msg, result_data)
else:
- await check.success(
+ await recorder.success(
f"License headers OK ({result_data.get('files_checked', 0)}
files checked)", result_data
)
except Exception as e:
- await check.failure("Error checking zip license headers", {"error":
str(e)})
+ await recorder.failure("Error checking zip license headers", {"error":
str(e)})
return None
[email protected]_model(checks.ReleaseAndRelPath)
-async def structure(args: checks.ReleaseAndRelPath) -> str | None:
+async def structure(args: checks.FunctionArguments) -> str | None:
"""Check that the zip archive has a single root directory matching the
artifact name."""
- check = await checks.Check.create(
- checker=structure, release_name=args.release_name,
primary_rel_path=args.primary_rel_path
- )
- if not (artifact_abs_path := await check.abs_path()):
+ recorder = await args.recorder()
+ if not (artifact_abs_path := await recorder.abs_path()):
return None
_LOGGER.info(f"Checking zip structure for {artifact_abs_path} (rel:
{args.primary_rel_path})")
@@ -132,11 +120,11 @@ async def structure(args: checks.ReleaseAndRelPath) ->
str | None:
try:
result_data = await asyncio.to_thread(_structure_check_core_logic,
str(artifact_abs_path))
if result_data.get("error"):
- await check.failure(result_data["error"], result_data)
+ await recorder.failure(result_data["error"], result_data)
else:
- await check.success(f"Zip structure OK (root:
{result_data['root_dir']})", result_data)
+ await recorder.success(f"Zip structure OK (root:
{result_data['root_dir']})", result_data)
except Exception as e:
- await check.failure("Error checking zip structure", {"error": str(e)})
+ await recorder.failure("Error checking zip structure", {"error":
str(e)})
return None
diff --git a/atr/tasks/rsync.py b/atr/tasks/rsync.py
deleted file mode 100644
index 046e6c5..0000000
--- a/atr/tasks/rsync.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import logging
-from typing import Final
-
-import pydantic
-
-import atr.tasks as tasks
-import atr.tasks.checks as checks
-
-# _CONFIG: Final = config.get()
-_LOGGER: Final = logging.getLogger(__name__)
-
-
-class Analyse(pydantic.BaseModel):
- """Parameters for rsync analysis."""
-
- project_name: str = pydantic.Field(..., description="Name of the project
to rsync")
- release_version: str = pydantic.Field(..., description="Version of the
release to rsync")
-
-
[email protected]_model(Analyse)
-async def analyse(args: Analyse) -> str | None:
- """Analyse an rsync upload by queuing specific checks for discovered
files."""
- _LOGGER.info(f"Starting rsync analysis for {args.project_name}
{args.release_version}")
- try:
- num_paths = await tasks.draft_checks(
- args.project_name,
- args.release_version,
- )
- _LOGGER.info(f"Finished rsync analysis for {args.project_name}
{args.release_version}, found {num_paths} paths")
- except Exception as e:
- _LOGGER.exception(f"Rsync analysis failed for {args.project_name}
{args.release_version}: {e}")
- raise e
-
- return None
diff --git a/atr/util.py b/atr/util.py
index 81541f0..8de22eb 100644
--- a/atr/util.py
+++ b/atr/util.py
@@ -37,6 +37,8 @@ import quart
import quart_wtf
import quart_wtf.typing
+# NOTE: The atr.db module imports this module
+# Therefore, this module must not import atr.db
import atr.config as config
import atr.db.models as models
@@ -46,13 +48,6 @@ T = TypeVar("T")
_LOGGER = logging.getLogger(__name__)
-async def get_asf_id_or_die() -> str:
- web_session = await session.read()
- if web_session is None or web_session.uid is None:
- raise base.ASFQuartException("Not authenticated", errorcode=401)
- return web_session.uid
-
-
# from
https://github.com/pydantic/pydantic/discussions/8755#discussioncomment-8417979
@dataclasses.dataclass
class DictToList:
@@ -101,23 +96,23 @@ class QuartFormTyped(quart_wtf.QuartForm):
return form
-# def abs_path_to_release_and_rel_path(abs_path: str) -> tuple[str, str]:
-# """Return the release name and relative path for a given path."""
-# conf = config.get()
-# phase_dir = pathlib.Path(conf.PHASE_STORAGE_DIR)
-# phase_sub_dir = pathlib.Path(abs_path).relative_to(phase_dir)
-# # Skip the first component, which is the phase name
-# # The next two components are the project name and version name
-# project_name = phase_sub_dir.parts[1]
-# version_name = phase_sub_dir.parts[2]
-# return models.release_name(project_name, version_name),
str(pathlib.Path(*phase_sub_dir.parts[3:]))
-
-
def as_url(func: Callable, **kwargs: Any) -> str:
"""Return the URL for a function."""
return quart.url_for(func.__annotations__["endpoint"], **kwargs)
[email protected]
+async def async_temporary_directory(
+ suffix: str | None = None, prefix: str | None = None, dir: str |
pathlib.Path | None = None
+) -> AsyncGenerator[pathlib.Path]:
+ """Create an async temporary directory similar to
tempfile.TemporaryDirectory."""
+ temp_dir_path: str = await asyncio.to_thread(tempfile.mkdtemp,
suffix=suffix, prefix=prefix, dir=dir)
+ try:
+ yield pathlib.Path(temp_dir_path)
+ finally:
+ await asyncio.to_thread(shutil.rmtree, temp_dir_path,
ignore_errors=True)
+
+
def compute_sha3_256(file_data: bytes) -> str:
"""Compute SHA3-256 hash of file data."""
return hashlib.sha3_256(file_data).hexdigest()
@@ -135,6 +130,8 @@ async def compute_sha512(file_path: pathlib.Path) -> str:
async def content_list(phase_subdir: pathlib.Path, project_name: str,
version_name: str) -> AsyncGenerator[FileStat]:
"""List all the files in the given path."""
base_path = phase_subdir / project_name / version_name
+ if phase_subdir.name == "release-candidate-draft":
+ base_path = base_path / "latest"
for path in await paths_recursive(base_path):
stat = await aiofiles.os.stat(base_path / path)
yield FileStat(
@@ -149,6 +146,7 @@ async def content_list(phase_subdir: pathlib.Path,
project_name: str, version_na
async def create_hard_link_clone(source_dir: pathlib.Path, dest_dir:
pathlib.Path) -> None:
"""Recursively create a clone of source_dir in dest_dir using hard links
for files."""
+ # TODO: We're currently using cp -al instead
# Ensure source exists and is a directory
if not await aiofiles.os.path.isdir(source_dir):
raise ValueError(f"Source path is not a directory or does not exist:
{source_dir}")
@@ -184,6 +182,13 @@ async def file_sha3(path: str) -> str:
return sha3.hexdigest()
+async def get_asf_id_or_die() -> str:
+ web_session = await session.read()
+ if web_session is None or web_session.uid is None:
+ raise base.ASFQuartException("Not authenticated", errorcode=401)
+ return web_session.uid
+
+
def get_phase_dir() -> pathlib.Path:
return pathlib.Path(config.get().PHASE_STORAGE_DIR)
@@ -233,6 +238,49 @@ async def paths_recursive(base_path: pathlib.Path, sort:
bool = True) -> list[pa
return paths
+async def read_file_for_viewer(full_path: pathlib.Path, max_size: int) ->
tuple[str | None, bool, bool, str | None]:
+ """Read file content for viewer."""
+ content: str | None = None
+ is_text = False
+ is_truncated = False
+ error_message: str | None = None
+
+ try:
+ if not await aiofiles.os.path.exists(full_path):
+ return None, False, False, "File does not exist"
+ if not await aiofiles.os.path.isfile(full_path):
+ return None, False, False, "Path is not a file"
+
+ file_size = await aiofiles.os.path.getsize(full_path)
+ read_size = min(file_size, max_size)
+
+ if file_size > max_size:
+ is_truncated = True
+
+ if file_size == 0:
+ is_text = True
+ content = "(Empty file)"
+ raw_content = b""
+ else:
+ async with aiofiles.open(full_path, "rb") as f:
+ raw_content = await f.read(read_size)
+
+ if file_size > 0:
+ try:
+ if b"\x00" in raw_content:
+ raise UnicodeDecodeError("utf-8", b"", 0, 1, "Null byte
found")
+ content = raw_content.decode("utf-8")
+ is_text = True
+ except UnicodeDecodeError:
+ is_text = False
+ content = _generate_hexdump(raw_content)
+
+ except Exception as e:
+ error_message = f"An error occurred reading the file: {e!s}"
+
+ return content, is_text, is_truncated, error_message
+
+
def release_directory(release: models.Release) -> pathlib.Path:
"""Determine the filesystem directory for a given release based on its
phase."""
phase = release.phase
@@ -315,6 +363,20 @@ def validate_as_type(value: Any, t: type[T]) -> T:
return value
+def _generate_hexdump(data: bytes) -> str:
+ """Generate a formatted hexdump string from bytes."""
+ hex_lines = []
+ for i in range(0, len(data), 16):
+ chunk = data[i : i + 16]
+ hex_part = binascii.hexlify(chunk).decode("ascii")
+ hex_part = hex_part.ljust(32)
+ hex_part_spaced = " ".join(hex_part[j : j + 2] for j in range(0,
len(hex_part), 2))
+ ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk)
+ line_num = f"{i:08x}"
+ hex_lines.append(f"{line_num} {hex_part_spaced} |{ascii_part}|")
+ return "\n".join(hex_lines)
+
+
def _get_dict_to_list_inner_type_adapter(source_type: Any, key: str) ->
pydantic.TypeAdapter[dict[Any, Any]]:
root_adapter = pydantic.TypeAdapter(source_type)
schema = root_adapter.core_schema
@@ -361,72 +423,3 @@ def _get_dict_to_list_validator(inner_adapter:
pydantic.TypeAdapter[dict[Any, An
return val
return validator
-
-
[email protected]
-async def async_temporary_directory(
- suffix: str | None = None, prefix: str | None = None, dir: str |
pathlib.Path | None = None
-) -> AsyncGenerator[pathlib.Path]:
- """Create an async temporary directory similar to
tempfile.TemporaryDirectory."""
- temp_dir_path: str = await asyncio.to_thread(tempfile.mkdtemp,
suffix=suffix, prefix=prefix, dir=dir)
- try:
- yield pathlib.Path(temp_dir_path)
- finally:
- await asyncio.to_thread(shutil.rmtree, temp_dir_path,
ignore_errors=True)
-
-
-async def read_file_for_viewer(full_path: pathlib.Path, max_size: int) ->
tuple[str | None, bool, bool, str | None]:
- """Read file content for viewer."""
- content: str | None = None
- is_text = False
- is_truncated = False
- error_message: str | None = None
-
- try:
- if not await aiofiles.os.path.exists(full_path):
- return None, False, False, "File does not exist"
- if not await aiofiles.os.path.isfile(full_path):
- return None, False, False, "Path is not a file"
-
- file_size = await aiofiles.os.path.getsize(full_path)
- read_size = min(file_size, max_size)
-
- if file_size > max_size:
- is_truncated = True
-
- if file_size == 0:
- is_text = True
- content = "(Empty file)"
- raw_content = b""
- else:
- async with aiofiles.open(full_path, "rb") as f:
- raw_content = await f.read(read_size)
-
- if file_size > 0:
- try:
- if b"\x00" in raw_content:
- raise UnicodeDecodeError("utf-8", b"", 0, 1, "Null byte
found")
- content = raw_content.decode("utf-8")
- is_text = True
- except UnicodeDecodeError:
- is_text = False
- content = _generate_hexdump(raw_content)
-
- except Exception as e:
- error_message = f"An error occurred reading the file: {e!s}"
-
- return content, is_text, is_truncated, error_message
-
-
-def _generate_hexdump(data: bytes) -> str:
- """Generate a formatted hexdump string from bytes."""
- hex_lines = []
- for i in range(0, len(data), 16):
- chunk = data[i : i + 16]
- hex_part = binascii.hexlify(chunk).decode("ascii")
- hex_part = hex_part.ljust(32)
- hex_part_spaced = " ".join(hex_part[j : j + 2] for j in range(0,
len(hex_part), 2))
- ascii_part = "".join(chr(b) if 32 <= b < 127 else "." for b in chunk)
- line_num = f"{i:08x}"
- hex_lines.append(f"{line_num} {hex_part_spaced} |{ascii_part}|")
- return "\n".join(hex_lines)
diff --git a/atr/worker.py b/atr/worker.py
index 769d08a..4cc55e9 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -24,6 +24,7 @@
import asyncio
import datetime
+import inspect
import json
import logging
import os
@@ -37,6 +38,7 @@ import sqlmodel
import atr.db as db
import atr.db.models as models
import atr.tasks as tasks
+import atr.tasks.checks as checks
import atr.tasks.task as task
_LOGGER: Final = logging.getLogger(__name__)
@@ -192,7 +194,7 @@ async def _task_result_process(
async def _task_process(task_id: int, task_type: str, task_args: list[str] |
dict[str, Any]) -> None:
"""Process a claimed task."""
- _LOGGER.info(f"Processing task {task_id} ({task_type}) with args
{task_args}")
+ _LOGGER.info(f"Processing task {task_id} ({task_type}) with raw args
{task_args}")
try:
task_type_member = models.TaskType(task_type)
except ValueError as e:
@@ -203,13 +205,58 @@ async def _task_process(task_id: int, task_type: str,
task_args: list[str] | dic
task_results: tuple[Any, ...]
try:
handler = tasks.resolve(task_type_member)
- handler_result = await handler(task_args)
+ sig = inspect.signature(handler)
+ params = list(sig.parameters.values())
+
+ # Check whether the handler is a check handler
+ if (len(params) == 1) and (params[0].annotation ==
checks.FunctionArguments):
+ _LOGGER.debug(f"Handler {handler.__name__} expects
checks.FunctionArguments, fetching full task details")
+ async with db.session() as data:
+ task_obj = await data.task(id=task_id).demand(
+ ValueError(f"Task {task_id} disappeared during processing")
+ )
+
+ # Validate required fields from the Task object itself
+ if task_obj.release_name is None:
+ raise ValueError(f"Task {task_id} is missing required
release_name")
+ if task_obj.draft_revision is None:
+ raise ValueError(f"Task {task_id} is missing required
draft_revision")
+
+ if not isinstance(task_args, dict):
+ raise TypeError(
+ f"Task {task_id} ({task_type}) has non-dict raw args"
+ f" {task_args} which should represent keyword_args"
+ )
+
+ async def recorder_factory() -> checks.Recorder:
+ return await checks.Recorder.create(
+ checker=handler,
+ release_name=task_obj.release_name or "",
+ draft_revision=task_obj.draft_revision or "",
+ primary_rel_path=task_obj.primary_rel_path,
+ )
+
+ function_arguments = checks.FunctionArguments(
+ recorder=recorder_factory,
+ release_name=task_obj.release_name,
+ draft_revision=task_obj.draft_revision,
+ primary_rel_path=task_obj.primary_rel_path,
+ extra_args=task_args,
+ )
+ _LOGGER.debug(f"Calling {handler.__name__} with structured
arguments: {function_arguments}")
+ handler_result = await handler(function_arguments)
+ else:
+ # Otherwise, it's not a check handler
+ handler_result = await handler(task_args)
+
task_results = (handler_result,)
status = task.COMPLETED
error = None
except Exception as e:
task_results = tuple()
status = task.FAILED
+ error_details = traceback.format_exc()
+ _LOGGER.error(f"Task {task_id} failed processing: {error_details}")
error = str(e)
await _task_result_process(task_id, task_results, status, error)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]