This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch check_caching in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 74b5038f59beb4e9b1bf4f4400592e09ba9b1719 Author: Alastair McFarlane <[email protected]> AuthorDate: Wed Feb 11 16:13:22 2026 +0000 Start to move caching out of check tasks --- atr/attestable.py | 16 +- atr/db/__init__.py | 3 + atr/docs/checks.md | 2 +- atr/docs/tasks.md | 2 +- atr/file_paths.py | 28 ++++ atr/hashing.py | 42 +++++ atr/models/sql.py | 2 +- atr/server.py | 42 ++--- atr/storage/readers/checks.py | 3 + atr/tasks/__init__.py | 200 +++++++++++++++++------- atr/tasks/checks/__init__.py | 159 ++++++++++++------- atr/tasks/checks/{hashing.py => file_hash.py} | 0 atr/tasks/checks/license.py | 11 +- migrations/versions/0049_2026.02.11_5b874ed2.py | 37 +++++ tests/unit/recorders.py | 2 +- 15 files changed, 388 insertions(+), 161 deletions(-) diff --git a/atr/attestable.py b/atr/attestable.py index d4d6d15..cac950c 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -18,13 +18,13 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any import aiofiles import aiofiles.os -import blake3 import pydantic +import atr.hashing as hashing import atr.log as log import atr.models.attestable as models import atr.util as util @@ -32,8 +32,6 @@ import atr.util as util if TYPE_CHECKING: import pathlib -_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 - def attestable_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.json" @@ -43,14 +41,6 @@ def attestable_paths_path(project_name: str, version_name: str, revision_number: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.paths.json" -async def compute_file_hash(path: pathlib.Path) -> str: - hasher = blake3.blake3() - async with aiofiles.open(path, "rb") as f: - while chunk := await f.read(_HASH_CHUNK_SIZE): - hasher.update(chunk) - return f"blake3:{hasher.hexdigest()}" - - def github_tp_payload_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.github-tp.json" @@ -140,7 +130,7 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str, if "\\" in path_key: # TODO: We should centralise this, and forbid some other characters too raise ValueError(f"Backslash in path is forbidden: {path_key}") - path_to_hash[path_key] = await compute_file_hash(full_path) + path_to_hash[path_key] = await hashing.compute_file_hash(full_path) path_to_size[path_key] = (await aiofiles.os.stat(full_path)).st_size return path_to_hash, path_to_size diff --git a/atr/db/__init__.py b/atr/db/__init__.py index eb454e0..ed4a76a 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -164,6 +164,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): status: Opt[sql.CheckResultStatus] = NOT_SET, message: Opt[str] = NOT_SET, data: Opt[Any] = NOT_SET, + inputs_hash: Opt[str] = NOT_SET, _release: bool = False, ) -> Query[sql.CheckResult]: query = sqlmodel.select(sql.CheckResult) @@ -188,6 +189,8 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): query = query.where(sql.CheckResult.message == message) if is_defined(data): query = query.where(sql.CheckResult.data == data) + if is_defined(inputs_hash): + query = query.where(sql.CheckResult.inputs_hash == inputs_hash) if _release: query = query.options(joined_load(sql.CheckResult.release)) diff --git a/atr/docs/checks.md b/atr/docs/checks.md index 7a34817..30badd0 100644 --- a/atr/docs/checks.md +++ b/atr/docs/checks.md @@ -52,7 +52,7 @@ This check records separate checker keys for errors, warnings, and success. Use For each `.sha256` or `.sha512` file, ATR computes the hash of the referenced artifact and compares it with the expected value. It supports files that contain just the hash as well as files that include a filename and hash on the same line. If the suffix does not indicate `sha256` or `sha512`, the check fails. -The checker key is `atr.tasks.checks.hashing.check`. +The checker key is `atr.tasks.checks.file_hash.check`. ### Signature verification diff --git a/atr/docs/tasks.md b/atr/docs/tasks.md index 98c409a..43c8ae2 100644 --- a/atr/docs/tasks.md +++ b/atr/docs/tasks.md @@ -41,7 +41,7 @@ In `atr/tasks/checks` you will find several modules that perform these check tas In `atr/tasks/__init__.py` you will see imports for existing modules where you can add an import for new check task, for example: ```python -import atr.tasks.checks.hashing as hashing +import atr.tasks.checks.file_hash as file_hash import atr.tasks.checks.license as license ``` diff --git a/atr/file_paths.py b/atr/file_paths.py new file mode 100644 index 0000000..d29d6b9 --- /dev/null +++ b/atr/file_paths.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pathlib + +import atr.util as util + + +def base_path_for_revision(project_name: str, version_name: str, revision: str) -> pathlib.Path: + return pathlib.Path(util.get_unfinished_dir(), project_name, version_name, revision) + + +def revision_path_for_file(project_name: str, version_name: str, revision: str, file_name: str) -> pathlib.Path: + return base_path_for_revision(project_name, version_name, revision) / file_name diff --git a/atr/hashing.py b/atr/hashing.py new file mode 100644 index 0000000..2970e08 --- /dev/null +++ b/atr/hashing.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pathlib +from typing import Any, Final + +import aiofiles +import aiofiles.os +import blake3 + +_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 + + +async def compute_file_hash(path: str | pathlib.Path) -> str: + path = pathlib.Path(path) + hasher = blake3.blake3() + async with aiofiles.open(path, "rb") as f: + while chunk := await f.read(_HASH_CHUNK_SIZE): + hasher.update(chunk) + return f"blake3:{hasher.hexdigest()}" + + +def compute_dict_hash(to_hash: dict[Any, Any]) -> str: + hasher = blake3.blake3() + for k in sorted(to_hash.keys()): + hasher.update(str(k).encode("utf-8")) + hasher.update(str(to_hash[k]).encode("utf-8")) + return f"blake3:{hasher.hexdigest()}" diff --git a/atr/models/sql.py b/atr/models/sql.py index a6e9aed..c03a395 100644 --- a/atr/models/sql.py +++ b/atr/models/sql.py @@ -946,7 +946,7 @@ class CheckResult(sqlmodel.SQLModel, table=True): data: Any = sqlmodel.Field( sa_column=sqlalchemy.Column(sqlalchemy.JSON), **example({"expected": "...", "found": "..."}) ) - input_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc...")) + inputs_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc...")) cached: bool = sqlmodel.Field(default=False, **example(False)) diff --git a/atr/server.py b/atr/server.py index c09357c..2499b7d 100644 --- a/atr/server.py +++ b/atr/server.py @@ -437,27 +437,27 @@ def _app_setup_request_lifecycle(app: base.QuartApp) -> None: # Check if session has a creation timestamp in metadata created_at_str = session.metadata.get("created_at") - if created_at_str is None: - # First time seeing this session, record creation time - session.metadata["created_at"] = datetime.datetime.now(datetime.UTC).isoformat() - pmcs = util.cookie_pmcs_or_session_pmcs(session) - session_data = util.session_cookie_data_from_client(session, pmcs) - util.write_quart_session_cookie(session_data) - return - - # Parse the creation timestamp and check session age - try: - created_at = datetime.datetime.fromisoformat(created_at_str) - except (ValueError, TypeError): - # Invalid timestamp, treat as expired - asfquart.session.clear() - raise base.ASFQuartException("Session expired", errorcode=401) - - session_age = datetime.datetime.now(datetime.UTC) - created_at - - if session_age > max_lifetime: - asfquart.session.clear() - raise base.ASFQuartException("Session expired", errorcode=401) + # if created_at_str is None: + # # First time seeing this session, record creation time + # session.metadata["created_at"] = datetime.datetime.now(datetime.UTC).isoformat() + # pmcs = util.cookie_pmcs_or_session_pmcs(session) + # session_data = util.session_cookie_data_from_client(session, pmcs) + # util.write_quart_session_cookie(session_data) + # return + + # # Parse the creation timestamp and check session age + # try: + # created_at = datetime.datetime.fromisoformat(created_at_str) + # except (ValueError, TypeError): + # # Invalid timestamp, treat as expired + # asfquart.session.clear() + # raise base.ASFQuartException("Session expired", errorcode=401) + # + # session_age = datetime.datetime.now(datetime.UTC) - created_at + + # if session_age > max_lifetime: + # asfquart.session.clear() + # raise base.ASFQuartException("Session expired", errorcode=401) @app.after_request async def log_request(response: quart.Response) -> quart.Response: diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py index b48daf3..00eb6eb 100644 --- a/atr/storage/readers/checks.py +++ b/atr/storage/readers/checks.py @@ -48,6 +48,9 @@ class GeneralPublic: if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") + # TODO: Consider here - we want to find by file / policy hash but we don't want to get all the task types policy + # keys in order to get all the possible hashes, do we? + query = self.__data.check_result( release_name=release.name, revision_number=release.latest_revision_number, diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 8030727..9835315 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -15,17 +15,22 @@ # specific language governing permissions and limitations # under the License. +import asyncio import datetime +import logging +import pathlib from collections.abc import Awaitable, Callable, Coroutine from typing import Any, Final import sqlmodel import atr.db as db +import atr.hashing as hashing import atr.models.results as results import atr.models.sql as sql +import atr.tasks.checks as checks import atr.tasks.checks.compare as compare -import atr.tasks.checks.hashing as hashing +import atr.tasks.checks.file_hash as file_hash import atr.tasks.checks.license as license import atr.tasks.checks.paths as paths import atr.tasks.checks.rat as rat @@ -43,17 +48,20 @@ import atr.tasks.vote as vote import atr.util as util -async def asc_checks(asf_uid: str, release: sql.Release, revision: str, signature_path: str) -> list[sql.Task]: +async def asc_checks( + asf_uid: str, release: sql.Release, revision: str, signature_path: str, data: db.Session +) -> list[sql.Task | None]: """Create signature check task for a .asc file.""" tasks = [] if release.committee: tasks.append( - queued( + await queued( asf_uid, sql.TaskType.SIGNATURE_CHECK, release, revision, + data, signature_path, {"committee_name": release.committee.name}, ) @@ -120,9 +128,12 @@ async def draft_checks( relative_paths = [path async for path in util.paths_recursive(revision_path)] async with db.ensure_session(caller_data) as data: - release = await data.release(name=sql.release_name(project_name, release_version), _committee=True).demand( - RuntimeError("Release not found") - ) + release = await data.release( + name=sql.release_name(project_name, release_version), + _committee=True, + _release_policy=True, + _project_release_policy=True, + ).demand(RuntimeError("Release not found")) other_releases = ( await data.release(project_name=project_name, phase=sql.ReleasePhase.RELEASE) .order_by(sql.Release.released) @@ -136,43 +147,29 @@ async def draft_checks( (v for v in release_versions if util.version_sort_key(v.version) < release_version_sortable), None ) for path in relative_paths: - path_str = str(path) - task_function: Callable[[str, sql.Release, str, str], Awaitable[list[sql.Task]]] | None = None - for suffix, func in TASK_FUNCTIONS.items(): - if path.name.endswith(suffix): - task_function = func - break - if task_function: - for task in await task_function(asf_uid, release, revision_number, path_str): - task.revision_number = revision_number - data.add(task) - # TODO: Should we check .json files for their content? - # Ideally we would not have to do that - if path.name.endswith(".cdx.json"): - data.add( - queued( - asf_uid, - sql.TaskType.SBOM_TOOL_SCORE, - release, - revision_number, - path_str, - extra_args={ - "project_name": project_name, - "version_name": release_version, - "revision_number": revision_number, - "previous_release_version": previous_version.version if previous_version else None, - "file_path": path_str, - "asf_uid": asf_uid, - }, - ) - ) + await _draft_file_checks( + asf_uid, + caller_data, + data, + path, + previous_version, + project_name, + release, + release_version, + revision_number, + ) is_podling = False if release.project.committee is not None: if release.project.committee.is_podling: is_podling = True - path_check_task = queued( - asf_uid, sql.TaskType.PATHS_CHECK, release, revision_number, extra_args={"is_podling": is_podling} + path_check_task = await queued( + asf_uid, + sql.TaskType.PATHS_CHECK, + release, + revision_number, + caller_data, + extra_args={"is_podling": is_podling}, ) data.add(path_check_task) if caller_data is None: @@ -181,6 +178,51 @@ async def draft_checks( return len(relative_paths) +async def _draft_file_checks( + asf_uid: str, + caller_data: db.Session | None, + data: db.Session, + path: pathlib.Path, + previous_version: sql.Release | None, + project_name: str, + release: sql.Release, + release_version: str, + revision_number: str, +): + path_str = str(path) + task_function: Callable[[str, sql.Release, str, str, db.Session], Awaitable[list[sql.Task | None]]] | None = None + for suffix, func in TASK_FUNCTIONS.items(): + if path.name.endswith(suffix): + task_function = func + break + if task_function: + for task in await task_function(asf_uid, release, revision_number, path_str, data): + if task: + task.revision_number = revision_number + data.add(task) + # TODO: Should we check .json files for their content? + # Ideally we would not have to do that + if path.name.endswith(".cdx.json"): + data.add( + await queued( + asf_uid, + sql.TaskType.SBOM_TOOL_SCORE, + release, + revision_number, + caller_data, + path_str, + extra_args={ + "project_name": project_name, + "version_name": release_version, + "revision_number": revision_number, + "previous_release_version": previous_version.version if previous_version else None, + "file_path": path_str, + "asf_uid": asf_uid, + }, + ) + ) + + async def keys_import_file( asf_uid: str, project_name: str, version_name: str, revision_number: str, caller_data: db.Session | None = None ) -> None: @@ -230,14 +272,24 @@ async def metadata_update( return task -def queued( +async def queued( asf_uid: str, task_type: sql.TaskType, release: sql.Release, revision_number: str, + data: db.Session | None = None, primary_rel_path: str | None = None, extra_args: dict[str, Any] | None = None, -) -> sql.Task: + check_cache_key: dict[str, Any] | None = None, +) -> sql.Task | None: + if check_cache_key is not None: + logging.info("cache key", check_cache_key) + hash_val = hashing.compute_dict_hash(check_cache_key) + if not data: + raise RuntimeError("DB Session is required for check_cache_key") + existing = await data.check_result(inputs_hash=hash_val).all() + if existing: + return None return sql.Task( status=sql.TaskStatus.QUEUED, task_type=task_type, @@ -259,7 +311,7 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results case sql.TaskType.DISTRIBUTION_WORKFLOW: return gha.trigger_workflow case sql.TaskType.HASHING_CHECK: - return hashing.check + return file_hash.check case sql.TaskType.KEYS_IMPORT_FILE: return keys.import_file case sql.TaskType.LICENSE_FILES: @@ -304,29 +356,53 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results # Otherwise we lose exhaustiveness checking -async def sha_checks(asf_uid: str, release: sql.Release, revision: str, hash_file: str) -> list[sql.Task]: +async def sha_checks( + asf_uid: str, release: sql.Release, revision: str, hash_file: str, data: db.Session +) -> list[sql.Task | None]: """Create hash check task for a .sha256 or .sha512 file.""" tasks = [] - tasks.append(queued(asf_uid, sql.TaskType.HASHING_CHECK, release, revision, hash_file)) + tasks.append(queued(asf_uid, sql.TaskType.HASHING_CHECK, release, revision, data, hash_file)) return tasks -async def tar_gz_checks(asf_uid: str, release: sql.Release, revision: str, path: str) -> list[sql.Task]: +async def tar_gz_checks( + asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session +) -> list[sql.Task | None]: """Create check tasks for a .tar.gz or .tgz file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling + tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path), - queued(asf_uid, sql.TaskType.LICENSE_FILES, release, revision, path, extra_args={"is_podling": is_podling}), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path), - queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, path), - queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, path), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path), + queued( + asf_uid, + sql.TaskType.LICENSE_FILES, + release, + revision, + data, + path, + check_cache_key=await checks.resolve_cache_key( + license.INPUT_POLICY_KEYS, release, revision, {**{"is_podling": is_podling}}, file=path + ), + extra_args={"is_podling": is_podling}, + ), + queued( + asf_uid, + sql.TaskType.LICENSE_HEADERS, + release, + revision, + data, + path, + check_cache_key=await checks.resolve_cache_key(license.INPUT_POLICY_KEYS, release, revision, file=path), + ), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path), + queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, data, path), + queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, data, path), ] - return tasks + return await asyncio.gather(*tasks) async def workflow_update( @@ -356,22 +432,26 @@ async def workflow_update( return task -async def zip_checks(asf_uid: str, release: sql.Release, revision: str, path: str) -> list[sql.Task]: +async def zip_checks( + asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session +) -> list[sql.Task | None]: """Create check tasks for a .zip file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path), - queued(asf_uid, sql.TaskType.LICENSE_FILES, release, revision, path, extra_args={"is_podling": is_podling}), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path), - queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, path), - queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, path), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path), + queued( + asf_uid, sql.TaskType.LICENSE_FILES, release, revision, data, path, extra_args={"is_podling": is_podling} + ), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path), + queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, data, path), + queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, data, path), ] - return tasks + return await asyncio.gather(*tasks) -TASK_FUNCTIONS: Final[dict[str, Callable[..., Coroutine[Any, Any, list[sql.Task]]]]] = { +TASK_FUNCTIONS: Final[dict[str, Callable[..., Coroutine[Any, Any, list[sql.Task | None]]]]] = { ".asc": asc_checks, ".sha256": sha_checks, ".sha512": sha_checks, diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index 08d2c4c..9bc0fdf 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -20,26 +20,25 @@ from __future__ import annotations import dataclasses import datetime import functools -import pathlib -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any import aiofiles import aiofiles.os -import blake3 import sqlmodel if TYPE_CHECKING: + import pathlib from collections.abc import Awaitable, Callable import atr.models.schema as schema import atr.config as config import atr.db as db +import atr.file_paths as file_paths +import atr.hashing as hashing import atr.models.sql as sql import atr.util as util -_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 - # Pydantic does not like Callable types, so we use a dataclass instead # It says: "you should define `Callable`, then call `FunctionArguments.model_rebuild()`" @@ -61,7 +60,7 @@ class Recorder: version_name: str primary_rel_path: str | None member_rel_path: str | None - revision: str + revision_number: str afresh: bool __cached: bool __input_hash: str | None @@ -142,7 +141,7 @@ class Recorder: message=message, data=data, cached=False, - input_hash=self.__input_hash, + inputs_hash=self.__input_hash, ) # It would be more efficient to keep a session open @@ -167,7 +166,7 @@ class Recorder: return self.abs_path_base() / rel_path_part def abs_path_base(self) -> pathlib.Path: - return pathlib.Path(util.get_unfinished_dir(), self.project_name, self.version_name, self.revision_number) + return file_paths.base_path_for_revision(self.project_name, self.version_name, self.revision_number) async def project(self) -> sql.Project: # TODO: Cache project @@ -200,11 +199,8 @@ class Recorder: abs_path = await self.abs_path() return matches(str(abs_path)) - @property - def cached(self) -> bool: - return self.__cached - - async def check_cache(self, path: pathlib.Path) -> bool: + async def cache_key_set(self, policy_keys: list[str], input_args: dict[str, Any] | None = None) -> bool: + path = await self.abs_path() if not await aiofiles.os.path.isfile(path): return False @@ -218,48 +214,74 @@ class Recorder: if await aiofiles.os.path.exists(no_cache_file): return False - self.__input_hash = await _compute_file_hash(path) - async with db.session() as data: - via = sql.validate_instrumented_attribute - subquery = ( - sqlmodel.select( - sql.CheckResult.member_rel_path, - sqlmodel.func.max(via(sql.CheckResult.id)).label("max_id"), - ) - .where(sql.CheckResult.checker == self.checker) - .where(sql.CheckResult.input_hash == self.__input_hash) - .where(sql.CheckResult.primary_rel_path == self.primary_rel_path) - .group_by(sql.CheckResult.member_rel_path) - .subquery() - ) - stmt = sqlmodel.select(sql.CheckResult).join(subquery, via(sql.CheckResult.id) == subquery.c.max_id) - results = await data.execute(stmt) - cached_results = results.scalars().all() - - if not cached_results: - return False - - for cached in cached_results: - new_result = sql.CheckResult( - release_name=self.release_name, - revision_number=self.revision_number, - checker=self.checker, - primary_rel_path=self.primary_rel_path, - member_rel_path=cached.member_rel_path, - created=datetime.datetime.now(datetime.UTC), - status=cached.status, - message=cached.message, - data=cached.data, - cached=True, - input_hash=self.__input_hash, - ) - data.add(new_result) - await data.commit() - - self.__cached = True + release = await data.release( + name=self.release_name, _release_policy=True, _project_release_policy=True + ).demand(RuntimeError(f"Release {self.release_name} not found")) + cache_key = await resolve_cache_key(policy_keys, release, self.revision_number, input_args, path=path) + self.__input_hash = hashing.compute_dict_hash(cache_key) if cache_key else None return True + @property + def cached(self) -> bool: + return self.__cached + + # async def check_cache(self, path: pathlib.Path) -> bool: + # if not await aiofiles.os.path.isfile(path): + # return False + # + # if config.get().DISABLE_CHECK_CACHE: + # return False + # + # if not await self.use_check_cache(): + # return False + # + # no_cache_file = self.abs_path_base() / ".atr-no-cache" + # if await aiofiles.os.path.exists(no_cache_file): + # return False + # + # self.__input_hash = await hashing.compute_file_hash(path) + # + # async with db.session() as data: + # via = sql.validate_instrumented_attribute + # subquery = ( + # sqlmodel.select( + # sql.CheckResult.member_rel_path, + # sqlmodel.func.max(via(sql.CheckResult.id)).label("max_id"), + # ) + # .where(sql.CheckResult.checker == self.checker) + # .where(sql.CheckResult.input_hash == self.__input_hash) + # .where(sql.CheckResult.primary_rel_path == self.primary_rel_path) + # .group_by(sql.CheckResult.member_rel_path) + # .subquery() + # ) + # stmt = sqlmodel.select(sql.CheckResult).join(subquery, via(sql.CheckResult.id) == subquery.c.max_id) + # results = await data.execute(stmt) + # cached_results = results.scalars().all() + # + # if not cached_results: + # return False + # + # for cached in cached_results: + # new_result = sql.CheckResult( + # release_name=self.release_name, + # revision_number=self.revision_number, + # checker=self.checker, + # primary_rel_path=self.primary_rel_path, + # member_rel_path=cached.member_rel_path, + # created=datetime.datetime.now(datetime.UTC), + # status=cached.status, + # message=cached.message, + # data=cached.data, + # cached=True, + # inputs_hash=self.__input_hash, + # ) + # data.add(new_result) + # await data.commit() + # + # self.__cached = True + # return True + async def clear(self, primary_rel_path: str | None = None, member_rel_path: str | None = None) -> None: async with db.session() as data: stmt = sqlmodel.delete(sql.CheckResult).where( @@ -348,6 +370,31 @@ def function_key(func: Callable[..., Any]) -> str: return func.__module__ + "." + func.__name__ +async def resolve_cache_key( + policy_keys: list[str], + release: sql.Release, + revision: str, + args: dict[str, Any] | None = None, + file: str | None = None, + path: pathlib.Path | None = None, +) -> dict[str, Any] | None: + if file is None and path is None: + raise ValueError("Must specify either file or path") + if not args: + args = {} + if path is None: + path = file_paths.revision_path_for_file(release.project_name, release.version, revision, file) + file_hash = await hashing.compute_file_hash(path) + cache_key = {"file_hash": file_hash} + + policy = release.release_policy or release.project.release_policy + if len(policy_keys) > 0 and policy is not None: + policy_dict = policy.model_dump(exclude_none=True) + return {**cache_key, **args, **{k: policy_dict[k] for k in policy_keys if k in policy_dict}} + else: + return {**cache_key, **args} + + def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Decorator to specify the parameters for a check.""" @@ -360,11 +407,3 @@ def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Calla return wrapper return decorator - - -async def _compute_file_hash(path: pathlib.Path) -> str: - hasher = blake3.blake3() - async with aiofiles.open(path, "rb") as f: - while chunk := await f.read(_HASH_CHUNK_SIZE): - hasher.update(chunk) - return f"blake3:{hasher.hexdigest()}" diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/file_hash.py similarity index 100% rename from atr/tasks/checks/hashing.py rename to atr/tasks/checks/file_hash.py diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py index 066fef3..3e46f11 100644 --- a/atr/tasks/checks/license.py +++ b/atr/tasks/checks/license.py @@ -79,6 +79,9 @@ INCLUDED_PATTERNS: Final[list[str]] = [ r"\.(pl|pm|t)$", # Perl ] +# Release policy fields which this task relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [""] + # Types @@ -166,9 +169,11 @@ async def headers(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None - if await recorder.check_cache(artifact_abs_path): - log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") - return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + + # if await recorder.check_cache(artifact_abs_path): + # log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") + # return None log.info(f"Checking license headers for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/migrations/versions/0049_2026.02.11_5b874ed2.py b/migrations/versions/0049_2026.02.11_5b874ed2.py new file mode 100644 index 0000000..e4730cd --- /dev/null +++ b/migrations/versions/0049_2026.02.11_5b874ed2.py @@ -0,0 +1,37 @@ +"""Rename input_hash inputs_hash + +Revision ID: 0049_2026.02.11_5b874ed2 +Revises: 0048_2026.02.06_blocking_to_blocker +Create Date: 2026-02-11 13:42:59.712570+00:00 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# Revision identifiers, used by Alembic +revision: str = "0049_2026.02.11_5b874ed2" +down_revision: str | None = "0048_2026.02.06_blocking_to_blocker" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.add_column(sa.Column("inputs_hash", sa.String(), nullable=True)) + batch_op.drop_index(batch_op.f("ix_checkresult_input_hash")) + batch_op.create_index(batch_op.f("ix_checkresult_inputs_hash"), ["inputs_hash"], unique=False) + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.execute("UPDATE checkresult SET inputs_hash = input_hash") + batch_op.drop_column("input_hash") + + +def downgrade() -> None: + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.add_column(sa.Column("input_hash", sa.VARCHAR(), nullable=True)) + batch_op.drop_index(batch_op.f("ix_checkresult_inputs_hash")) + batch_op.create_index(batch_op.f("ix_checkresult_input_hash"), ["input_hash"], unique=False) + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.execute("UPDATE checkresult SET input_hash = inputs_hash") + batch_op.drop_column("inputs_hash") diff --git a/tests/unit/recorders.py b/tests/unit/recorders.py index 33e5af0..47c772e 100644 --- a/tests/unit/recorders.py +++ b/tests/unit/recorders.py @@ -63,7 +63,7 @@ class RecorderStub(checks.Recorder): status=status, message=message, data=data, - input_hash=None, + inputs_hash=None, ) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
