This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push:
new 82ae2a0 Fix the order of functions in the remaining unfixed modules
82ae2a0 is described below
commit 82ae2a0a090b0e27eff6e13d66688071188cbf79
Author: Sean B. Palmer <[email protected]>
AuthorDate: Tue Apr 15 15:10:28 2025 +0100
Fix the order of functions in the remaining unfixed modules
---
atr/blueprints/api/api.py | 34 ++--
atr/datasources/apache.py | 40 ++--
atr/db/__init__.py | 77 +++----
atr/db/models.py | 40 ++--
atr/tasks/bulk.py | 452 +++++++++++++++++++++---------------------
atr/tasks/checks/license.py | 56 +++---
atr/tasks/checks/paths.py | 104 +++++-----
atr/tasks/checks/targz.py | 38 ++--
atr/tasks/checks/zipformat.py | 40 ++--
atr/tasks/sbom.py | 88 ++++----
10 files changed, 487 insertions(+), 482 deletions(-)
diff --git a/atr/blueprints/api/api.py b/atr/blueprints/api/api.py
index 9bc1118..51563be 100644
--- a/atr/blueprints/api/api.py
+++ b/atr/blueprints/api/api.py
@@ -35,23 +35,6 @@ import atr.db.models as models
# For now, just explicitly dump the model.
[email protected]("/projects/<name>")
-@quart_schema.validate_response(models.Committee, 200)
-async def project_by_name(name: str) -> tuple[Mapping, int]:
- async with db.session() as data:
- committee = await
data.committee(name=name).demand(exceptions.NotFound())
- return committee.model_dump(), 200
-
-
[email protected]("/projects")
-@quart_schema.validate_response(list[models.Committee], 200)
-async def projects() -> tuple[list[Mapping], int]:
- """List all projects in the database."""
- async with db.session() as data:
- committees = await data.committee().all()
- return [committee.model_dump() for committee in committees], 200
-
-
@dataclasses.dataclass
class Pagination:
offset: int = 0
@@ -72,3 +55,20 @@ async def api_tasks(query_args: Pagination) ->
quart.Response:
count = (await
data.execute(sqlalchemy.select(sqlalchemy.func.count(models.Task.id)))).scalar_one()
# type: ignore
result = {"data": [x.model_dump(exclude={"result"}) for x in
paged_tasks], "count": count}
return quart.jsonify(result)
+
+
[email protected]("/projects/<name>")
+@quart_schema.validate_response(models.Committee, 200)
+async def project_by_name(name: str) -> tuple[Mapping, int]:
+ async with db.session() as data:
+ committee = await
data.committee(name=name).demand(exceptions.NotFound())
+ return committee.model_dump(), 200
+
+
[email protected]("/projects")
+@quart_schema.validate_response(list[models.Committee], 200)
+async def projects() -> tuple[list[Mapping], int]:
+ """List all projects in the database."""
+ async with db.session() as data:
+ committees = await data.committee().all()
+ return [committee.model_dump() for committee in committees], 200
diff --git a/atr/datasources/apache.py b/atr/datasources/apache.py
index 7ad6112..06fe97b 100644
--- a/atr/datasources/apache.py
+++ b/atr/datasources/apache.py
@@ -160,15 +160,6 @@ class ProjectsData(_DictRootModel[ProjectStatus]):
pass
-async def get_ldap_projects_data() -> LDAPProjectsData:
- async with httpx.AsyncClient() as client:
- response = await client.get(_WHIMSY_PROJECTS_URL)
- response.raise_for_status()
- data = response.json()
-
- return LDAPProjectsData.model_validate(data)
-
-
async def get_active_committee_data() -> CommitteeData:
"""Returns the list of currently active committees."""
@@ -180,17 +171,6 @@ async def get_active_committee_data() -> CommitteeData:
return CommitteeData.model_validate(data)
-async def get_retired_committee_data() -> RetiredCommitteeData:
- """Returns the list of retired committees."""
-
- async with httpx.AsyncClient() as client:
- response = await client.get(_WHIMSY_COMMITTEE_RETIRED_URL)
- response.raise_for_status()
- data = response.json()
-
- return RetiredCommitteeData.model_validate(data)
-
-
async def get_current_podlings_data() -> PodlingsData:
"""Returns the list of current podlings."""
@@ -211,6 +191,15 @@ async def get_groups_data() -> GroupsData:
return GroupsData.model_validate(data)
+async def get_ldap_projects_data() -> LDAPProjectsData:
+ async with httpx.AsyncClient() as client:
+ response = await client.get(_WHIMSY_PROJECTS_URL)
+ response.raise_for_status()
+ data = response.json()
+
+ return LDAPProjectsData.model_validate(data)
+
+
async def get_projects_data() -> ProjectsData:
"""Returns the list of projects."""
@@ -219,3 +208,14 @@ async def get_projects_data() -> ProjectsData:
response.raise_for_status()
data = response.json()
return ProjectsData.model_validate(data)
+
+
+async def get_retired_committee_data() -> RetiredCommitteeData:
+ """Returns the list of retired committees."""
+
+ async with httpx.AsyncClient() as client:
+ response = await client.get(_WHIMSY_COMMITTEE_RETIRED_URL)
+ response.raise_for_status()
+ data = response.json()
+
+ return RetiredCommitteeData.model_validate(data)
diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 4943d00..4c829dc 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -46,6 +46,9 @@ _global_atr_sessionmaker:
sqlalchemy.ext.asyncio.async_sessionmaker | None = Non
T = TypeVar("T")
+# TODO: The not set class should be NotSet
+# And the constant should be _NOT_SET
+
class _NotSetType:
"""
@@ -70,18 +73,12 @@ class _NotSetType:
return NotSet
-NotSet = _NotSetType()
+# TODO: Technically a constant, and only used in this file
+# Should be _NOT_SET
+NotSet: Final[_NotSetType] = _NotSetType()
type Opt[T] = T | _NotSetType
-def is_defined(v: T | _NotSetType) -> TypeGuard[T]:
- return not isinstance(v, _NotSetType)
-
-
-def is_undefined(v: T | _NotSetType) -> TypeGuard[_NotSetType]: # pyright:
ignore [reportInvalidTypeVarUse]
- return isinstance(v, _NotSetType)
-
-
class Query(Generic[T]):
def __init__(self, session: Session, query: expression.SelectOfScalar[T]):
self.query = query
@@ -482,6 +479,29 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
return Query(self, query)
+async def create_async_engine(app_config: type[config.AppConfig]) ->
sqlalchemy.ext.asyncio.AsyncEngine:
+ sqlite_url = f"sqlite+aiosqlite://{app_config.SQLITE_DB_PATH}"
+ # Use aiosqlite for async SQLite access
+ engine = sqlalchemy.ext.asyncio.create_async_engine(
+ sqlite_url,
+ connect_args={
+ "check_same_thread": False,
+ "timeout": 30,
+ },
+ )
+
+ # Set SQLite pragmas for better performance
+ # Use 64 MB for the cache_size, and 5000ms for busy_timeout
+ async with engine.begin() as conn:
+ await conn.execute(sql.text("PRAGMA journal_mode=WAL"))
+ await conn.execute(sql.text("PRAGMA synchronous=NORMAL"))
+ await conn.execute(sql.text("PRAGMA cache_size=-64000"))
+ await conn.execute(sql.text("PRAGMA foreign_keys=ON"))
+ await conn.execute(sql.text("PRAGMA busy_timeout=5000"))
+
+ return engine
+
+
def init_database(app: base.QuartApp) -> None:
"""
Creates and initializes the database for a QuartApp.
@@ -532,35 +552,12 @@ async def init_database_for_worker() -> None:
)
-async def shutdown_database() -> None:
- if _global_atr_engine:
- _LOGGER.info("Closing database")
- await _global_atr_engine.dispose()
- else:
- _LOGGER.info("No database to close")
-
+def is_defined(v: T | _NotSetType) -> TypeGuard[T]:
+ return not isinstance(v, _NotSetType)
-async def create_async_engine(app_config: type[config.AppConfig]) ->
sqlalchemy.ext.asyncio.AsyncEngine:
- sqlite_url = f"sqlite+aiosqlite://{app_config.SQLITE_DB_PATH}"
- # Use aiosqlite for async SQLite access
- engine = sqlalchemy.ext.asyncio.create_async_engine(
- sqlite_url,
- connect_args={
- "check_same_thread": False,
- "timeout": 30,
- },
- )
- # Set SQLite pragmas for better performance
- # Use 64 MB for the cache_size, and 5000ms for busy_timeout
- async with engine.begin() as conn:
- await conn.execute(sql.text("PRAGMA journal_mode=WAL"))
- await conn.execute(sql.text("PRAGMA synchronous=NORMAL"))
- await conn.execute(sql.text("PRAGMA cache_size=-64000"))
- await conn.execute(sql.text("PRAGMA foreign_keys=ON"))
- await conn.execute(sql.text("PRAGMA busy_timeout=5000"))
-
- return engine
+def is_undefined(v: T | _NotSetType) -> TypeGuard[_NotSetType]: # pyright:
ignore [reportInvalidTypeVarUse]
+ return isinstance(v, _NotSetType)
# async def recent_tasks(data: Session, release_name: str, file_path: str,
modified: int) -> dict[str, models.Task]:
@@ -627,6 +624,14 @@ def session() -> Session:
return util.validate_as_type(_global_atr_sessionmaker(), Session)
+async def shutdown_database() -> None:
+ if _global_atr_engine:
+ _LOGGER.info("Closing database")
+ await _global_atr_engine.dispose()
+ else:
+ _LOGGER.info("No database to close")
+
+
def validate_instrumented_attribute(obj: Any) -> orm.InstrumentedAttribute:
"""Check if the given object is an InstrumentedAttribute."""
if not isinstance(obj, orm.InstrumentedAttribute):
diff --git a/atr/db/models.py b/atr/db/models.py
index 7e2a656..9a1b3a3 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -413,26 +413,6 @@ class Release(sqlmodel.SQLModel, table=True):
return project.committee
-def release_name(project_name: str, version_name: str) -> str:
- """Return the release name for a given project and version."""
- return f"{project_name}-{version_name}"
-
-
-def project_version(release_name: str) -> tuple[str, str]:
- """Return the project and version for a given release name."""
- try:
- project_name, version_name = release_name.rsplit("-", 1)
- return (project_name, version_name)
- except ValueError:
- raise ValueError(f"Invalid release name: {release_name}")
-
-
[email protected]_for(Release, "before_insert")
-def check_release_name(_mapper: sqlalchemy.orm.Mapper, _connection:
sqlalchemy.Connection, release: Release) -> None:
- if release.name == "":
- release.name = release_name(release.project.name, release.version)
-
-
class SSHKey(sqlmodel.SQLModel, table=True):
fingerprint: str = sqlmodel.Field(primary_key=True)
key: str
@@ -474,3 +454,23 @@ class TextValue(sqlmodel.SQLModel, table=True):
ns: str = sqlmodel.Field(primary_key=True, index=True)
key: str = sqlmodel.Field(primary_key=True, index=True)
value: str = sqlmodel.Field()
+
+
[email protected]_for(Release, "before_insert")
+def check_release_name(_mapper: sqlalchemy.orm.Mapper, _connection:
sqlalchemy.Connection, release: Release) -> None:
+ if release.name == "":
+ release.name = release_name(release.project.name, release.version)
+
+
+def project_version(release_name: str) -> tuple[str, str]:
+ """Return the project and version for a given release name."""
+ try:
+ project_name, version_name = release_name.rsplit("-", 1)
+ return (project_name, version_name)
+ except ValueError:
+ raise ValueError(f"Invalid release name: {release_name}")
+
+
+def release_name(project_name: str, version_name: str) -> str:
+ """Return the release name for a given project and version."""
+ return f"{project_name}-{version_name}"
diff --git a/atr/tasks/bulk.py b/atr/tasks/bulk.py
index d2f9993..76e221f 100644
--- a/atr/tasks/bulk.py
+++ b/atr/tasks/bulk.py
@@ -126,6 +126,188 @@ class Args:
return args_obj
+class LinkExtractor(html.parser.HTMLParser):
+ def __init__(self) -> None:
+ super().__init__()
+ self.links: list[str] = []
+
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]])
-> None:
+ if tag == "a":
+ for attr, value in attrs:
+ if attr == "href" and value:
+ self.links.append(value)
+
+
+async def artifact_download(url: str, semaphore: asyncio.Semaphore) -> bool:
+ _LOGGER.debug(f"Starting download of artifact: {url}")
+ try:
+ success = await artifact_download_core(url, semaphore)
+ if success:
+ _LOGGER.info(f"Successfully downloaded artifact: {url}")
+ else:
+ _LOGGER.warning(f"Failed to download artifact: {url}")
+ return success
+ except Exception as e:
+ _LOGGER.exception(f"Error downloading artifact {url}: {e}")
+ return False
+
+
+async def artifact_download_core(url: str, semaphore: asyncio.Semaphore) ->
bool:
+ _LOGGER.debug(f"Starting core download process for {url}")
+ async with semaphore:
+ _LOGGER.debug(f"Acquired semaphore for {url}")
+ # TODO: We flatten the hierarchy to get the filename
+ # We should preserve the hierarchy
+ filename = url.split("/")[-1]
+ if filename.startswith("."):
+ raise ValueError(f"Invalid filename: {filename}")
+ local_path = os.path.join("downloads", filename)
+
+ # Create download directory if it doesn't exist
+ # TODO: Check whether local_path itself exists first
+ os.makedirs("downloads", exist_ok=True)
+ _LOGGER.debug(f"Downloading {url} to {local_path}")
+
+ try:
+ async with aiohttp.ClientSession() as session:
+ _LOGGER.debug(f"Created HTTP session for {url}")
+ async with session.get(url) as response:
+ if response.status != 200:
+ _LOGGER.warning(f"Failed to download {url}: HTTP
{response.status}")
+ return False
+
+ total_size = int(response.headers.get("Content-Length", 0))
+ if total_size:
+ _LOGGER.info(f"Content-Length: {total_size} bytes for
{url}")
+
+ chunk_size = 8192
+ downloaded = 0
+ _LOGGER.debug(f"Writing file to {local_path} with chunk
size {chunk_size}")
+
+ async with aiofiles.open(local_path, "wb") as f:
+ async for chunk in
response.content.iter_chunked(chunk_size):
+ await f.write(chunk)
+ downloaded += len(chunk)
+ # if total_size:
+ # progress = (downloaded / total_size) * 100
+ # if downloaded % (chunk_size * 128) == 0:
+ # _LOGGER.debug(
+ # f"Download progress for {filename}:"
+ # f" {progress:.1f}%
({downloaded}/{total_size} bytes)"
+ # )
+
+ _LOGGER.info(f"Download complete: {url} -> {local_path}
({downloaded} bytes)")
+ return True
+
+ except Exception as e:
+ _LOGGER.exception(f"Error during download of {url}: {e}")
+ # Remove partial download if an error occurred
+ if os.path.exists(local_path):
+ _LOGGER.debug(f"Removing partial download: {local_path}")
+ try:
+ os.remove(local_path)
+ except Exception as del_err:
+ _LOGGER.error(f"Error removing partial download
{local_path}: {del_err}")
+ return False
+
+
+async def artifact_urls(args: Args, queue: asyncio.Queue, semaphore:
asyncio.Semaphore) -> tuple[list[str], list[str]]:
+ _LOGGER.info(f"Starting URL crawling from {args.base_url}")
+ await database_message(f"Crawling artifact URLs from {args.base_url}")
+ signatures: list[str] = []
+ artifacts: list[str] = []
+ seen: set[str] = set()
+
+ _LOGGER.debug(f"Adding base URL to queue: {args.base_url}")
+ await queue.put(args.base_url)
+
+ _LOGGER.debug("Starting crawl loop")
+ depth = 0
+ # Start with just the base URL
+ urls_at_current_depth = 1
+ urls_at_next_depth = 0
+
+ while (not queue.empty()) and (depth < args.max_depth):
+ _LOGGER.debug(f"Processing depth {depth + 1}/{args.max_depth}, queue
size: {queue.qsize()}")
+
+ # Process all URLs at the current depth before moving to the next
+ for _ in range(urls_at_current_depth):
+ if queue.empty():
+ break
+
+ url = await queue.get()
+ _LOGGER.debug(f"Processing URL: {url}")
+
+ if url_excluded(seen, url, args):
+ continue
+
+ seen.add(url)
+ _LOGGER.debug(f"Checking URL for file types: {args.file_types}")
+
+ # If not a target file type, try to parse HTML links
+ if not check_matches(args, url, artifacts, signatures):
+ _LOGGER.debug(f"URL is not a target file, parsing HTML: {url}")
+ try:
+ new_urls = await download_html(url, semaphore)
+ _LOGGER.debug(f"Found {len(new_urls)} new URLs in {url}")
+ for new_url in new_urls:
+ if new_url not in seen:
+ _LOGGER.debug(f"Adding new URL to queue:
{new_url}")
+ await queue.put(new_url)
+ urls_at_next_depth += 1
+ except Exception as e:
+ _LOGGER.warning(f"Error parsing HTML from {url}: {e}")
+ # Move to next depth
+ depth += 1
+ urls_at_current_depth = urls_at_next_depth
+ urls_at_next_depth = 0
+
+ # Update database with progress message
+ progress_msg = f"Crawled {len(seen)} URLs, found {len(artifacts)}
artifacts (depth {depth}/{args.max_depth})"
+ await database_message(progress_msg, progress=(30 + min(50, depth *
10), 100))
+ _LOGGER.debug(f"Moving to depth {depth + 1}, {urls_at_current_depth}
URLs to process")
+
+ _LOGGER.info(f"URL crawling complete. Found {len(artifacts)} artifacts and
{len(signatures)} signatures")
+ return signatures, artifacts
+
+
+async def artifacts_download(artifacts: list[str], semaphore:
asyncio.Semaphore) -> list[str]:
+ """Download artifacts with progress tracking."""
+ size = len(artifacts)
+ _LOGGER.info(f"Starting download of {size} artifacts")
+ downloaded = []
+
+ for i, artifact in enumerate(artifacts):
+ progress_percent = int((i / size) * 100) if (size > 0) else 100
+ progress_msg = f"Downloading {i + 1}/{size} artifacts"
+ _LOGGER.info(f"{progress_msg}: {artifact}")
+ await database_message(progress_msg, progress=(progress_percent, 100))
+
+ success = await artifact_download(artifact, semaphore)
+ if success:
+ _LOGGER.debug(f"Successfully downloaded: {artifact}")
+ downloaded.append(artifact)
+ else:
+ _LOGGER.warning(f"Failed to download: {artifact}")
+
+ _LOGGER.info(f"Download complete. Successfully downloaded
{len(downloaded)}/{size} artifacts")
+ await database_message(f"Downloaded {len(downloaded)} artifacts",
progress=(100, 100))
+ return downloaded
+
+
+def check_matches(args: Args, url: str, artifacts: list[str], signatures:
list[str]) -> bool:
+ for type in args.file_types:
+ if url.endswith(type):
+ _LOGGER.info(f"Found artifact: {url}")
+ artifacts.append(url)
+ return True
+ elif url.endswith(type + ".asc"):
+ _LOGGER.info(f"Found signature: {url}")
+ signatures.append(url)
+ return True
+ return False
+
+
async def database_message(msg: str, progress: tuple[int, int] | None = None)
-> None:
"""Update database with message and progress."""
_LOGGER.debug(f"Updating database with message: '{msg}', progress:
{progress}")
@@ -144,26 +326,23 @@ async def database_message(msg: str, progress: tuple[int,
int] | None = None) ->
_LOGGER.info(f"Continuing despite database error. Message was:
'{msg}'")
-async def get_db_session() -> sqlalchemy.ext.asyncio.AsyncSession:
- """Get a reusable database session."""
- global global_db_connection
+def database_progress_percentage_calculate(progress: tuple[int, int] | None)
-> int:
+ """Calculate percentage from progress tuple."""
+ _LOGGER.debug(f"Calculating percentage from progress tuple: {progress}")
+ if progress is None:
+ _LOGGER.debug("Progress is None, returning 0%")
+ return 0
- try:
- # Create connection only if it doesn't exist already
- if global_db_connection is None:
- db_url = "sqlite+aiosqlite:///atr.db"
- _LOGGER.debug(f"Creating database engine: {db_url}")
+ current, total = progress
- engine = sqlalchemy.ext.asyncio.create_async_engine(db_url)
- global_db_connection = sqlalchemy.ext.asyncio.async_sessionmaker(
- engine, class_=sqlalchemy.ext.asyncio.AsyncSession,
expire_on_commit=False
- )
+ # Avoid division by zero
+ if total == 0:
+ _LOGGER.warning("Total is zero in progress tuple, avoiding division by
zero")
+ return 0
- connection: sqlalchemy.ext.asyncio.AsyncSession =
global_db_connection()
- return connection
- except Exception as e:
- _LOGGER.exception(f"Error creating database session: {e}")
- raise
+ percentage = min(100, int((current / total) * 100))
+ _LOGGER.debug(f"Calculated percentage: {percentage}% ({current}/{total})")
+ return percentage
async def database_task_id_get() -> int | None:
@@ -265,25 +444,6 @@ async def database_task_update_execute(task_id: int, msg:
str, progress_pct: int
_LOGGER.exception(f"Error updating task {task_id} in database: {e}")
-def database_progress_percentage_calculate(progress: tuple[int, int] | None)
-> int:
- """Calculate percentage from progress tuple."""
- _LOGGER.debug(f"Calculating percentage from progress tuple: {progress}")
- if progress is None:
- _LOGGER.debug("Progress is None, returning 0%")
- return 0
-
- current, total = progress
-
- # Avoid division by zero
- if total == 0:
- _LOGGER.warning("Total is zero in progress tuple, avoiding division by
zero")
- return 0
-
- percentage = min(100, int((current / total) * 100))
- _LOGGER.debug(f"Calculated percentage: {percentage}% ({current}/{total})")
- return percentage
-
-
async def download(args: dict[str, Any]) -> tuple[models.TaskStatus, str |
None, tuple[Any, ...]]:
"""Download bulk package from URL."""
# Returns (status, error, result)
@@ -360,99 +520,6 @@ async def download_core(args_dict: dict[str, Any]) ->
tuple[models.TaskStatus, s
)
-async def artifact_urls(args: Args, queue: asyncio.Queue, semaphore:
asyncio.Semaphore) -> tuple[list[str], list[str]]:
- _LOGGER.info(f"Starting URL crawling from {args.base_url}")
- await database_message(f"Crawling artifact URLs from {args.base_url}")
- signatures: list[str] = []
- artifacts: list[str] = []
- seen: set[str] = set()
-
- _LOGGER.debug(f"Adding base URL to queue: {args.base_url}")
- await queue.put(args.base_url)
-
- _LOGGER.debug("Starting crawl loop")
- depth = 0
- # Start with just the base URL
- urls_at_current_depth = 1
- urls_at_next_depth = 0
-
- while (not queue.empty()) and (depth < args.max_depth):
- _LOGGER.debug(f"Processing depth {depth + 1}/{args.max_depth}, queue
size: {queue.qsize()}")
-
- # Process all URLs at the current depth before moving to the next
- for _ in range(urls_at_current_depth):
- if queue.empty():
- break
-
- url = await queue.get()
- _LOGGER.debug(f"Processing URL: {url}")
-
- if url_excluded(seen, url, args):
- continue
-
- seen.add(url)
- _LOGGER.debug(f"Checking URL for file types: {args.file_types}")
-
- # If not a target file type, try to parse HTML links
- if not check_matches(args, url, artifacts, signatures):
- _LOGGER.debug(f"URL is not a target file, parsing HTML: {url}")
- try:
- new_urls = await download_html(url, semaphore)
- _LOGGER.debug(f"Found {len(new_urls)} new URLs in {url}")
- for new_url in new_urls:
- if new_url not in seen:
- _LOGGER.debug(f"Adding new URL to queue:
{new_url}")
- await queue.put(new_url)
- urls_at_next_depth += 1
- except Exception as e:
- _LOGGER.warning(f"Error parsing HTML from {url}: {e}")
- # Move to next depth
- depth += 1
- urls_at_current_depth = urls_at_next_depth
- urls_at_next_depth = 0
-
- # Update database with progress message
- progress_msg = f"Crawled {len(seen)} URLs, found {len(artifacts)}
artifacts (depth {depth}/{args.max_depth})"
- await database_message(progress_msg, progress=(30 + min(50, depth *
10), 100))
- _LOGGER.debug(f"Moving to depth {depth + 1}, {urls_at_current_depth}
URLs to process")
-
- _LOGGER.info(f"URL crawling complete. Found {len(artifacts)} artifacts and
{len(signatures)} signatures")
- return signatures, artifacts
-
-
-def check_matches(args: Args, url: str, artifacts: list[str], signatures:
list[str]) -> bool:
- for type in args.file_types:
- if url.endswith(type):
- _LOGGER.info(f"Found artifact: {url}")
- artifacts.append(url)
- return True
- elif url.endswith(type + ".asc"):
- _LOGGER.info(f"Found signature: {url}")
- signatures.append(url)
- return True
- return False
-
-
-def url_excluded(seen: set[str], url: str, args: Args) -> bool:
- # Filter for sorting URLs to avoid redundant crawling
- sorting_patterns = ["?C=N;O=", "?C=M;O=", "?C=S;O=", "?C=D;O="]
-
- if not url.startswith(args.base_url):
- _LOGGER.debug(f"Skipping URL outside base URL scope: {url}")
- return True
-
- if url in seen:
- _LOGGER.debug(f"Skipping already seen URL: {url}")
- return True
-
- # Skip sorting URLs to avoid redundant crawling
- if any(pattern in url for pattern in sorting_patterns):
- _LOGGER.debug(f"Skipping sorting URL: {url}")
- return True
-
- return False
-
-
async def download_html(url: str, semaphore: asyncio.Semaphore) -> list[str]:
"""Download HTML and extract links."""
_LOGGER.debug(f"Downloading HTML from: {url}")
@@ -492,18 +559,6 @@ async def download_html_core(url: str, semaphore:
asyncio.Semaphore) -> list[str
return urls
-class LinkExtractor(html.parser.HTMLParser):
- def __init__(self) -> None:
- super().__init__()
- self.links: list[str] = []
-
- def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]])
-> None:
- if tag == "a":
- for attr, value in attrs:
- if attr == "href" and value:
- self.links.append(value)
-
-
def extract_links_from_html(html: str, base_url: str) -> list[str]:
"""Extract links from HTML content using html.parser."""
parser = LinkExtractor()
@@ -525,98 +580,43 @@ def extract_links_from_html(html: str, base_url: str) ->
list[str]:
return processed_urls
-async def artifacts_download(artifacts: list[str], semaphore:
asyncio.Semaphore) -> list[str]:
- """Download artifacts with progress tracking."""
- size = len(artifacts)
- _LOGGER.info(f"Starting download of {size} artifacts")
- downloaded = []
-
- for i, artifact in enumerate(artifacts):
- progress_percent = int((i / size) * 100) if (size > 0) else 100
- progress_msg = f"Downloading {i + 1}/{size} artifacts"
- _LOGGER.info(f"{progress_msg}: {artifact}")
- await database_message(progress_msg, progress=(progress_percent, 100))
-
- success = await artifact_download(artifact, semaphore)
- if success:
- _LOGGER.debug(f"Successfully downloaded: {artifact}")
- downloaded.append(artifact)
- else:
- _LOGGER.warning(f"Failed to download: {artifact}")
-
- _LOGGER.info(f"Download complete. Successfully downloaded
{len(downloaded)}/{size} artifacts")
- await database_message(f"Downloaded {len(downloaded)} artifacts",
progress=(100, 100))
- return downloaded
-
+async def get_db_session() -> sqlalchemy.ext.asyncio.AsyncSession:
+ """Get a reusable database session."""
+ global global_db_connection
-async def artifact_download(url: str, semaphore: asyncio.Semaphore) -> bool:
- _LOGGER.debug(f"Starting download of artifact: {url}")
try:
- success = await artifact_download_core(url, semaphore)
- if success:
- _LOGGER.info(f"Successfully downloaded artifact: {url}")
- else:
- _LOGGER.warning(f"Failed to download artifact: {url}")
- return success
- except Exception as e:
- _LOGGER.exception(f"Error downloading artifact {url}: {e}")
- return False
-
+ # Create connection only if it doesn't exist already
+ if global_db_connection is None:
+ db_url = "sqlite+aiosqlite:///atr.db"
+ _LOGGER.debug(f"Creating database engine: {db_url}")
-async def artifact_download_core(url: str, semaphore: asyncio.Semaphore) ->
bool:
- _LOGGER.debug(f"Starting core download process for {url}")
- async with semaphore:
- _LOGGER.debug(f"Acquired semaphore for {url}")
- # TODO: We flatten the hierarchy to get the filename
- # We should preserve the hierarchy
- filename = url.split("/")[-1]
- if filename.startswith("."):
- raise ValueError(f"Invalid filename: {filename}")
- local_path = os.path.join("downloads", filename)
+ engine = sqlalchemy.ext.asyncio.create_async_engine(db_url)
+ global_db_connection = sqlalchemy.ext.asyncio.async_sessionmaker(
+ engine, class_=sqlalchemy.ext.asyncio.AsyncSession,
expire_on_commit=False
+ )
- # Create download directory if it doesn't exist
- # TODO: Check whether local_path itself exists first
- os.makedirs("downloads", exist_ok=True)
- _LOGGER.debug(f"Downloading {url} to {local_path}")
+ connection: sqlalchemy.ext.asyncio.AsyncSession =
global_db_connection()
+ return connection
+ except Exception as e:
+ _LOGGER.exception(f"Error creating database session: {e}")
+ raise
- try:
- async with aiohttp.ClientSession() as session:
- _LOGGER.debug(f"Created HTTP session for {url}")
- async with session.get(url) as response:
- if response.status != 200:
- _LOGGER.warning(f"Failed to download {url}: HTTP
{response.status}")
- return False
- total_size = int(response.headers.get("Content-Length", 0))
- if total_size:
- _LOGGER.info(f"Content-Length: {total_size} bytes for
{url}")
+def url_excluded(seen: set[str], url: str, args: Args) -> bool:
+ # Filter for sorting URLs to avoid redundant crawling
+ sorting_patterns = ["?C=N;O=", "?C=M;O=", "?C=S;O=", "?C=D;O="]
- chunk_size = 8192
- downloaded = 0
- _LOGGER.debug(f"Writing file to {local_path} with chunk
size {chunk_size}")
+ if not url.startswith(args.base_url):
+ _LOGGER.debug(f"Skipping URL outside base URL scope: {url}")
+ return True
- async with aiofiles.open(local_path, "wb") as f:
- async for chunk in
response.content.iter_chunked(chunk_size):
- await f.write(chunk)
- downloaded += len(chunk)
- # if total_size:
- # progress = (downloaded / total_size) * 100
- # if downloaded % (chunk_size * 128) == 0:
- # _LOGGER.debug(
- # f"Download progress for {filename}:"
- # f" {progress:.1f}%
({downloaded}/{total_size} bytes)"
- # )
+ if url in seen:
+ _LOGGER.debug(f"Skipping already seen URL: {url}")
+ return True
- _LOGGER.info(f"Download complete: {url} -> {local_path}
({downloaded} bytes)")
- return True
+ # Skip sorting URLs to avoid redundant crawling
+ if any(pattern in url for pattern in sorting_patterns):
+ _LOGGER.debug(f"Skipping sorting URL: {url}")
+ return True
- except Exception as e:
- _LOGGER.exception(f"Error during download of {url}: {e}")
- # Remove partial download if an error occurred
- if os.path.exists(local_path):
- _LOGGER.debug(f"Removing partial download: {local_path}")
- try:
- os.remove(local_path)
- except Exception as del_err:
- _LOGGER.error(f"Error removing partial download
{local_path}: {del_err}")
- return False
+ return False
diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py
index 7a22a39..cacee1e 100644
--- a/atr/tasks/checks/license.py
+++ b/atr/tasks/checks/license.py
@@ -331,6 +331,27 @@ def _files_check_core_logic_license(tf: tarfile.TarFile,
member: tarfile.TarInfo
return sha3.hexdigest() ==
"8a0a8fb6c73ef27e4322391c7b28e5b38639e64e58c40a2c7a51cec6e7915a6a"
+def _files_check_core_logic_notice(tf: tarfile.TarFile, member:
tarfile.TarInfo) -> tuple[bool, list[str]]:
+ """Verify that the NOTICE file follows the required format."""
+ f = tf.extractfile(member)
+ if not f:
+ return False, ["Could not read NOTICE file"]
+
+ content = f.read().decode("utf-8")
+ issues = []
+
+ if not re.search(r"Apache\s+[\w\-\.]+", content, re.MULTILINE):
+ issues.append("Missing or invalid Apache product header")
+ if not re.search(r"Copyright\s+(?:\d{4}|\d{4}-\d{4})\s+The Apache Software
Foundation", content, re.MULTILINE):
+ issues.append("Missing or invalid copyright statement")
+ if not re.search(
+ r"This product includes software developed at\s*\nThe Apache Software
Foundation \(.*?\)", content, re.DOTALL
+ ):
+ issues.append("Missing or invalid foundation attribution")
+
+ return len(issues) == 0, issues
+
+
def _files_messages_build(
root_dir: str,
files_found: list[str],
@@ -356,28 +377,15 @@ def _files_messages_build(
return messages
-def _files_check_core_logic_notice(tf: tarfile.TarFile, member:
tarfile.TarInfo) -> tuple[bool, list[str]]:
- """Verify that the NOTICE file follows the required format."""
- f = tf.extractfile(member)
- if not f:
- return False, ["Could not read NOTICE file"]
-
- content = f.read().decode("utf-8")
- issues = []
-
- if not re.search(r"Apache\s+[\w\-\.]+", content, re.MULTILINE):
- issues.append("Missing or invalid Apache product header")
- if not re.search(r"Copyright\s+(?:\d{4}|\d{4}-\d{4})\s+The Apache Software
Foundation", content, re.MULTILINE):
- issues.append("Missing or invalid copyright statement")
- if not re.search(
- r"This product includes software developed at\s*\nThe Apache Software
Foundation \(.*?\)", content, re.DOTALL
- ):
- issues.append("Missing or invalid foundation attribution")
-
- return len(issues) == 0, issues
+# Header helpers
-# Header helpers
+def _get_file_extension(filename: str) -> str | None:
+ """Get the file extension without the dot."""
+ _, ext = os.path.splitext(filename)
+ if not ext:
+ return None
+ return ext[1:].lower()
def _headers_check_core_logic(artifact_path: str) -> dict[str, Any]:
@@ -489,14 +497,6 @@ def _headers_check_core_logic_should_check(filepath: str)
-> bool:
return False
-def _get_file_extension(filename: str) -> str | None:
- """Get the file extension without the dot."""
- _, ext = os.path.splitext(filename)
- if not ext:
- return None
- return ext[1:].lower()
-
-
def _headers_validate(content: bytes, filename: str) -> tuple[bool, str |
None]:
"""Validate that the content contains the Apache License header after
removing comments."""
# Get the file extension from the filename
diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py
index 55eea0e..9f17564 100644
--- a/atr/tasks/checks/paths.py
+++ b/atr/tasks/checks/paths.py
@@ -29,6 +29,58 @@ import atr.util as util
_LOGGER: Final = logging.getLogger(__name__)
+async def check(args: checks.FunctionArguments) -> None:
+ """Check file path structure and naming conventions against ASF release
policy for all files in a release."""
+ # We refer to the following authoritative policies:
+ # - Release Creation Process (RCP)
+ # - Release Distribution Policy (RDP)
+
+ recorder_errors = await checks.Recorder.create(
+ checker=checks.function_key(check) + "_errors",
+ release_name=args.release_name,
+ draft_revision=args.draft_revision,
+ primary_rel_path=None,
+ afresh=True,
+ )
+ recorder_warnings = await checks.Recorder.create(
+ checker=checks.function_key(check) + "_warnings",
+ release_name=args.release_name,
+ draft_revision=args.draft_revision,
+ primary_rel_path=None,
+ afresh=True,
+ )
+ recorder_success = await checks.Recorder.create(
+ checker=checks.function_key(check) + "_success",
+ release_name=args.release_name,
+ draft_revision=args.draft_revision,
+ primary_rel_path=None,
+ afresh=True,
+ )
+
+ # As primary_rel_path is None, the base path is the release candidate
draft directory
+ if not (base_path := await recorder_success.abs_path()):
+ return
+
+ if not await aiofiles.os.path.isdir(base_path):
+ _LOGGER.error("Base release directory does not exist or is not a
directory: %s", base_path)
+ return
+
+ relative_paths = await util.paths_recursive(base_path)
+ relative_paths_set = set(str(p) for p in relative_paths)
+ for relative_path in relative_paths:
+ # Delegate processing of each path to the helper function
+ await _check_path_process_single(
+ base_path,
+ relative_path,
+ recorder_errors,
+ recorder_warnings,
+ recorder_success,
+ relative_paths_set,
+ )
+
+ return None
+
+
async def _check_artifact_rules(
base_path: pathlib.Path, relative_path: pathlib.Path, relative_paths:
set[str], errors: list[str]
) -> None:
@@ -138,55 +190,3 @@ async def _check_path_process_single(
await recorder_success.success(
"Path structure and naming conventions conform to policy", {},
primary_rel_path=relative_path_str
)
-
-
-async def check(args: checks.FunctionArguments) -> None:
- """Check file path structure and naming conventions against ASF release
policy for all files in a release."""
- # We refer to the following authoritative policies:
- # - Release Creation Process (RCP)
- # - Release Distribution Policy (RDP)
-
- recorder_errors = await checks.Recorder.create(
- checker=checks.function_key(check) + "_errors",
- release_name=args.release_name,
- draft_revision=args.draft_revision,
- primary_rel_path=None,
- afresh=True,
- )
- recorder_warnings = await checks.Recorder.create(
- checker=checks.function_key(check) + "_warnings",
- release_name=args.release_name,
- draft_revision=args.draft_revision,
- primary_rel_path=None,
- afresh=True,
- )
- recorder_success = await checks.Recorder.create(
- checker=checks.function_key(check) + "_success",
- release_name=args.release_name,
- draft_revision=args.draft_revision,
- primary_rel_path=None,
- afresh=True,
- )
-
- # As primary_rel_path is None, the base path is the release candidate
draft directory
- if not (base_path := await recorder_success.abs_path()):
- return
-
- if not await aiofiles.os.path.isdir(base_path):
- _LOGGER.error("Base release directory does not exist or is not a
directory: %s", base_path)
- return
-
- relative_paths = await util.paths_recursive(base_path)
- relative_paths_set = set(str(p) for p in relative_paths)
- for relative_path in relative_paths:
- # Delegate processing of each path to the helper function
- await _check_path_process_single(
- base_path,
- relative_path,
- recorder_errors,
- recorder_warnings,
- recorder_success,
- relative_paths_set,
- )
-
- return None
diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py
index 6a3ed62..3c54836 100644
--- a/atr/tasks/checks/targz.py
+++ b/atr/tasks/checks/targz.py
@@ -42,6 +42,25 @@ async def integrity(args: checks.FunctionArguments) -> str |
None:
return None
+def root_directory(tgz_path: str) -> str:
+ """Find the root directory in a tar archive and validate that it has only
one root dir."""
+ root = None
+
+ with tarfile.open(tgz_path, mode="r|gz") as tf:
+ for member in tf:
+ parts = member.name.split("/", 1)
+ if len(parts) >= 1:
+ if not root:
+ root = parts[0]
+ elif parts[0] != root:
+ raise ValueError(f"Multiple root directories found:
{root}, {parts[0]}")
+
+ if not root:
+ raise ValueError("No root directory found in archive")
+
+ return root
+
+
async def structure(args: checks.FunctionArguments) -> str | None:
"""Check the structure of a .tar.gz file."""
recorder = await args.recorder()
@@ -73,25 +92,6 @@ async def structure(args: checks.FunctionArguments) -> str |
None:
return None
-def root_directory(tgz_path: str) -> str:
- """Find the root directory in a tar archive and validate that it has only
one root dir."""
- root = None
-
- with tarfile.open(tgz_path, mode="r|gz") as tf:
- for member in tf:
- parts = member.name.split("/", 1)
- if len(parts) >= 1:
- if not root:
- root = parts[0]
- elif parts[0] != root:
- raise ValueError(f"Multiple root directories found:
{root}, {parts[0]}")
-
- if not root:
- raise ValueError("No root directory found in archive")
-
- return root
-
-
def _integrity_core(tgz_path: str, chunk_size: int = 4096) -> int:
"""Verify a .tar.gz file and compute its uncompressed size."""
total_size = 0
diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py
index 8ce63d5..d95fee1 100644
--- a/atr/tasks/checks/zipformat.py
+++ b/atr/tasks/checks/zipformat.py
@@ -145,26 +145,6 @@ def _integrity_check_core_logic(artifact_path: str) ->
dict[str, Any]:
return {"error": f"Unexpected error: {e}"}
-def _license_files_check_file_zip(zf: zipfile.ZipFile, artifact_path: str,
expected_path: str) -> tuple[bool, bool]:
- """Check for the presence and basic validity of a specific file in a
zip."""
- found = False
- valid = False
- try:
- with zf.open(expected_path) as file_handle:
- found = True
- content = file_handle.read().strip()
- if content:
- # TODO: Add more specific NOTICE checks if needed
- valid = True
- except KeyError:
- # File not found in zip
- ...
- except Exception as e:
- filename = os.path.basename(expected_path)
- _LOGGER.warning(f"Error reading {filename} in zip {artifact_path}:
{e}")
- return found, valid
-
-
def _license_files_check_core_logic_zip(artifact_path: str) -> dict[str, Any]:
"""Verify LICENSE and NOTICE files within a zip archive."""
# TODO: Obviously we want to reuse the license files check logic from
license.py
@@ -212,6 +192,26 @@ def _license_files_check_core_logic_zip(artifact_path:
str) -> dict[str, Any]:
return {"error": f"Unexpected error: {e}"}
+def _license_files_check_file_zip(zf: zipfile.ZipFile, artifact_path: str,
expected_path: str) -> tuple[bool, bool]:
+ """Check for the presence and basic validity of a specific file in a
zip."""
+ found = False
+ valid = False
+ try:
+ with zf.open(expected_path) as file_handle:
+ found = True
+ content = file_handle.read().strip()
+ if content:
+ # TODO: Add more specific NOTICE checks if needed
+ valid = True
+ except KeyError:
+ # File not found in zip
+ ...
+ except Exception as e:
+ filename = os.path.basename(expected_path)
+ _LOGGER.warning(f"Error reading {filename} in zip {artifact_path}:
{e}")
+ return found, valid
+
+
def _license_files_find_root_dir_zip(members: list[str]) -> str | None:
"""Find the root directory in a list of zip members."""
for member in members:
diff --git a/atr/tasks/sbom.py b/atr/tasks/sbom.py
index d8aa89f..c3d0c91 100644
--- a/atr/tasks/sbom.py
+++ b/atr/tasks/sbom.py
@@ -49,50 +49,6 @@ class GenerateCycloneDX(pydantic.BaseModel):
output_path: str = pydantic.Field(..., description="Absolute path where
the generated SBOM JSON should be written")
-def _archive_extract_safe_process_file(
- tf: tarfile.TarFile,
- member: tarfile.TarInfo,
- extract_dir: str,
- total_extracted: int,
- max_size: int,
- chunk_size: int,
-) -> int:
- """Process a single file member during safe archive extraction."""
- target_path = os.path.join(extract_dir, member.name)
- if not
os.path.abspath(target_path).startswith(os.path.abspath(extract_dir)):
- _LOGGER.warning(f"Skipping potentially unsafe path: {member.name}")
- return 0
-
- os.makedirs(os.path.dirname(target_path), exist_ok=True)
-
- source = tf.extractfile(member)
- if source is None:
- # Should not happen if member.isreg() is true
- _LOGGER.warning(f"Could not extract file object for member:
{member.name}")
- return 0
-
- extracted_file_size = 0
- try:
- with open(target_path, "wb") as target:
- while chunk := source.read(chunk_size):
- target.write(chunk)
- extracted_file_size += len(chunk)
-
- # Check size limits during extraction
- if (total_extracted + extracted_file_size) > max_size:
- # Clean up the partial file before raising
- target.close()
- os.unlink(target_path)
- raise SBOMGenerationError(
- f"Extraction exceeded maximum size limit of {max_size}
bytes",
- {"max_size": max_size, "current_size":
total_extracted},
- )
- finally:
- source.close()
-
- return extracted_file_size
-
-
def archive_extract_safe(
archive_path: str,
extract_dir: str,
@@ -155,6 +111,50 @@ async def generate_cyclonedx(args: GenerateCycloneDX) ->
str | None:
raise
+def _archive_extract_safe_process_file(
+ tf: tarfile.TarFile,
+ member: tarfile.TarInfo,
+ extract_dir: str,
+ total_extracted: int,
+ max_size: int,
+ chunk_size: int,
+) -> int:
+ """Process a single file member during safe archive extraction."""
+ target_path = os.path.join(extract_dir, member.name)
+ if not
os.path.abspath(target_path).startswith(os.path.abspath(extract_dir)):
+ _LOGGER.warning(f"Skipping potentially unsafe path: {member.name}")
+ return 0
+
+ os.makedirs(os.path.dirname(target_path), exist_ok=True)
+
+ source = tf.extractfile(member)
+ if source is None:
+ # Should not happen if member.isreg() is true
+ _LOGGER.warning(f"Could not extract file object for member:
{member.name}")
+ return 0
+
+ extracted_file_size = 0
+ try:
+ with open(target_path, "wb") as target:
+ while chunk := source.read(chunk_size):
+ target.write(chunk)
+ extracted_file_size += len(chunk)
+
+ # Check size limits during extraction
+ if (total_extracted + extracted_file_size) > max_size:
+ # Clean up the partial file before raising
+ target.close()
+ os.unlink(target_path)
+ raise SBOMGenerationError(
+ f"Extraction exceeded maximum size limit of {max_size}
bytes",
+ {"max_size": max_size, "current_size":
total_extracted},
+ )
+ finally:
+ source.close()
+
+ return extracted_file_size
+
+
async def _generate_cyclonedx_core(artifact_path: str, output_path: str) ->
dict[str, Any]:
"""Core logic to generate CycloneDX SBOM, raising SBOMGenerationError on
failure."""
_LOGGER.info(f"Generating CycloneDX SBOM for {artifact_path} ->
{output_path}")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]