This is an automated email from the ASF dual-hosted git repository. sbp pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push: new 67107ae Make further files consistent with style guidelines 67107ae is described below commit 67107ae2f2cb9e0b801ec90bae7c665c40623660 Author: Sean B. Palmer <s...@miscoranda.com> AuthorDate: Wed Mar 12 19:24:03 2025 +0200 Make further files consistent with style guidelines --- Makefile | 6 +- atr/blueprints/admin/templates/tasks.html | 3 +- atr/datasources/apache.py | 54 +++---- atr/db/__init__.py | 65 ++++---- atr/db/models.py | 2 + atr/db/service.py | 62 ++++---- atr/server.py | 6 +- atr/static/css/bootstrap.custom.css | 4 +- atr/tasks/archive.py | 8 +- atr/tasks/bulk.py | 253 +++++++++++++++--------------- atr/tasks/mailtest.py | 84 +++++----- atr/tasks/task.py | 4 +- atr/tasks/vote.py | 80 +++++----- atr/templates/candidate-review.html | 2 +- atr/util.py | 9 ++ 15 files changed, 330 insertions(+), 312 deletions(-) diff --git a/Makefile b/Makefile index 715f5be..450f68f 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-alpine-ubuntu certs check checks docs report run serve sync sync-dev +.PHONY: build build-alpine-ubuntu certs check docs report run serve sync sync-dev BIND ?= 127.0.0.1:8080 PYTHON ?= $(which python3) @@ -18,10 +18,6 @@ certs: check: $(SCRIPTS)/run pre-commit run --all-files -checks: - $(SCRIPTS)/run pre-commit run --all-files - poetry run pyright - docs: for fn in docs/*.md; \ do cmark "$$fn" > "$${fn%.md}.html"; \ diff --git a/atr/blueprints/admin/templates/tasks.html b/atr/blueprints/admin/templates/tasks.html index 8563686..1e3143b 100644 --- a/atr/blueprints/admin/templates/tasks.html +++ b/atr/blueprints/admin/templates/tasks.html @@ -47,7 +47,8 @@ ], style: { table: { - 'white-space': 'nowrap' + // TODO: Need a better fix here + // 'white-space': 'nowrap' } }, search: true, diff --git a/atr/datasources/apache.py b/atr/datasources/apache.py index 7466d4e..2d272d7 100644 --- a/atr/datasources/apache.py +++ b/atr/datasources/apache.py @@ -19,13 +19,13 @@ from __future__ import annotations -from datetime import datetime +import datetime from typing import TYPE_CHECKING, Annotated, TypeVar import httpx -from pydantic import BaseModel, Field, RootModel +import pydantic -from atr.util import DictToList +import atr.util as util if TYPE_CHECKING: from collections.abc import Generator, ItemsView @@ -40,20 +40,20 @@ _PROJECTS_GROUPS_URL = "https://projects.apache.org/json/foundation/groups.json" VT = TypeVar("VT") -class LDAPProjectsData(BaseModel): - last_timestamp: str = Field(alias="lastTimestamp") +class LDAPProjectsData(pydantic.BaseModel): + last_timestamp: str = pydantic.Field(alias="lastTimestamp") project_count: int - projects: Annotated[list[LDAPProject], DictToList(key="name")] + projects: Annotated[list[LDAPProject], util.DictToList(key="name")] @property - def last_time(self) -> datetime: - return datetime.strptime(self.last_timestamp, "%Y%m%d%H%M%S%z") + def last_time(self) -> datetime.datetime: + return datetime.datetime.strptime(self.last_timestamp, "%Y%m%d%H%M%S%z") -class LDAPProject(BaseModel): +class LDAPProject(pydantic.BaseModel): name: str - create_timestamp: str = Field(alias="createTimestamp") - modify_timestamp: str = Field(alias="modifyTimestamp") + create_timestamp: str = pydantic.Field(alias="createTimestamp") + modify_timestamp: str = pydantic.Field(alias="modifyTimestamp") member_count: int owner_count: int members: list[str] @@ -62,20 +62,20 @@ class LDAPProject(BaseModel): podling: str | None = None -class CommitteeData(BaseModel): +class CommitteeData(pydantic.BaseModel): last_updated: str committee_count: int pmc_count: int - committees: Annotated[list[Committee], DictToList(key="name")] + committees: Annotated[list[Committee], util.DictToList(key="name")] -class RetiredCommitteeData(BaseModel): +class RetiredCommitteeData(pydantic.BaseModel): last_updated: str retired_count: int - retired: Annotated[list[RetiredCommittee], DictToList(key="name")] + retired: Annotated[list[RetiredCommittee], util.DictToList(key="name")] -class Committee(BaseModel): +class Committee(pydantic.BaseModel): name: str display_name: str site: str @@ -83,29 +83,29 @@ class Committee(BaseModel): mail_list: str established: str report: list[str] - chair: Annotated[list[User], DictToList(key="id")] + chair: Annotated[list[User], util.DictToList(key="id")] roster_count: int - roster: Annotated[list[User], DictToList(key="id")] + roster: Annotated[list[User], util.DictToList(key="id")] pmc: bool -class User(BaseModel): +class User(pydantic.BaseModel): id: str name: str date: str | None = None -class RetiredCommittee(BaseModel): +class RetiredCommittee(pydantic.BaseModel): name: str display_name: str retired: str description: str -class PodlingStatus(BaseModel): +class PodlingStatus(pydantic.BaseModel): description: str homepage: str - name: str = Field(alias="name") + name: str = pydantic.Field(alias="name") pmc: str podling: bool started: str @@ -114,7 +114,7 @@ class PodlingStatus(BaseModel): resolution: str | None = None -class _DictRootModel(RootModel[dict[str, VT]]): +class _DictRootModel(pydantic.RootModel[dict[str, VT]]): def __iter__(self) -> Generator[tuple[str, VT]]: yield from self.root.items() @@ -136,13 +136,13 @@ class GroupsData(_DictRootModel[list[str]]): pass -class Release(BaseModel): +class Release(pydantic.BaseModel): created: str | None = None name: str revision: str | None = None -class ProjectStatus(BaseModel): +class ProjectStatus(pydantic.BaseModel): category: str | None = None created: str | None = None description: str | None = None @@ -151,8 +151,8 @@ class ProjectStatus(BaseModel): name: str pmc: str shortdesc: str | None = None - repository: list[str | dict] = Field(default_factory=list) - release: list[Release] = Field(default_factory=list) + repository: list[str | dict] = pydantic.Field(default_factory=list) + release: list[Release] = pydantic.Field(default_factory=list) class ProjectsData(_DictRootModel[ProjectStatus]): diff --git a/atr/db/__init__.py b/atr/db/__init__.py index a272180..7f59e86 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -17,22 +17,22 @@ import logging import os -from typing import Any +from typing import Any, Final +import alembic.config as config +import quart +import sqlalchemy +import sqlalchemy.ext.asyncio import sqlalchemy.orm as orm +import sqlalchemy.sql as sql +import sqlmodel -# from alembic import command -from alembic.config import Config -from quart import current_app -from sqlalchemy import Engine, create_engine -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import Session -from sqlalchemy.sql import text -from sqlmodel import SQLModel - +import atr.util as util from asfquart.base import QuartApp -_LOGGER = logging.getLogger(__name__) +_LOGGER: Final = logging.getLogger(__name__) + +_global_sync_engine: sqlalchemy.Engine | None = None def create_database(app: QuartApp) -> None: @@ -42,7 +42,7 @@ def create_database(app: QuartApp) -> None: sqlite_db_path = app.config["SQLITE_DB_PATH"] sqlite_url = f"sqlite+aiosqlite://{sqlite_db_path}" # Use aiosqlite for async SQLite access - engine = create_async_engine( + engine = sqlalchemy.ext.asyncio.create_async_engine( sqlite_url, connect_args={ "check_same_thread": False, @@ -51,24 +51,25 @@ def create_database(app: QuartApp) -> None: ) # Create async session factory - async_session = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) - app.async_session = async_session # type: ignore + app.extensions["async_session"] = sqlalchemy.ext.asyncio.async_sessionmaker( + bind=engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False + ) # Set SQLite pragmas for better performance # Use 64 MB for the cache_size, and 5000ms for busy_timeout async with engine.begin() as conn: - await conn.execute(text("PRAGMA journal_mode=WAL")) - await conn.execute(text("PRAGMA synchronous=NORMAL")) - await conn.execute(text("PRAGMA cache_size=-64000")) - await conn.execute(text("PRAGMA foreign_keys=ON")) - await conn.execute(text("PRAGMA busy_timeout=5000")) + await conn.execute(sql.text("PRAGMA journal_mode=WAL")) + await conn.execute(sql.text("PRAGMA synchronous=NORMAL")) + await conn.execute(sql.text("PRAGMA cache_size=-64000")) + await conn.execute(sql.text("PRAGMA foreign_keys=ON")) + await conn.execute(sql.text("PRAGMA busy_timeout=5000")) # Run any pending migrations # In dev we'd do this first: # poetry run alembic revision --autogenerate -m "description" # Then review the generated migration in migrations/versions/ and commit it alembic_ini_path = os.path.join(project_root, "alembic.ini") - alembic_cfg = Config(alembic_ini_path) + alembic_cfg = config.Config(alembic_ini_path) # Override the migrations directory location to use project root # TODO: Is it possible to set this in alembic.ini? alembic_cfg.set_main_option("script_location", os.path.join(project_root, "migrations")) @@ -78,34 +79,34 @@ def create_database(app: QuartApp) -> None: # Create any tables that might be missing async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) + await conn.run_sync(sqlmodel.SQLModel.metadata.create_all) -def create_async_db_session() -> AsyncSession: +def create_async_db_session() -> sqlalchemy.ext.asyncio.AsyncSession: """Create a new asynchronous database session.""" - return current_app.async_session() # type: ignore - - -_SYNC_ENGINE: Engine | None = None + async_session = util.validate_as_type( + quart.current_app.extensions["async_session"](), sqlalchemy.ext.asyncio.AsyncSession + ) + return async_session def create_sync_db_engine() -> None: """Create a synchronous database engine.""" import atr.config as config - global _SYNC_ENGINE + global _global_sync_engine conf = config.get() sqlite_url = f"sqlite://{conf.SQLITE_DB_PATH}" _LOGGER.debug(f"Creating sync database engine in process {os.getpid()}") - _SYNC_ENGINE = create_engine(sqlite_url, echo=False) + _global_sync_engine = sqlalchemy.create_engine(sqlite_url, echo=False) -def create_sync_db_session() -> Session: +def create_sync_db_session() -> sqlalchemy.orm.Session: """Create a new synchronous database session.""" - global _SYNC_ENGINE - assert _SYNC_ENGINE is not None - return Session(_SYNC_ENGINE) + global _global_sync_engine + assert _global_sync_engine is not None + return sqlalchemy.orm.Session(_global_sync_engine) def select_in_load(*entities: Any) -> orm.strategy_options._AbstractLoad: diff --git a/atr/db/models.py b/atr/db/models.py index 4ca55f9..59ab1ef 100644 --- a/atr/db/models.py +++ b/atr/db/models.py @@ -17,6 +17,8 @@ """The data models to be persisted in the database.""" +# from __future__ import annotations + import datetime import enum from typing import Any diff --git a/atr/db/service.py b/atr/db/service.py index 514f6ef..9ec7970 100644 --- a/atr/db/service.py +++ b/atr/db/service.py @@ -15,70 +15,70 @@ # specific language governing permissions and limitations # under the License. +import contextlib from collections.abc import Sequence -from contextlib import nullcontext -from sqlalchemy import func -from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import select +import sqlalchemy +import sqlalchemy.ext.asyncio +import sqlmodel import atr.db as db -from atr.db.models import PMC, Release, Task +import atr.db.models as models -from . import create_async_db_session - -async def get_pmc_by_name(project_name: str, session: AsyncSession | None = None) -> PMC | None: +async def get_pmc_by_name( + project_name: str, session: sqlalchemy.ext.asyncio.AsyncSession | None = None +) -> models.PMC | None: """Returns a PMC object by name.""" - async with create_async_db_session() if session is None else nullcontext(session) as db_session: - statement = select(PMC).where(PMC.project_name == project_name) + async with db.create_async_db_session() if session is None else contextlib.nullcontext(session) as db_session: + statement = sqlmodel.select(models.PMC).where(models.PMC.project_name == project_name) pmc = (await db_session.execute(statement)).scalar_one_or_none() return pmc -async def get_pmcs(session: AsyncSession | None = None) -> Sequence[PMC]: +async def get_pmcs(session: sqlalchemy.ext.asyncio.AsyncSession | None = None) -> Sequence[models.PMC]: """Returns a list of PMC objects.""" - async with create_async_db_session() if session is None else nullcontext(session) as db_session: + async with db.create_async_db_session() if session is None else contextlib.nullcontext(session) as db_session: # Get all PMCs and their latest releases - statement = select(PMC).order_by(PMC.project_name) + statement = sqlmodel.select(models.PMC).order_by(models.PMC.project_name) pmcs = (await db_session.execute(statement)).scalars().all() return pmcs -async def get_release_by_key(storage_key: str) -> Release | None: +async def get_release_by_key(storage_key: str) -> models.Release | None: """Get a release by its storage key.""" - async with create_async_db_session() as db_session: + async with db.create_async_db_session() as db_session: # Get the release with its PMC and product line query = ( - select(Release) - .where(Release.storage_key == storage_key) - .options(db.select_in_load(Release.pmc)) - .options(db.select_in_load(Release.product_line)) + sqlmodel.select(models.Release) + .where(models.Release.storage_key == storage_key) + .options(db.select_in_load(models.Release.pmc)) + .options(db.select_in_load(models.Release.product_line)) ) result = await db_session.execute(query) return result.scalar_one_or_none() -def get_release_by_key_sync(storage_key: str) -> Release | None: +def get_release_by_key_sync(storage_key: str) -> models.Release | None: """Synchronous version of get_release_by_key for use in background tasks.""" - from atr.db import create_sync_db_session - - with create_sync_db_session() as session: + with db.create_sync_db_session() as session: # Get the release with its PMC and product line query = ( - select(Release) - .where(Release.storage_key == storage_key) - .options(db.select_in_load(Release.pmc)) - .options(db.select_in_load(Release.product_line)) + sqlmodel.select(models.Release) + .where(models.Release.storage_key == storage_key) + .options(db.select_in_load(models.Release.pmc)) + .options(db.select_in_load(models.Release.product_line)) ) result = session.execute(query) return result.scalar_one_or_none() -async def get_tasks(limit: int, offset: int, session: AsyncSession | None = None) -> tuple[Sequence[Task], int]: +async def get_tasks( + limit: int, offset: int, session: sqlalchemy.ext.asyncio.AsyncSession | None = None +) -> tuple[Sequence[models.Task], int]: """Returns a list of Tasks based on limit and offset values together with the total count.""" - async with create_async_db_session() if session is None else nullcontext(session) as db_session: - statement = select(Task).limit(limit).offset(offset).order_by(Task.id.desc()) # type: ignore + async with db.create_async_db_session() if session is None else contextlib.nullcontext(session) as db_session: + statement = sqlmodel.select(models.Task).limit(limit).offset(offset).order_by(models.Task.id.desc()) # type: ignore tasks = (await db_session.execute(statement)).scalars().all() - count = (await db_session.execute(select(func.count(Task.id)))).scalar_one() # type: ignore + count = (await db_session.execute(sqlalchemy.select(sqlalchemy.func.count(models.Task.id)))).scalar_one() # type: ignore return tasks, count diff --git a/atr/server.py b/atr/server.py index 1601d69..3df1727 100644 --- a/atr/server.py +++ b/atr/server.py @@ -69,7 +69,11 @@ def register_routes(app: base.QuartApp) -> tuple[str, ...]: @app.errorhandler(ASFQuartException) async def handle_asfquart_exception(error: ASFQuartException) -> Any: - errorcode = error.errorcode # type: ignore + # TODO: Figure out why pyright doesn't know about this attribute + if not hasattr(error, "errorcode"): + errorcode = 500 + else: + errorcode = getattr(error, "errorcode") return await render_template("error.html", error=str(error), status_code=errorcode), errorcode # Add a global error handler in case a page does not exist. diff --git a/atr/static/css/bootstrap.custom.css b/atr/static/css/bootstrap.custom.css index 7fc649c..5af14d2 100644 --- a/atr/static/css/bootstrap.custom.css +++ b/atr/static/css/bootstrap.custom.css @@ -6986,7 +6986,7 @@ textarea.form-control-lg { border-top: 0 !important; } -.border-end, th { +.border-end, table.atr-data th { border-right: var(--bs-border-width) var(--bs-border-style) var(--bs-border-color) !important; } @@ -8321,7 +8321,7 @@ textarea.form-control-lg { background-color: rgba(var(--bs-secondary-bg-rgb), var(--bs-bg-opacity)) !important; } -.bg-body-tertiary, th { +.bg-body-tertiary, table.atr-data th { --bs-bg-opacity: 1; background-color: rgba(var(--bs-tertiary-bg-rgb), var(--bs-bg-opacity)) !important; } diff --git a/atr/tasks/archive.py b/atr/tasks/archive.py index 3f8ee59..ae057be 100644 --- a/atr/tasks/archive.py +++ b/atr/tasks/archive.py @@ -20,18 +20,18 @@ import os.path import tarfile from typing import Any, Final -from pydantic import BaseModel, Field +import pydantic import atr.tasks.task as task _LOGGER = logging.getLogger(__name__) -class CheckIntegrity(BaseModel): +class CheckIntegrity(pydantic.BaseModel): """Parameters for archive integrity checking.""" - path: str = Field(..., description="Path to the .tar.gz file to check") - chunk_size: int = Field(default=4096, description="Size of chunks to read when checking the file") + path: str = pydantic.Field(..., description="Path to the .tar.gz file to check") + chunk_size: int = pydantic.Field(default=4096, description="Size of chunks to read when checking the file") def check_integrity(args: dict[str, Any]) -> tuple[task.Status, str | None, tuple[Any, ...]]: diff --git a/atr/tasks/bulk.py b/atr/tasks/bulk.py index 2654920..537d5d4 100644 --- a/atr/tasks/bulk.py +++ b/atr/tasks/bulk.py @@ -16,24 +16,24 @@ # under the License. import asyncio +import dataclasses +import html.parser import json import logging import os -from dataclasses import dataclass -from html.parser import HTMLParser -from typing import Any -from urllib.parse import urljoin +import urllib.parse +from typing import Any, Final import aiofiles import aiohttp -from sqlalchemy import text -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +import sqlalchemy +import sqlalchemy.ext.asyncio import atr.tasks.task as task # Configure detailed logging -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +_LOGGER: Final = logging.getLogger(__name__) +_LOGGER.setLevel(logging.DEBUG) # Create file handler for test.log file_handler = logging.FileHandler("tasks-bulk.log") @@ -45,17 +45,18 @@ formatter = logging.Formatter( datefmt="%Y-%m-%d %H:%M:%S", ) file_handler.setFormatter(formatter) -logger.addHandler(file_handler) +_LOGGER.addHandler(file_handler) # Ensure parent loggers don't duplicate messages -logger.propagate = False +_LOGGER.propagate = False -logger.info("Bulk download module imported") +_LOGGER.info("Bulk download module imported") -global_db_connection: async_sessionmaker | None = None +global_db_connection: sqlalchemy.ext.asyncio.async_sessionmaker | None = None global_task_id: int | None = None -@dataclass +# TODO: Use a Pydantic model instead +@dataclasses.dataclass class Args: release_key: str base_url: str @@ -67,10 +68,10 @@ class Args: @staticmethod def from_list(args: list[str]) -> "Args": """Parse command line arguments.""" - logger.debug(f"Parsing arguments: {args}") + _LOGGER.debug(f"Parsing arguments: {args}") if len(args) != 6: - logger.error(f"Invalid number of arguments: {len(args)}, expected 6") + _LOGGER.error(f"Invalid number of arguments: {len(args)}, expected 6") raise ValueError("Invalid number of arguments") release_key = args[0] @@ -80,36 +81,36 @@ class Args: max_depth = args[4] max_concurrent = args[5] - logger.debug( + _LOGGER.debug( f"Extracted values - release_key: {release_key}, base_url: {base_url}, " f"file_types: {file_types}, require_sigs: {require_sigs}, " f"max_depth: {max_depth}, max_concurrent: {max_concurrent}" ) if not isinstance(release_key, str): - logger.error(f"Release key must be a string, got {type(release_key)}") + _LOGGER.error(f"Release key must be a string, got {type(release_key)}") raise ValueError("Release key must be a string") if not isinstance(base_url, str): - logger.error(f"Base URL must be a string, got {type(base_url)}") + _LOGGER.error(f"Base URL must be a string, got {type(base_url)}") raise ValueError("Base URL must be a string") if not isinstance(file_types, list): - logger.error(f"File types must be a list, got {type(file_types)}") + _LOGGER.error(f"File types must be a list, got {type(file_types)}") raise ValueError("File types must be a list") for arg in file_types: if not isinstance(arg, str): - logger.error(f"File types must be a list of strings, got {type(arg)}") + _LOGGER.error(f"File types must be a list of strings, got {type(arg)}") raise ValueError("File types must be a list of strings") if not isinstance(require_sigs, bool): - logger.error(f"Require sigs must be a boolean, got {type(require_sigs)}") + _LOGGER.error(f"Require sigs must be a boolean, got {type(require_sigs)}") raise ValueError("Require sigs must be a boolean") if not isinstance(max_depth, int): - logger.error(f"Max depth must be an integer, got {type(max_depth)}") + _LOGGER.error(f"Max depth must be an integer, got {type(max_depth)}") raise ValueError("Max depth must be an integer") if not isinstance(max_concurrent, int): - logger.error(f"Max concurrent must be an integer, got {type(max_concurrent)}") + _LOGGER.error(f"Max concurrent must be an integer, got {type(max_concurrent)}") raise ValueError("Max concurrent must be an integer") - logger.debug("All argument validations passed") + _LOGGER.debug("All argument validations passed") args_obj = Args( release_key=release_key, @@ -120,29 +121,29 @@ class Args: max_concurrent=max_concurrent, ) - logger.info(f"Args object created: {args_obj}") + _LOGGER.info(f"Args object created: {args_obj}") return args_obj async def database_message(msg: str, progress: tuple[int, int] | None = None) -> None: """Update database with message and progress.""" - logger.debug(f"Updating database with message: '{msg}', progress: {progress}") + _LOGGER.debug(f"Updating database with message: '{msg}', progress: {progress}") try: task_id = await database_task_id_get() if task_id: - logger.debug(f"Found task_id: {task_id}, updating with message") + _LOGGER.debug(f"Found task_id: {task_id}, updating with message") await database_task_update(task_id, msg, progress) else: - logger.warning("No task ID found, skipping database update") + _LOGGER.warning("No task ID found, skipping database update") except Exception as e: # We don't raise here # We continue even if database updates fail # But in this case, the user won't be informed on the update page - logger.exception(f"Failed to update database: {e}") - logger.info(f"Continuing despite database error. Message was: '{msg}'") + _LOGGER.exception(f"Failed to update database: {e}") + _LOGGER.info(f"Continuing despite database error. Message was: '{msg}'") -async def get_db_session() -> AsyncSession: +async def get_db_session() -> sqlalchemy.ext.asyncio.AsyncSession: """Get a reusable database session.""" global global_db_connection @@ -150,103 +151,105 @@ async def get_db_session() -> AsyncSession: # Create connection only if it doesn't exist already if global_db_connection is None: db_url = "sqlite+aiosqlite:///atr.db" - logger.debug(f"Creating database engine: {db_url}") + _LOGGER.debug(f"Creating database engine: {db_url}") - engine = create_async_engine(db_url) - global_db_connection = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + engine = sqlalchemy.ext.asyncio.create_async_engine(db_url) + global_db_connection = sqlalchemy.ext.asyncio.async_sessionmaker( + engine, class_=sqlalchemy.ext.asyncio.AsyncSession, expire_on_commit=False + ) - connection: AsyncSession = global_db_connection() + connection: sqlalchemy.ext.asyncio.AsyncSession = global_db_connection() return connection except Exception as e: - logger.exception(f"Error creating database session: {e}") + _LOGGER.exception(f"Error creating database session: {e}") raise async def database_task_id_get() -> int | None: """Get current task ID asynchronously with caching.""" global global_task_id - logger.debug("Attempting to get current task ID") + _LOGGER.debug("Attempting to get current task ID") # Return cached ID if available if global_task_id is not None: - logger.debug(f"Using cached task ID: {global_task_id}") + _LOGGER.debug(f"Using cached task ID: {global_task_id}") return global_task_id try: from os import getpid process_id = getpid() - logger.debug(f"Current process ID: {process_id}") + _LOGGER.debug(f"Current process ID: {process_id}") task_id = await database_task_pid_lookup(process_id) if task_id: - logger.info(f"Found task ID: {task_id} for process ID: {process_id}") + _LOGGER.info(f"Found task ID: {task_id} for process ID: {process_id}") # Cache the task ID for future use global_task_id = task_id else: - logger.warning(f"No task found for process ID: {process_id}") + _LOGGER.warning(f"No task found for process ID: {process_id}") return task_id except Exception as e: - logger.exception(f"Error getting task ID: {e}") + _LOGGER.exception(f"Error getting task ID: {e}") return None async def database_task_pid_lookup(process_id: int) -> int | None: """Look up task ID by process ID asynchronously.""" - logger.debug(f"Looking up task ID for process ID: {process_id}") + _LOGGER.debug(f"Looking up task ID for process ID: {process_id}") try: async with await get_db_session() as session: - logger.debug(f"Executing SQL query to find task for PID: {process_id}") + _LOGGER.debug(f"Executing SQL query to find task for PID: {process_id}") # Look for ACTIVE task with our PID result = await session.execute( - text(""" + sqlalchemy.text(""" SELECT id FROM task WHERE pid = :pid AND status = 'ACTIVE' LIMIT 1 """), {"pid": process_id}, ) - logger.debug("SQL query executed, fetching results") + _LOGGER.debug("SQL query executed, fetching results") row = result.fetchone() if row: - logger.info(f"Found task ID: {row[0]} for process ID: {process_id}") + _LOGGER.info(f"Found task ID: {row[0]} for process ID: {process_id}") row_one = row[0] if not isinstance(row_one, int): - logger.error(f"Task ID is not an integer: {row_one}") + _LOGGER.error(f"Task ID is not an integer: {row_one}") raise ValueError("Task ID is not an integer") return row_one else: - logger.warning(f"No ACTIVE task found for process ID: {process_id}") + _LOGGER.warning(f"No ACTIVE task found for process ID: {process_id}") return None except Exception as e: - logger.exception(f"Error looking up task by PID: {e}") + _LOGGER.exception(f"Error looking up task by PID: {e}") return None async def database_task_update(task_id: int, msg: str, progress: tuple[int, int] | None) -> None: """Update task in database with message and progress.""" - logger.debug(f"Updating task {task_id} with message: '{msg}', progress: {progress}") + _LOGGER.debug(f"Updating task {task_id} with message: '{msg}', progress: {progress}") # Convert progress to percentage progress_pct = database_progress_percentage_calculate(progress) - logger.debug(f"Calculated progress percentage: {progress_pct}%") + _LOGGER.debug(f"Calculated progress percentage: {progress_pct}%") await database_task_update_execute(task_id, msg, progress_pct) async def database_task_update_execute(task_id: int, msg: str, progress_pct: int) -> None: """Execute database update with message and progress.""" - logger.debug(f"Executing database update for task {task_id}, message: '{msg}', progress: {progress_pct}%") + _LOGGER.debug(f"Executing database update for task {task_id}, message: '{msg}', progress: {progress_pct}%") try: async with await get_db_session() as session: - logger.debug(f"Executing SQL UPDATE for task ID: {task_id}") + _LOGGER.debug(f"Executing SQL UPDATE for task ID: {task_id}") # Store progress info in the result column as JSON result_data = json.dumps({"message": msg, "progress": progress_pct}) await session.execute( - text(""" + sqlalchemy.text(""" UPDATE task SET result = :result WHERE id = :task_id @@ -257,28 +260,28 @@ async def database_task_update_execute(task_id: int, msg: str, progress_pct: int }, ) await session.commit() - logger.info(f"Successfully updated task {task_id} with progress {progress_pct}%") + _LOGGER.info(f"Successfully updated task {task_id} with progress {progress_pct}%") except Exception as e: # Continue even if database update fails - logger.exception(f"Error updating task {task_id} in database: {e}") + _LOGGER.exception(f"Error updating task {task_id} in database: {e}") def database_progress_percentage_calculate(progress: tuple[int, int] | None) -> int: """Calculate percentage from progress tuple.""" - logger.debug(f"Calculating percentage from progress tuple: {progress}") + _LOGGER.debug(f"Calculating percentage from progress tuple: {progress}") if progress is None: - logger.debug("Progress is None, returning 0%") + _LOGGER.debug("Progress is None, returning 0%") return 0 current, total = progress # Avoid division by zero if total == 0: - logger.warning("Total is zero in progress tuple, avoiding division by zero") + _LOGGER.warning("Total is zero in progress tuple, avoiding division by zero") return 0 percentage = min(100, int((current / total) * 100)) - logger.debug(f"Calculated percentage: {percentage}% ({current}/{total})") + _LOGGER.debug(f"Calculated percentage: {percentage}% ({current}/{total})") return percentage @@ -287,28 +290,28 @@ def download(args: list[str]) -> tuple[task.Status, str | None, tuple[Any, ...]] # Returns (status, error, result) # This is the main task entry point, called by worker.py # This function should probably be called artifacts_download - logger.info(f"Starting bulk download task with args: {args}") + _LOGGER.info(f"Starting bulk download task with args: {args}") try: - logger.debug("Delegating to download_core function") + _LOGGER.debug("Delegating to download_core function") status, error, result = download_core(args) - logger.info(f"Download completed with status: {status}") + _LOGGER.info(f"Download completed with status: {status}") return status, error, result except Exception as e: - logger.exception(f"Error in download function: {e}") + _LOGGER.exception(f"Error in download function: {e}") # Return a tuple with a dictionary that matches what the template expects return task.FAILED, str(e), ({"message": f"Error: {e}", "progress": 0},) def download_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[Any, ...]]: """Download bulk package from URL.""" - logger.info("Starting download_core") + _LOGGER.info("Starting download_core") try: - logger.debug(f"Parsing arguments: {args_list}") + _LOGGER.debug(f"Parsing arguments: {args_list}") args = Args.from_list(args_list) - logger.info(f"Args parsed successfully: release_key={args.release_key}, base_url={args.base_url}") + _LOGGER.info(f"Args parsed successfully: release_key={args.release_key}, base_url={args.base_url}") # Create async resources - logger.debug("Creating async queue and semaphore") + _LOGGER.debug("Creating async queue and semaphore") queue: asyncio.Queue[str] = asyncio.Queue() semaphore = asyncio.Semaphore(args.max_concurrent) loop = asyncio.get_event_loop() @@ -316,15 +319,15 @@ def download_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[ # Start URL crawling loop.run_until_complete(database_message(f"Crawling URLs from {args.base_url}")) - logger.info("Starting artifact_urls coroutine") + _LOGGER.info("Starting artifact_urls coroutine") signatures, artifacts = loop.run_until_complete(artifact_urls(args, queue, semaphore)) - logger.info(f"Found {len(signatures)} signatures and {len(artifacts)} artifacts") + _LOGGER.info(f"Found {len(signatures)} signatures and {len(artifacts)} artifacts") # Update progress for download phase loop.run_until_complete(database_message(f"Found {len(artifacts)} artifacts to download")) # Download artifacts - logger.info("Starting artifacts_download coroutine") + _LOGGER.info("Starting artifacts_download coroutine") artifacts_downloaded = loop.run_until_complete(artifacts_download(artifacts, semaphore)) files_downloaded = len(artifacts_downloaded) @@ -345,7 +348,7 @@ def download_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[ ) except Exception as e: - logger.exception(f"Error in download_core: {e}") + _LOGGER.exception(f"Error in download_core: {e}") return ( task.FAILED, str(e), @@ -359,23 +362,23 @@ def download_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[ async def artifact_urls(args: Args, queue: asyncio.Queue, semaphore: asyncio.Semaphore) -> tuple[list[str], list[str]]: - logger.info(f"Starting URL crawling from {args.base_url}") + _LOGGER.info(f"Starting URL crawling from {args.base_url}") await database_message(f"Crawling artifact URLs from {args.base_url}") signatures: list[str] = [] artifacts: list[str] = [] seen: set[str] = set() - logger.debug(f"Adding base URL to queue: {args.base_url}") + _LOGGER.debug(f"Adding base URL to queue: {args.base_url}") await queue.put(args.base_url) - logger.debug("Starting crawl loop") + _LOGGER.debug("Starting crawl loop") depth = 0 # Start with just the base URL urls_at_current_depth = 1 urls_at_next_depth = 0 while (not queue.empty()) and (depth < args.max_depth): - logger.debug(f"Processing depth {depth + 1}/{args.max_depth}, queue size: {queue.qsize()}") + _LOGGER.debug(f"Processing depth {depth + 1}/{args.max_depth}, queue size: {queue.qsize()}") # Process all URLs at the current depth before moving to the next for _ in range(urls_at_current_depth): @@ -383,27 +386,27 @@ async def artifact_urls(args: Args, queue: asyncio.Queue, semaphore: asyncio.Sem break url = await queue.get() - logger.debug(f"Processing URL: {url}") + _LOGGER.debug(f"Processing URL: {url}") if url_excluded(seen, url, args): continue seen.add(url) - logger.debug(f"Checking URL for file types: {args.file_types}") + _LOGGER.debug(f"Checking URL for file types: {args.file_types}") # If not a target file type, try to parse HTML links if not check_matches(args, url, artifacts, signatures): - logger.debug(f"URL is not a target file, parsing HTML: {url}") + _LOGGER.debug(f"URL is not a target file, parsing HTML: {url}") try: new_urls = await download_html(url, semaphore) - logger.debug(f"Found {len(new_urls)} new URLs in {url}") + _LOGGER.debug(f"Found {len(new_urls)} new URLs in {url}") for new_url in new_urls: if new_url not in seen: - logger.debug(f"Adding new URL to queue: {new_url}") + _LOGGER.debug(f"Adding new URL to queue: {new_url}") await queue.put(new_url) urls_at_next_depth += 1 except Exception as e: - logger.warning(f"Error parsing HTML from {url}: {e}") + _LOGGER.warning(f"Error parsing HTML from {url}: {e}") # Move to next depth depth += 1 urls_at_current_depth = urls_at_next_depth @@ -412,20 +415,20 @@ async def artifact_urls(args: Args, queue: asyncio.Queue, semaphore: asyncio.Sem # Update database with progress message progress_msg = f"Crawled {len(seen)} URLs, found {len(artifacts)} artifacts (depth {depth}/{args.max_depth})" await database_message(progress_msg, progress=(30 + min(50, depth * 10), 100)) - logger.debug(f"Moving to depth {depth + 1}, {urls_at_current_depth} URLs to process") + _LOGGER.debug(f"Moving to depth {depth + 1}, {urls_at_current_depth} URLs to process") - logger.info(f"URL crawling complete. Found {len(artifacts)} artifacts and {len(signatures)} signatures") + _LOGGER.info(f"URL crawling complete. Found {len(artifacts)} artifacts and {len(signatures)} signatures") return signatures, artifacts def check_matches(args: Args, url: str, artifacts: list[str], signatures: list[str]) -> bool: for type in args.file_types: if url.endswith(type): - logger.info(f"Found artifact: {url}") + _LOGGER.info(f"Found artifact: {url}") artifacts.append(url) return True elif url.endswith(type + ".asc"): - logger.info(f"Found signature: {url}") + _LOGGER.info(f"Found signature: {url}") signatures.append(url) return True return False @@ -436,16 +439,16 @@ def url_excluded(seen: set[str], url: str, args: Args) -> bool: sorting_patterns = ["?C=N;O=", "?C=M;O=", "?C=S;O=", "?C=D;O="] if not url.startswith(args.base_url): - logger.debug(f"Skipping URL outside base URL scope: {url}") + _LOGGER.debug(f"Skipping URL outside base URL scope: {url}") return True if url in seen: - logger.debug(f"Skipping already seen URL: {url}") + _LOGGER.debug(f"Skipping already seen URL: {url}") return True # Skip sorting URLs to avoid redundant crawling if any(pattern in url for pattern in sorting_patterns): - logger.debug(f"Skipping sorting URL: {url}") + _LOGGER.debug(f"Skipping sorting URL: {url}") return True return False @@ -453,44 +456,44 @@ def url_excluded(seen: set[str], url: str, args: Args) -> bool: async def download_html(url: str, semaphore: asyncio.Semaphore) -> list[str]: """Download HTML and extract links.""" - logger.debug(f"Downloading HTML from: {url}") + _LOGGER.debug(f"Downloading HTML from: {url}") try: return await download_html_core(url, semaphore) except Exception as e: - logger.error(f"Error downloading HTML from {url}: {e}") + _LOGGER.error(f"Error downloading HTML from {url}: {e}") return [] async def download_html_core(url: str, semaphore: asyncio.Semaphore) -> list[str]: """Core HTML download and link extraction logic.""" - logger.debug(f"Starting HTML download core for {url}") + _LOGGER.debug(f"Starting HTML download core for {url}") async with semaphore: - logger.debug(f"Acquired semaphore for {url}") + _LOGGER.debug(f"Acquired semaphore for {url}") urls = [] async with aiohttp.ClientSession() as session: - logger.debug(f"Created HTTP session for {url}") + _LOGGER.debug(f"Created HTTP session for {url}") async with session.get(url) as response: if response.status != 200: - logger.warning(f"HTTP {response.status} for {url}") + _LOGGER.warning(f"HTTP {response.status} for {url}") return [] - logger.debug(f"Received HTTP 200 for {url}, content type: {response.content_type}") + _LOGGER.debug(f"Received HTTP 200 for {url}, content type: {response.content_type}") if not response.content_type.startswith("text/html"): - logger.debug(f"Not HTML content: {response.content_type}, skipping link extraction") + _LOGGER.debug(f"Not HTML content: {response.content_type}, skipping link extraction") return [] - logger.debug(f"Reading HTML content from {url}") + _LOGGER.debug(f"Reading HTML content from {url}") html = await response.text() urls = extract_links_from_html(html, url) - logger.debug(f"Extracted {len(urls)} processed links from {url}") + _LOGGER.debug(f"Extracted {len(urls)} processed links from {url}") return urls -class LinkExtractor(HTMLParser): +class LinkExtractor(html.parser.HTMLParser): def __init__(self) -> None: super().__init__() self.links: list[str] = [] @@ -507,18 +510,18 @@ def extract_links_from_html(html: str, base_url: str) -> list[str]: parser = LinkExtractor() parser.feed(html) raw_links = parser.links - logger.debug(f"Found {len(raw_links)} raw links in {base_url}") + _LOGGER.debug(f"Found {len(raw_links)} raw links in {base_url}") processed_urls = [] for link in raw_links: - processed_url = urljoin(base_url, link) + processed_url = urllib.parse.urljoin(base_url, link) # Filter out URLs that don't start with the base URL # We also check this elsewhere amongst other checks # But it's good to filter them early if processed_url.startswith(base_url): processed_urls.append(processed_url) else: - logger.debug(f"Skipping URL outside base URL scope: {processed_url}") + _LOGGER.debug(f"Skipping URL outside base URL scope: {processed_url}") return processed_urls @@ -526,45 +529,45 @@ def extract_links_from_html(html: str, base_url: str) -> list[str]: async def artifacts_download(artifacts: list[str], semaphore: asyncio.Semaphore) -> list[str]: """Download artifacts with progress tracking.""" size = len(artifacts) - logger.info(f"Starting download of {size} artifacts") + _LOGGER.info(f"Starting download of {size} artifacts") downloaded = [] for i, artifact in enumerate(artifacts): progress_percent = int((i / size) * 100) if (size > 0) else 100 progress_msg = f"Downloading {i + 1}/{size} artifacts" - logger.info(f"{progress_msg}: {artifact}") + _LOGGER.info(f"{progress_msg}: {artifact}") await database_message(progress_msg, progress=(progress_percent, 100)) success = await artifact_download(artifact, semaphore) if success: - logger.debug(f"Successfully downloaded: {artifact}") + _LOGGER.debug(f"Successfully downloaded: {artifact}") downloaded.append(artifact) else: - logger.warning(f"Failed to download: {artifact}") + _LOGGER.warning(f"Failed to download: {artifact}") - logger.info(f"Download complete. Successfully downloaded {len(downloaded)}/{size} artifacts") + _LOGGER.info(f"Download complete. Successfully downloaded {len(downloaded)}/{size} artifacts") await database_message(f"Downloaded {len(downloaded)} artifacts", progress=(100, 100)) return downloaded async def artifact_download(url: str, semaphore: asyncio.Semaphore) -> bool: - logger.debug(f"Starting download of artifact: {url}") + _LOGGER.debug(f"Starting download of artifact: {url}") try: success = await artifact_download_core(url, semaphore) if success: - logger.info(f"Successfully downloaded artifact: {url}") + _LOGGER.info(f"Successfully downloaded artifact: {url}") else: - logger.warning(f"Failed to download artifact: {url}") + _LOGGER.warning(f"Failed to download artifact: {url}") return success except Exception as e: - logger.exception(f"Error downloading artifact {url}: {e}") + _LOGGER.exception(f"Error downloading artifact {url}: {e}") return False async def artifact_download_core(url: str, semaphore: asyncio.Semaphore) -> bool: - logger.debug(f"Starting core download process for {url}") + _LOGGER.debug(f"Starting core download process for {url}") async with semaphore: - logger.debug(f"Acquired semaphore for {url}") + _LOGGER.debug(f"Acquired semaphore for {url}") # TODO: We flatten the hierarchy to get the filename # We should preserve the hierarchy filename = url.split("/")[-1] @@ -575,23 +578,23 @@ async def artifact_download_core(url: str, semaphore: asyncio.Semaphore) -> bool # Create download directory if it doesn't exist # TODO: Check whether local_path itself exists first os.makedirs("downloads", exist_ok=True) - logger.debug(f"Downloading {url} to {local_path}") + _LOGGER.debug(f"Downloading {url} to {local_path}") try: async with aiohttp.ClientSession() as session: - logger.debug(f"Created HTTP session for {url}") + _LOGGER.debug(f"Created HTTP session for {url}") async with session.get(url) as response: if response.status != 200: - logger.warning(f"Failed to download {url}: HTTP {response.status}") + _LOGGER.warning(f"Failed to download {url}: HTTP {response.status}") return False total_size = int(response.headers.get("Content-Length", 0)) if total_size: - logger.info(f"Content-Length: {total_size} bytes for {url}") + _LOGGER.info(f"Content-Length: {total_size} bytes for {url}") chunk_size = 8192 downloaded = 0 - logger.debug(f"Writing file to {local_path} with chunk size {chunk_size}") + _LOGGER.debug(f"Writing file to {local_path} with chunk size {chunk_size}") async with aiofiles.open(local_path, "wb") as f: async for chunk in response.content.iter_chunked(chunk_size): @@ -600,21 +603,21 @@ async def artifact_download_core(url: str, semaphore: asyncio.Semaphore) -> bool # if total_size: # progress = (downloaded / total_size) * 100 # if downloaded % (chunk_size * 128) == 0: - # logger.debug( + # _LOGGER.debug( # f"Download progress for {filename}:" # f" {progress:.1f}% ({downloaded}/{total_size} bytes)" # ) - logger.info(f"Download complete: {url} -> {local_path} ({downloaded} bytes)") + _LOGGER.info(f"Download complete: {url} -> {local_path} ({downloaded} bytes)") return True except Exception as e: - logger.exception(f"Error during download of {url}: {e}") + _LOGGER.exception(f"Error during download of {url}: {e}") # Remove partial download if an error occurred if os.path.exists(local_path): - logger.debug(f"Removing partial download: {local_path}") + _LOGGER.debug(f"Removing partial download: {local_path}") try: os.remove(local_path) except Exception as del_err: - logger.error(f"Error removing partial download {local_path}: {del_err}") + _LOGGER.error(f"Error removing partial download {local_path}: {del_err}") return False diff --git a/atr/tasks/mailtest.py b/atr/tasks/mailtest.py index 81d9e25..8e298bd 100644 --- a/atr/tasks/mailtest.py +++ b/atr/tasks/mailtest.py @@ -15,35 +15,37 @@ # specific language governing permissions and limitations # under the License. +import dataclasses import logging import os -from dataclasses import dataclass -from typing import Any +from typing import Any, Final import atr.tasks.task as task # Configure detailed logging -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +_LOGGER: Final = logging.getLogger(__name__) +_LOGGER.setLevel(logging.DEBUG) # Create file handler for test.log -file_handler = logging.FileHandler("tasks-mailtest.log") -file_handler.setLevel(logging.DEBUG) +_HANDLER: Final = logging.FileHandler("tasks-mailtest.log") +_HANDLER.setLevel(logging.DEBUG) # Create formatter with detailed information -formatter = logging.Formatter( - "[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] [%(name)s:%(funcName)s:%(lineno)d] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", +_HANDLER.setFormatter( + logging.Formatter( + "[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] [%(name)s:%(funcName)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) ) -file_handler.setFormatter(formatter) -logger.addHandler(file_handler) +_LOGGER.addHandler(_HANDLER) # Ensure parent loggers don't duplicate messages -logger.propagate = False +_LOGGER.propagate = False -logger.info("Mail test module imported") +_LOGGER.info("Mail test module imported") -@dataclass +# TODO: Use a Pydantic model instead +@dataclasses.dataclass class Args: artifact_name: str email_recipient: str @@ -52,10 +54,10 @@ class Args: @staticmethod def from_list(args: list[str]) -> "Args": """Parse command line arguments.""" - logger.debug(f"Parsing arguments: {args}") + _LOGGER.debug(f"Parsing arguments: {args}") if len(args) != 3: - logger.error(f"Invalid number of arguments: {len(args)}, expected 3") + _LOGGER.error(f"Invalid number of arguments: {len(args)}, expected 3") raise ValueError("Invalid number of arguments") artifact_name = args[0] @@ -63,15 +65,15 @@ class Args: token = args[2] if not isinstance(artifact_name, str): - logger.error(f"Artifact name must be a string, got {type(artifact_name)}") + _LOGGER.error(f"Artifact name must be a string, got {type(artifact_name)}") raise ValueError("Artifact name must be a string") if not isinstance(email_recipient, str): - logger.error(f"Email recipient must be a string, got {type(email_recipient)}") + _LOGGER.error(f"Email recipient must be a string, got {type(email_recipient)}") raise ValueError("Email recipient must be a string") if not isinstance(token, str): - logger.error(f"Token must be a string, got {type(token)}") + _LOGGER.error(f"Token must be a string, got {type(token)}") raise ValueError("Token must be a string") - logger.debug("All argument validations passed") + _LOGGER.debug("All argument validations passed") args_obj = Args( artifact_name=artifact_name, @@ -79,20 +81,20 @@ class Args: token=token, ) - logger.info(f"Args object created: {args_obj}") + _LOGGER.info(f"Args object created: {args_obj}") return args_obj def send(args: list[str]) -> tuple[task.Status, str | None, tuple[Any, ...]]: """Send a test email.""" - logger.info(f"Sending with args: {args}") + _LOGGER.info(f"Sending with args: {args}") try: - logger.debug("Delegating to send_core function") + _LOGGER.debug("Delegating to send_core function") status, error, result = send_core(args) - logger.info(f"Send completed with status: {status}") + _LOGGER.info(f"Send completed with status: {status}") return status, error, result except Exception as e: - logger.exception(f"Error in send function: {e}") + _LOGGER.exception(f"Error in send function: {e}") return task.FAILED, str(e), tuple() @@ -103,10 +105,10 @@ def send_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[Any, import atr.mail from atr.db.service import get_pmc_by_name - logger.info("Starting send_core") + _LOGGER.info("Starting send_core") try: - # Configure root logger to also write to our log file - # This ensures logs from mail.py, using the root logger, are captured + # Configure root _LOGGER to also write to our log file + # This ensures logs from mail.py, using the root _LOGGER, are captured root_logger = logging.getLogger() # Check whether our file handler is already added, to avoid duplicates has_our_handler = any( @@ -114,13 +116,13 @@ def send_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[Any, for h in root_logger.handlers ) if not has_our_handler: - # Add our file handler to the root logger - root_logger.addHandler(file_handler) - logger.info("Added file handler to root logger to capture mail.py logs") + # Add our file handler to the root _LOGGER + root_logger.addHandler(_HANDLER) + _LOGGER.info("Added file handler to root _LOGGER to capture mail.py logs") - logger.debug(f"Parsing arguments: {args_list}") + _LOGGER.debug(f"Parsing arguments: {args_list}") args = Args.from_list(args_list) - logger.info( + _LOGGER.info( f"Args parsed successfully: artifact_name={args.artifact_name}, email_recipient={args.email_recipient}" ) @@ -138,20 +140,20 @@ def send_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[Any, if not tooling_pmc: error_msg = "Tooling PMC not found in database" - logger.error(error_msg) + _LOGGER.error(error_msg) return task.FAILED, error_msg, tuple() if domain != "apache.org": error_msg = f"Email domain must be apache.org, got {domain}" - logger.error(error_msg) + _LOGGER.error(error_msg) return task.FAILED, error_msg, tuple() if local_part not in tooling_pmc.pmc_members: error_msg = f"Email recipient {local_part} is not a member of the tooling PMC" - logger.error(error_msg) + _LOGGER.error(error_msg) return task.FAILED, error_msg, tuple() - logger.info(f"Recipient {email_recipient} is a tooling PMC member, allowed") + _LOGGER.info(f"Recipient {email_recipient} is a tooling PMC member, allowed") # Load and set DKIM key try: @@ -161,10 +163,10 @@ def send_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[Any, with open(dkim_path) as f: dkim_key = f.read() atr.mail.set_secret_key(dkim_key.strip()) - logger.info("DKIM key loaded and set successfully") + _LOGGER.info("DKIM key loaded and set successfully") except Exception as e: error_msg = f"Failed to load DKIM key: {e}" - logger.error(error_msg) + _LOGGER.error(error_msg) return task.FAILED, error_msg, tuple() event = atr.mail.ArtifactEvent( @@ -173,10 +175,10 @@ def send_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[Any, token=args.token, ) atr.mail.send(event) - logger.info(f"Email sent successfully to {args.email_recipient}") + _LOGGER.info(f"Email sent successfully to {args.email_recipient}") return task.COMPLETED, None, tuple() except Exception as e: - logger.exception(f"Error in send_core: {e}") + _LOGGER.exception(f"Error in send_core: {e}") return task.FAILED, str(e), tuple() diff --git a/atr/tasks/task.py b/atr/tasks/task.py index 5b71191..1628de6 100644 --- a/atr/tasks/task.py +++ b/atr/tasks/task.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. -from enum import Enum +import enum from typing import Any, Final, Literal -class Status(Enum): +class Status(enum.Enum): COMPLETED = "completed" FAILED = "failed" diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py index cfca283..1865c0b 100644 --- a/atr/tasks/vote.py +++ b/atr/tasks/vote.py @@ -15,37 +15,37 @@ # specific language governing permissions and limitations # under the License. +import dataclasses import datetime import logging import os -from dataclasses import dataclass -from datetime import UTC -from typing import Any +from typing import Any, Final import atr.tasks.task as task # Configure detailed logging -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) +_LOGGER: Final = logging.getLogger(__name__) +_LOGGER.setLevel(logging.DEBUG) # Create file handler for tasks-vote.log -file_handler = logging.FileHandler("tasks-vote.log") -file_handler.setLevel(logging.DEBUG) +_HANDLER: Final = logging.FileHandler("tasks-vote.log") +_HANDLER.setLevel(logging.DEBUG) # Create formatter with detailed information -formatter = logging.Formatter( - "[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] [%(name)s:%(funcName)s:%(lineno)d] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", +_HANDLER.setFormatter( + logging.Formatter( + "[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] [%(name)s:%(funcName)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) ) -file_handler.setFormatter(formatter) -logger.addHandler(file_handler) +_LOGGER.addHandler(_HANDLER) # Ensure parent loggers don't duplicate messages -logger.propagate = False +_LOGGER.propagate = False -logger.info("Vote module imported") +_LOGGER.info("Vote module imported") -@dataclass +@dataclasses.dataclass class Args: """Arguments for the vote_initiate task.""" @@ -59,10 +59,10 @@ class Args: @staticmethod def from_list(args: list[str]) -> "Args": """Parse task arguments.""" - logger.debug(f"Parsing arguments: {args}") + _LOGGER.debug(f"Parsing arguments: {args}") if len(args) != 6: - logger.error(f"Invalid number of arguments: {len(args)}, expected 6") + _LOGGER.error(f"Invalid number of arguments: {len(args)}, expected 6") raise ValueError("Invalid number of arguments") release_key = args[0] @@ -82,10 +82,10 @@ class Args: ("initiator_id", initiator_id), ]: if not isinstance(arg_value, str): - logger.error(f"{arg_name} must be a string, got {type(arg_value)}") + _LOGGER.error(f"{arg_name} must be a string, got {type(arg_value)}") raise ValueError(f"{arg_name} must be a string") - logger.debug("All argument validations passed") + _LOGGER.debug("All argument validations passed") args_obj = Args( release_key=release_key, @@ -96,20 +96,20 @@ class Args: initiator_id=initiator_id, ) - logger.info(f"Args object created: {args_obj}") + _LOGGER.info(f"Args object created: {args_obj}") return args_obj def initiate(args: list[str]) -> tuple[task.Status, str | None, tuple[Any, ...]]: """Initiate a vote for a release.""" - logger.info(f"Initiating vote with args: {args}") + _LOGGER.info(f"Initiating vote with args: {args}") try: - logger.debug("Delegating to initiate_core function") + _LOGGER.debug("Delegating to initiate_core function") status, error, result = initiate_core(args) - logger.info(f"Vote initiation completed with status: {status}") + _LOGGER.info(f"Vote initiation completed with status: {status}") return status, error, result except Exception as e: - logger.exception(f"Error in initiate function: {e}") + _LOGGER.exception(f"Error in initiate function: {e}") return task.FAILED, str(e), tuple() @@ -119,10 +119,10 @@ def initiate_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[ from atr.db.service import get_release_by_key_sync test_recipients = ["sbp"] - logger.info("Starting initiate_core") + _LOGGER.info("Starting initiate_core") try: - # Configure root logger to also write to our log file - # This ensures logs from mail.py, using the root logger, are captured + # Configure root _LOGGER to also write to our log file + # This ensures logs from mail.py, using the root _LOGGER, are captured root_logger = logging.getLogger() # Check whether our file handler is already added, to avoid duplicates has_our_handler = any( @@ -130,19 +130,19 @@ def initiate_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[ for h in root_logger.handlers ) if not has_our_handler: - # Add our file handler to the root logger - root_logger.addHandler(file_handler) - logger.info("Added file handler to root logger to capture mail.py logs") + # Add our file handler to the root _LOGGER + root_logger.addHandler(_HANDLER) + _LOGGER.info("Added file handler to root _LOGGER to capture mail.py logs") - logger.debug(f"Parsing arguments: {args_list}") + _LOGGER.debug(f"Parsing arguments: {args_list}") args = Args.from_list(args_list) - logger.info(f"Args parsed successfully: {args}") + _LOGGER.info(f"Args parsed successfully: {args}") # Get the release information release = get_release_by_key_sync(args.release_key) if not release: error_msg = f"Release with key {args.release_key} not found" - logger.error(error_msg) + _LOGGER.error(error_msg) return task.FAILED, error_msg, tuple() # GPG key ID, just for testing the UI @@ -150,7 +150,7 @@ def initiate_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[ # Calculate vote end date vote_duration_hours = int(args.vote_duration) - vote_start = datetime.datetime.now(UTC) + vote_start = datetime.datetime.now(datetime.UTC) vote_end = vote_start + datetime.timedelta(hours=vote_duration_hours) # Format dates for email @@ -164,16 +164,16 @@ def initiate_core(args_list: list[str]) -> tuple[task.Status, str | None, tuple[ with open(dkim_path) as f: dkim_key = f.read() atr.mail.set_secret_key(dkim_key.strip()) - logger.info("DKIM key loaded and set successfully") + _LOGGER.info("DKIM key loaded and set successfully") except Exception as e: error_msg = f"Failed to load DKIM key: {e}" - logger.error(error_msg) + _LOGGER.error(error_msg) return task.FAILED, error_msg, tuple() # Get PMC and product details if release.pmc is None: error_msg = "Release has no associated PMC" - logger.error(error_msg) + _LOGGER.error(error_msg) return task.FAILED, error_msg, tuple() pmc_name = release.pmc.project_name @@ -216,7 +216,7 @@ Thanks, original_recipient = args.email_to # Only one test recipient is required for now test_recipient = test_recipients[0] + "@apache.org" - logger.info(f"TEMPORARY: Overriding recipient from {original_recipient} to {test_recipient}") + _LOGGER.info(f"TEMPORARY: Overriding recipient from {original_recipient} to {test_recipient}") # Create mail event with test recipient # Use test account instead of actual PMC list @@ -230,7 +230,7 @@ Thanks, # Send the email atr.mail.send(event) - logger.info( + _LOGGER.info( f"Vote email sent successfully to test account {test_recipient} (would have been {original_recipient})" ) @@ -251,5 +251,5 @@ Thanks, ) except Exception as e: - logger.exception(f"Error in initiate_core: {e}") + _LOGGER.exception(f"Error in initiate_core: {e}") return task.FAILED, str(e), tuple() diff --git a/atr/templates/candidate-review.html b/atr/templates/candidate-review.html index 2892399..459316c 100644 --- a/atr/templates/candidate-review.html +++ b/atr/templates/candidate-review.html @@ -65,7 +65,7 @@ </div> {% for package in release.packages %} - <table class="table border border-1 mb-4 table-bordered candidate-table"> + <table class="table border border-1 mb-4 table-bordered candidate-table atr-data"> <tr> <th>Name</th> <td> diff --git a/atr/util.py b/atr/util.py index 8675f8f..c5266b6 100644 --- a/atr/util.py +++ b/atr/util.py @@ -82,6 +82,8 @@ def _get_dict_to_list_inner_type_adapter(source_type: Any, key: str) -> pydantic assert (other_fields := {k: v for k, v in fields.items() if k != key}) # noqa: RUF018 model_name = f"{cls.__name__}Inner" + + # Create proper field definitions for create_model inner_model = pydantic.create_model(model_name, **{k: (v.annotation, v) for k, v in other_fields.items()}) # type: ignore return pydantic.TypeAdapter(dict[Annotated[str, key_field], inner_model]) # type: ignore @@ -127,6 +129,13 @@ class DictToList: ) +def validate_as_type(value: Any, t: type[T]) -> T: + """Validate the given value as the given type.""" + if not isinstance(value, t): + raise ValueError(f"Expected {t}, got {type(value)}") + return value + + def unwrap(value: T | None, error_message: str = "unexpected None when unwrapping value") -> T: """ Will unwrap the given value or raise a ValueError if it is None --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@tooling.apache.org For additional commands, e-mail: commits-h...@tooling.apache.org