This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new dc806dc  Remove all uses of cast
dc806dc is described below

commit dc806dc59cbd207b3aa598ff5294a6e6f9d5d7ee
Author: Sean B. Palmer <s...@miscoranda.com>
AuthorDate: Wed Mar 12 15:04:23 2025 +0200

    Remove all uses of cast
---
 atr/db/__init__.py      | 24 ++++++++++++++++++++++++
 atr/db/service.py       | 14 ++++++--------
 atr/routes/candidate.py | 22 ++++++----------------
 atr/routes/download.py  | 14 ++++++--------
 atr/routes/keys.py      |  9 ++-------
 atr/routes/package.py   | 13 +++----------
 atr/routes/project.py   |  8 ++------
 atr/routes/release.py   |  5 +----
 atr/tasks/signature.py  |  6 ++----
 migrations/env.py       |  8 ++++----
 10 files changed, 56 insertions(+), 67 deletions(-)

diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 82d6e24..63f5e2b 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -17,6 +17,9 @@
 
 import logging
 import os
+from typing import Any
+
+import sqlalchemy.orm as orm
 
 # from alembic import command
 from alembic.config import Config
@@ -103,3 +106,24 @@ def create_sync_db_session() -> Session:
     global _SYNC_ENGINE
     assert _SYNC_ENGINE is not None
     return Session(_SYNC_ENGINE)
+
+
+def eager_load(*entities: Any) -> orm.strategy_options._AbstractLoad:
+    """Eagerly load the given entities from the query."""
+    for entity in entities:
+        entity = instrumented_attribute(entity)
+    return orm.selectinload(*entities)
+
+
+def eager_load2(a: Any, b: Any) -> orm.strategy_options._AbstractLoad:
+    """Eagerly load the given entities from the query."""
+    a = instrumented_attribute(a)
+    b = instrumented_attribute(b)
+    return orm.selectinload(a).selectinload(b)
+
+
+def instrumented_attribute(entity: Any) -> orm.InstrumentedAttribute:
+    """Check whether the object is an InstrumentedAttribute."""
+    if not isinstance(entity, orm.InstrumentedAttribute):
+        raise ValueError(f"Object must be an orm.InstrumentedAttribute, got: 
{type(entity)}")
+    return entity
diff --git a/atr/db/service.py b/atr/db/service.py
index 16a5116..1e17dd5 100644
--- a/atr/db/service.py
+++ b/atr/db/service.py
@@ -17,15 +17,13 @@
 
 from collections.abc import Sequence
 from contextlib import nullcontext
-from typing import cast
 
 from sqlalchemy import func
 from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy.orm import selectinload
-from sqlalchemy.orm.attributes import InstrumentedAttribute
 from sqlmodel import select
 
-from atr.db.models import PMC, ProductLine, Release, Task
+import atr.db as db
+from atr.db.models import PMC, Release, Task
 
 from . import create_async_db_session
 
@@ -54,8 +52,8 @@ async def get_release_by_key(storage_key: str) -> Release | 
None:
         query = (
             select(Release)
             .where(Release.storage_key == storage_key)
-            .options(selectinload(cast(InstrumentedAttribute[PMC], 
Release.pmc)))
-            .options(selectinload(cast(InstrumentedAttribute[ProductLine], 
Release.product_line)))
+            .options(db.eager_load(Release.pmc))
+            .options(db.eager_load(Release.product_line))
         )
         result = await db_session.execute(query)
         return result.scalar_one_or_none()
@@ -70,8 +68,8 @@ def get_release_by_key_sync(storage_key: str) -> Release | 
None:
         query = (
             select(Release)
             .where(Release.storage_key == storage_key)
-            .options(selectinload(cast(InstrumentedAttribute[PMC], 
Release.pmc)))
-            .options(selectinload(cast(InstrumentedAttribute[ProductLine], 
Release.product_line)))
+            .options(db.eager_load(Release.pmc))
+            .options(db.eager_load(Release.product_line))
         )
         result = session.execute(query)
         return result.scalar_one_or_none()
diff --git a/atr/routes/candidate.py b/atr/routes/candidate.py
index 8b4cc19..674d0a0 100644
--- a/atr/routes/candidate.py
+++ b/atr/routes/candidate.py
@@ -19,10 +19,8 @@
 
 import datetime
 import secrets
-from typing import cast
 
 import quart
-import sqlalchemy.orm as orm
 
 # import sqlalchemy.orm.attributes as attributes
 import sqlmodel
@@ -140,10 +138,8 @@ async def root_candidate_create() -> response.Response | 
str:
 
     # Get PMC objects for all projects the user is a member of
     async with db.create_async_db_session() as db_session:
-        from sqlalchemy.sql.expression import ColumnElement
-
         project_list = web_session.committees + web_session.projects
-        project_name: ColumnElement[str] = cast(ColumnElement[str], 
models.PMC.project_name)
+        project_name = db.instrumented_attribute(models.PMC.project_name)
         statement = 
sqlmodel.select(models.PMC).where(project_name.in_(project_list))
         user_pmcs = (await db_session.execute(statement)).scalars().all()
 
@@ -170,19 +166,13 @@ async def root_candidate_review() -> str:
         # TODO: We don't actually record who uploaded the release candidate
         # We should probably add that information!
         # TODO: This duplicates code in root_package_add
-        release_pmc = 
orm.selectinload(cast(orm.InstrumentedAttribute[models.PMC], 
models.Release.pmc))
-        release_packages = orm.selectinload(
-            cast(orm.InstrumentedAttribute[list[models.Package]], 
models.Release.packages)
-        )
-        package_tasks = release_packages.selectinload(
-            cast(orm.InstrumentedAttribute[list[models.Task]], 
models.Package.tasks)
-        )
-        release_product_line = orm.selectinload(
-            cast(orm.InstrumentedAttribute[models.ProductLine], 
models.Release.product_line)
-        )
         statement = (
             sqlmodel.select(models.Release)
-            .options(release_pmc, release_packages, package_tasks, 
release_product_line)
+            .options(
+                db.eager_load(models.Release.pmc),
+                db.eager_load(models.Release.product_line),
+                db.eager_load2(models.Release.packages, models.Package.tasks),
+            )
             .join(models.PMC)
             .where(models.Release.stage == models.ReleaseStage.CANDIDATE)
         )
diff --git a/atr/routes/download.py b/atr/routes/download.py
index cbf0ee5..f94c7e9 100644
--- a/atr/routes/download.py
+++ b/atr/routes/download.py
@@ -18,12 +18,10 @@
 """download.py"""
 
 import pathlib
-from typing import cast
 
 import aiofiles
 import aiofiles.os
 import quart
-import sqlalchemy.orm as orm
 import sqlmodel
 import werkzeug.wrappers.response as response
 
@@ -48,12 +46,12 @@ async def root_download_artifact(release_key: str, 
artifact_sha3: str) -> respon
 
     async with db.create_async_db_session() as db_session:
         # Find the package
-        package_release = 
orm.selectinload(cast(orm.InstrumentedAttribute[models.Release], 
models.Package.release))
-        release_pmc = 
package_release.selectinload(cast(orm.InstrumentedAttribute[models.PMC], 
models.Release.pmc))
+        package_release = db.eager_load(models.Package.release)
+        release_pmc = db.eager_load(models.Release.pmc)
         package_statement = (
             sqlmodel.select(models.Package)
             .where(models.Package.artifact_sha3 == artifact_sha3, 
models.Package.release_key == release_key)
-            .options(release_pmc)
+            .options(package_release, release_pmc)
         )
         result = await db_session.execute(package_statement)
         package = result.scalar_one_or_none()
@@ -94,12 +92,12 @@ async def root_download_signature(release_key: str, 
signature_sha3: str) -> quar
 
     async with db.create_async_db_session() as db_session:
         # Find the package that has this signature
-        package_release = 
orm.selectinload(cast(orm.InstrumentedAttribute[models.Release], 
models.Package.release))
-        release_pmc = 
package_release.selectinload(cast(orm.InstrumentedAttribute[models.PMC], 
models.Release.pmc))
+        package_release = db.eager_load(models.Package.release)
+        release_pmc = db.eager_load(models.Release.pmc)
         package_statement = (
             sqlmodel.select(models.Package)
             .where(models.Package.signature_sha3 == signature_sha3, 
models.Package.release_key == release_key)
-            .options(release_pmc)
+            .options(package_release, release_pmc)
         )
         result = await db_session.execute(package_statement)
         package = result.scalar_one_or_none()
diff --git a/atr/routes/keys.py b/atr/routes/keys.py
index 3765302..b7d967b 100644
--- a/atr/routes/keys.py
+++ b/atr/routes/keys.py
@@ -26,12 +26,10 @@ import pprint
 import shutil
 import tempfile
 from collections.abc import AsyncGenerator, Sequence
-from typing import cast
 
 import gnupg
 import quart
 import sqlalchemy.ext.asyncio
-import sqlalchemy.orm as orm
 import sqlmodel
 import werkzeug.wrappers.response as response
 
@@ -187,10 +185,8 @@ async def root_keys_add() -> str:
 
     # Get PMC objects for all projects the user is a member of
     async with db.create_async_db_session() as db_session:
-        from sqlalchemy.sql.expression import ColumnElement
-
         project_list = web_session.committees + web_session.projects
-        project_name = cast(ColumnElement[str], models.PMC.project_name)
+        project_name = db.instrumented_attribute(models.PMC.project_name)
         pmc_statement = 
sqlmodel.select(models.PMC).where(project_name.in_(project_list))
         user_pmcs = (await db_session.execute(pmc_statement)).scalars().all()
 
@@ -257,10 +253,9 @@ async def root_keys_review() -> str:
 
     # Get all existing keys for the user
     async with db.create_async_db_session() as db_session:
-        pmcs_loader = 
orm.selectinload(cast(orm.InstrumentedAttribute[list[models.PMC]], 
models.PublicSigningKey.pmcs))
         psk_statement = (
             sqlmodel.select(models.PublicSigningKey)
-            .options(pmcs_loader)
+            .options(db.eager_load(models.PublicSigningKey.pmcs))
             .where(models.PublicSigningKey.apache_uid == web_session.uid)
         )
         user_keys = (await db_session.execute(psk_statement)).scalars().all()
diff --git a/atr/routes/package.py b/atr/routes/package.py
index 836111b..33abfff 100644
--- a/atr/routes/package.py
+++ b/atr/routes/package.py
@@ -25,13 +25,11 @@ import logging.handlers
 import pathlib
 import secrets
 from collections.abc import Sequence
-from typing import cast
 
 import aiofiles
 import aiofiles.os
 import quart
 import sqlalchemy.ext.asyncio
-import sqlalchemy.orm as orm
 import sqlmodel
 import werkzeug.datastructures as datastructures
 import werkzeug.wrappers.response as response
@@ -212,12 +210,9 @@ async def package_data_get(
     #     raise FlashError("Package has no associated release")
     # if Release.pmc is None:
     #     raise FlashError("Release has no associated PMC")
-
-    pkg_release = cast(orm.InstrumentedAttribute[models.Release], 
models.Package.release)
-    rel_pmc = cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc)
     statement = (
         sqlmodel.select(models.Package)
-        .options(orm.selectinload(pkg_release).selectinload(rel_pmc))
+        .options(db.eager_load2(models.Package.release, models.Release.pmc))
         .where(models.Package.artifact_sha3 == artifact_sha3)
     )
     result = await db_session.execute(statement)
@@ -378,10 +373,8 @@ async def root_package_add() -> response.Response | str:
     # Get all releases where the user is a PMC member or committer of the 
associated PMC
     async with db.create_async_db_session() as db_session:
         # TODO: This duplicates code in root_candidate_review
-        release_pmc = 
orm.selectinload(cast(orm.InstrumentedAttribute[models.PMC], 
models.Release.pmc))
-        release_product_line = orm.selectinload(
-            cast(orm.InstrumentedAttribute[models.ProductLine], 
models.Release.product_line)
-        )
+        release_pmc = db.eager_load(models.Release.pmc)
+        release_product_line = db.eager_load(models.Release.product_line)
         statement = (
             sqlmodel.select(models.Release)
             .options(release_pmc, release_product_line)
diff --git a/atr/routes/project.py b/atr/routes/project.py
index 1fc53b3..95eedab 100644
--- a/atr/routes/project.py
+++ b/atr/routes/project.py
@@ -18,11 +18,9 @@
 """project.py"""
 
 import http.client
-from typing import cast
 
 import quart
 import quart_wtf
-import sqlalchemy.orm as orm
 import sqlmodel
 import werkzeug.wrappers.response as response
 import wtforms
@@ -51,10 +49,8 @@ async def root_project_view(project_name: str) -> str:
             sqlmodel.select(models.PMC)
             .where(models.PMC.project_name == project_name)
             .options(
-                orm.selectinload(
-                    
cast(orm.attributes.InstrumentedAttribute[models.PublicSigningKey], 
models.PMC.public_signing_keys)
-                ),
-                
orm.selectinload(cast(orm.attributes.InstrumentedAttribute[models.VotePolicy], 
models.PMC.vote_policy)),
+                db.eager_load(models.PMC.public_signing_keys),
+                db.eager_load(models.PMC.vote_policy),
             )
         )
 
diff --git a/atr/routes/release.py b/atr/routes/release.py
index 5daa8ae..c06848d 100644
--- a/atr/routes/release.py
+++ b/atr/routes/release.py
@@ -20,11 +20,9 @@
 import logging
 import logging.handlers
 import pathlib
-from typing import cast
 
 import quart
 import sqlalchemy.ext.asyncio
-import sqlalchemy.orm as orm
 import sqlmodel
 import werkzeug.wrappers.response as response
 
@@ -52,10 +50,9 @@ async def release_delete_validate(
     # if Release.pmc is None:
     #     raise FlashError("Release has no associated PMC")
 
-    rel_pmc = cast(orm.InstrumentedAttribute[models.PMC], models.Release.pmc)
     statement = (
         sqlmodel.select(models.Release)
-        .options(orm.selectinload(rel_pmc))
+        .options(db.eager_load(models.Release.pmc))
         .where(models.Release.storage_key == release_key)
     )
     result = await db_session.execute(statement)
diff --git a/atr/tasks/signature.py b/atr/tasks/signature.py
index 74886fe..7211772 100644
--- a/atr/tasks/signature.py
+++ b/atr/tasks/signature.py
@@ -20,7 +20,7 @@ import logging
 import shutil
 import tempfile
 from collections.abc import Generator
-from typing import Any, BinaryIO, Final, cast
+from typing import Any, BinaryIO, Final
 
 import gnupg
 import sqlalchemy.sql as sql
@@ -46,13 +46,11 @@ def _check_core(pmc_name: str, artifact_path: str, 
signature_path: str) -> dict[
     # Query only the signing keys associated with this PMC
     # TODO: Rename create_sync_db_session to create_session_sync
     with db.create_sync_db_session() as session:
-        from sqlalchemy.sql.expression import ColumnElement
-
         statement = (
             sql.select(models.PublicSigningKey)
             .join(models.PMCKeyLink)
             .join(models.PMC)
-            .where(cast(ColumnElement[bool], models.PMC.project_name == 
pmc_name))
+            .where(db.instrumented_attribute(models.PMC.project_name) == 
pmc_name)
         )
         result = session.execute(statement)
         public_keys = [key.ascii_armored_key for key in result.scalars().all()]
diff --git a/migrations/env.py b/migrations/env.py
index 7ea3ab1..8b8f586 100644
--- a/migrations/env.py
+++ b/migrations/env.py
@@ -1,5 +1,4 @@
 from logging.config import fileConfig
-from typing import Any, cast
 
 from alembic import context
 from sqlalchemy import engine_from_config, pool
@@ -43,10 +42,11 @@ def run_migrations_online() -> None:
         raise RuntimeError("sqlalchemy.url is not set")
 
     # Create synchronous engine for migrations
-    configuration = config.get_section(config.config_ini_section)
-    if configuration is None:
+    section = config.get_section(config.config_ini_section)
+    if section is None:
         configuration = {}
-    configuration = cast(dict[str, Any], configuration)
+    else:
+        configuration = dict(section)
     configuration["sqlalchemy.url"] = sync_url
 
     connectable = engine_from_config(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@tooling.apache.org
For additional commands, e-mail: commits-h...@tooling.apache.org

Reply via email to