This is an automated email from the ASF dual-hosted git repository. sbp pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push: new dc806dc Remove all uses of cast dc806dc is described below commit dc806dc59cbd207b3aa598ff5294a6e6f9d5d7ee Author: Sean B. Palmer <s...@miscoranda.com> AuthorDate: Wed Mar 12 15:04:23 2025 +0200 Remove all uses of cast --- atr/db/__init__.py | 24 ++++++++++++++++++++++++ atr/db/service.py | 14 ++++++-------- atr/routes/candidate.py | 22 ++++++---------------- atr/routes/download.py | 14 ++++++-------- atr/routes/keys.py | 9 ++------- atr/routes/package.py | 13 +++---------- atr/routes/project.py | 8 ++------ atr/routes/release.py | 5 +---- atr/tasks/signature.py | 6 ++---- migrations/env.py | 8 ++++---- 10 files changed, 56 insertions(+), 67 deletions(-) diff --git a/atr/db/__init__.py b/atr/db/__init__.py index 82d6e24..63f5e2b 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -17,6 +17,9 @@ import logging import os +from typing import Any + +import sqlalchemy.orm as orm # from alembic import command from alembic.config import Config @@ -103,3 +106,24 @@ def create_sync_db_session() -> Session: global _SYNC_ENGINE assert _SYNC_ENGINE is not None return Session(_SYNC_ENGINE) + + +def eager_load(*entities: Any) -> orm.strategy_options._AbstractLoad: + """Eagerly load the given entities from the query.""" + for entity in entities: + entity = instrumented_attribute(entity) + return orm.selectinload(*entities) + + +def eager_load2(a: Any, b: Any) -> orm.strategy_options._AbstractLoad: + """Eagerly load the given entities from the query.""" + a = instrumented_attribute(a) + b = instrumented_attribute(b) + return orm.selectinload(a).selectinload(b) + + +def instrumented_attribute(entity: Any) -> orm.InstrumentedAttribute: + """Check whether the object is an InstrumentedAttribute.""" + if not isinstance(entity, orm.InstrumentedAttribute): + raise ValueError(f"Object must be an orm.InstrumentedAttribute, got: {type(entity)}") + return entity diff --git a/atr/db/service.py b/atr/db/service.py index 16a5116..1e17dd5 100644 --- a/atr/db/service.py +++ b/atr/db/service.py @@ -17,15 +17,13 @@ from collections.abc import Sequence from contextlib import nullcontext -from typing import cast from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload -from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlmodel import select -from atr.db.models import PMC, ProductLine, Release, Task +import atr.db as db +from atr.db.models import PMC, Release, Task from . import create_async_db_session @@ -54,8 +52,8 @@ async def get_release_by_key(storage_key: str) -> Release | None: query = ( select(Release) .where(Release.storage_key == storage_key) - .options(selectinload(cast(InstrumentedAttribute[PMC], Release.pmc))) - .options(selectinload(cast(InstrumentedAttribute[ProductLine], Release.product_line))) + .options(db.eager_load(Release.pmc)) + .options(db.eager_load(Release.product_line)) ) result = await db_session.execute(query) return result.scalar_one_or_none() @@ -70,8 +68,8 @@ def get_release_by_key_sync(storage_key: str) -> Release | None: query = ( select(Release) .where(Release.storage_key == storage_key) - .options(selectinload(cast(InstrumentedAttribute[PMC], Release.pmc))) - .options(selectinload(cast(InstrumentedAttribute[ProductLine], Release.product_line))) + .options(db.eager_load(Release.pmc)) + .options(db.eager_load(Release.product_line)) ) result = session.execute(query) return result.scalar_one_or_none() diff --git a/atr/routes/candidate.py b/atr/routes/candidate.py index 8b4cc19..674d0a0 100644 --- a/atr/routes/candidate.py +++ b/atr/routes/candidate.py @@ -19,10 +19,8 @@ import datetime import secrets -from typing import cast import quart -import sqlalchemy.orm as orm # import sqlalchemy.orm.attributes as attributes import sqlmodel @@ -140,10 +138,8 @@ async def root_candidate_create() -> response.Response | str: # Get PMC objects for all projects the user is a member of async with db.create_async_db_session() as db_session: - from sqlalchemy.sql.expression import ColumnElement - project_list = web_session.committees + web_session.projects - project_name: ColumnElement[str] = cast(ColumnElement[str], models.PMC.project_name) + project_name = db.instrumented_attribute(models.PMC.project_name) statement = sqlmodel.select(models.PMC).where(project_name.in_(project_list)) user_pmcs = (await db_session.execute(statement)).scalars().all() @@ -170,19 +166,13 @@ async def root_candidate_review() -> str: # TODO: We don't actually record who uploaded the release candidate # We should probably add that information! # TODO: This duplicates code in root_package_add - release_pmc = orm.selectinload(cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc)) - release_packages = orm.selectinload( - cast(orm.InstrumentedAttribute[list[models.Package]], models.Release.packages) - ) - package_tasks = release_packages.selectinload( - cast(orm.InstrumentedAttribute[list[models.Task]], models.Package.tasks) - ) - release_product_line = orm.selectinload( - cast(orm.InstrumentedAttribute[models.ProductLine], models.Release.product_line) - ) statement = ( sqlmodel.select(models.Release) - .options(release_pmc, release_packages, package_tasks, release_product_line) + .options( + db.eager_load(models.Release.pmc), + db.eager_load(models.Release.product_line), + db.eager_load2(models.Release.packages, models.Package.tasks), + ) .join(models.PMC) .where(models.Release.stage == models.ReleaseStage.CANDIDATE) ) diff --git a/atr/routes/download.py b/atr/routes/download.py index cbf0ee5..f94c7e9 100644 --- a/atr/routes/download.py +++ b/atr/routes/download.py @@ -18,12 +18,10 @@ """download.py""" import pathlib -from typing import cast import aiofiles import aiofiles.os import quart -import sqlalchemy.orm as orm import sqlmodel import werkzeug.wrappers.response as response @@ -48,12 +46,12 @@ async def root_download_artifact(release_key: str, artifact_sha3: str) -> respon async with db.create_async_db_session() as db_session: # Find the package - package_release = orm.selectinload(cast(orm.InstrumentedAttribute[models.Release], models.Package.release)) - release_pmc = package_release.selectinload(cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc)) + package_release = db.eager_load(models.Package.release) + release_pmc = db.eager_load(models.Release.pmc) package_statement = ( sqlmodel.select(models.Package) .where(models.Package.artifact_sha3 == artifact_sha3, models.Package.release_key == release_key) - .options(release_pmc) + .options(package_release, release_pmc) ) result = await db_session.execute(package_statement) package = result.scalar_one_or_none() @@ -94,12 +92,12 @@ async def root_download_signature(release_key: str, signature_sha3: str) -> quar async with db.create_async_db_session() as db_session: # Find the package that has this signature - package_release = orm.selectinload(cast(orm.InstrumentedAttribute[models.Release], models.Package.release)) - release_pmc = package_release.selectinload(cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc)) + package_release = db.eager_load(models.Package.release) + release_pmc = db.eager_load(models.Release.pmc) package_statement = ( sqlmodel.select(models.Package) .where(models.Package.signature_sha3 == signature_sha3, models.Package.release_key == release_key) - .options(release_pmc) + .options(package_release, release_pmc) ) result = await db_session.execute(package_statement) package = result.scalar_one_or_none() diff --git a/atr/routes/keys.py b/atr/routes/keys.py index 3765302..b7d967b 100644 --- a/atr/routes/keys.py +++ b/atr/routes/keys.py @@ -26,12 +26,10 @@ import pprint import shutil import tempfile from collections.abc import AsyncGenerator, Sequence -from typing import cast import gnupg import quart import sqlalchemy.ext.asyncio -import sqlalchemy.orm as orm import sqlmodel import werkzeug.wrappers.response as response @@ -187,10 +185,8 @@ async def root_keys_add() -> str: # Get PMC objects for all projects the user is a member of async with db.create_async_db_session() as db_session: - from sqlalchemy.sql.expression import ColumnElement - project_list = web_session.committees + web_session.projects - project_name = cast(ColumnElement[str], models.PMC.project_name) + project_name = db.instrumented_attribute(models.PMC.project_name) pmc_statement = sqlmodel.select(models.PMC).where(project_name.in_(project_list)) user_pmcs = (await db_session.execute(pmc_statement)).scalars().all() @@ -257,10 +253,9 @@ async def root_keys_review() -> str: # Get all existing keys for the user async with db.create_async_db_session() as db_session: - pmcs_loader = orm.selectinload(cast(orm.InstrumentedAttribute[list[models.PMC]], models.PublicSigningKey.pmcs)) psk_statement = ( sqlmodel.select(models.PublicSigningKey) - .options(pmcs_loader) + .options(db.eager_load(models.PublicSigningKey.pmcs)) .where(models.PublicSigningKey.apache_uid == web_session.uid) ) user_keys = (await db_session.execute(psk_statement)).scalars().all() diff --git a/atr/routes/package.py b/atr/routes/package.py index 836111b..33abfff 100644 --- a/atr/routes/package.py +++ b/atr/routes/package.py @@ -25,13 +25,11 @@ import logging.handlers import pathlib import secrets from collections.abc import Sequence -from typing import cast import aiofiles import aiofiles.os import quart import sqlalchemy.ext.asyncio -import sqlalchemy.orm as orm import sqlmodel import werkzeug.datastructures as datastructures import werkzeug.wrappers.response as response @@ -212,12 +210,9 @@ async def package_data_get( # raise FlashError("Package has no associated release") # if Release.pmc is None: # raise FlashError("Release has no associated PMC") - - pkg_release = cast(orm.InstrumentedAttribute[models.Release], models.Package.release) - rel_pmc = cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc) statement = ( sqlmodel.select(models.Package) - .options(orm.selectinload(pkg_release).selectinload(rel_pmc)) + .options(db.eager_load2(models.Package.release, models.Release.pmc)) .where(models.Package.artifact_sha3 == artifact_sha3) ) result = await db_session.execute(statement) @@ -378,10 +373,8 @@ async def root_package_add() -> response.Response | str: # Get all releases where the user is a PMC member or committer of the associated PMC async with db.create_async_db_session() as db_session: # TODO: This duplicates code in root_candidate_review - release_pmc = orm.selectinload(cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc)) - release_product_line = orm.selectinload( - cast(orm.InstrumentedAttribute[models.ProductLine], models.Release.product_line) - ) + release_pmc = db.eager_load(models.Release.pmc) + release_product_line = db.eager_load(models.Release.product_line) statement = ( sqlmodel.select(models.Release) .options(release_pmc, release_product_line) diff --git a/atr/routes/project.py b/atr/routes/project.py index 1fc53b3..95eedab 100644 --- a/atr/routes/project.py +++ b/atr/routes/project.py @@ -18,11 +18,9 @@ """project.py""" import http.client -from typing import cast import quart import quart_wtf -import sqlalchemy.orm as orm import sqlmodel import werkzeug.wrappers.response as response import wtforms @@ -51,10 +49,8 @@ async def root_project_view(project_name: str) -> str: sqlmodel.select(models.PMC) .where(models.PMC.project_name == project_name) .options( - orm.selectinload( - cast(orm.attributes.InstrumentedAttribute[models.PublicSigningKey], models.PMC.public_signing_keys) - ), - orm.selectinload(cast(orm.attributes.InstrumentedAttribute[models.VotePolicy], models.PMC.vote_policy)), + db.eager_load(models.PMC.public_signing_keys), + db.eager_load(models.PMC.vote_policy), ) ) diff --git a/atr/routes/release.py b/atr/routes/release.py index 5daa8ae..c06848d 100644 --- a/atr/routes/release.py +++ b/atr/routes/release.py @@ -20,11 +20,9 @@ import logging import logging.handlers import pathlib -from typing import cast import quart import sqlalchemy.ext.asyncio -import sqlalchemy.orm as orm import sqlmodel import werkzeug.wrappers.response as response @@ -52,10 +50,9 @@ async def release_delete_validate( # if Release.pmc is None: # raise FlashError("Release has no associated PMC") - rel_pmc = cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc) statement = ( sqlmodel.select(models.Release) - .options(orm.selectinload(rel_pmc)) + .options(db.eager_load(models.Release.pmc)) .where(models.Release.storage_key == release_key) ) result = await db_session.execute(statement) diff --git a/atr/tasks/signature.py b/atr/tasks/signature.py index 74886fe..7211772 100644 --- a/atr/tasks/signature.py +++ b/atr/tasks/signature.py @@ -20,7 +20,7 @@ import logging import shutil import tempfile from collections.abc import Generator -from typing import Any, BinaryIO, Final, cast +from typing import Any, BinaryIO, Final import gnupg import sqlalchemy.sql as sql @@ -46,13 +46,11 @@ def _check_core(pmc_name: str, artifact_path: str, signature_path: str) -> dict[ # Query only the signing keys associated with this PMC # TODO: Rename create_sync_db_session to create_session_sync with db.create_sync_db_session() as session: - from sqlalchemy.sql.expression import ColumnElement - statement = ( sql.select(models.PublicSigningKey) .join(models.PMCKeyLink) .join(models.PMC) - .where(cast(ColumnElement[bool], models.PMC.project_name == pmc_name)) + .where(db.instrumented_attribute(models.PMC.project_name) == pmc_name) ) result = session.execute(statement) public_keys = [key.ascii_armored_key for key in result.scalars().all()] diff --git a/migrations/env.py b/migrations/env.py index 7ea3ab1..8b8f586 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -1,5 +1,4 @@ from logging.config import fileConfig -from typing import Any, cast from alembic import context from sqlalchemy import engine_from_config, pool @@ -43,10 +42,11 @@ def run_migrations_online() -> None: raise RuntimeError("sqlalchemy.url is not set") # Create synchronous engine for migrations - configuration = config.get_section(config.config_ini_section) - if configuration is None: + section = config.get_section(config.config_ini_section) + if section is None: configuration = {} - configuration = cast(dict[str, Any], configuration) + else: + configuration = dict(section) configuration["sqlalchemy.url"] = sync_url connectable = engine_from_config( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@tooling.apache.org For additional commands, e-mail: commits-h...@tooling.apache.org