This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ef5011  Ensure promotion integrity and improve logging
0ef5011 is described below

commit 0ef501187610b9e74a0cb61c5f6cee08ca3096df
Author: Sean B. Palmer <[email protected]>
AuthorDate: Fri May 16 17:13:23 2025 +0100

    Ensure promotion integrity and improve logging
---
 atr/db/__init__.py                                 | 72 +++++++++++++++++++---
 atr/db/models.py                                   | 34 ++++++++++
 atr/revision.py                                    |  8 ++-
 atr/routes/__init__.py                             |  3 +
 atr/routes/announce.py                             | 52 ++++++++++++----
 atr/routes/voting.py                               | 49 ++++++++++-----
 migrations/env.py                                  |  5 +-
 ....15_32c59be6.py => 0001_2025.05.15_1d3ee5a0.py} |  8 ++-
 8 files changed, 185 insertions(+), 46 deletions(-)

diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 7cdecc7..4c643b3 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import contextlib
 import logging
 import os
 from typing import TYPE_CHECKING, Any, Final, Generic, TypeGuard, TypeVar
@@ -37,12 +38,13 @@ import atr.util as util
 
 if TYPE_CHECKING:
     import datetime
-    from collections.abc import Sequence
+    from collections.abc import Iterator, Sequence
 
     import asfquart.base as base
 
 _LOGGER: Final = logging.getLogger(__name__)
 
+global_log_query: bool = False
 _global_atr_engine: sqlalchemy.ext.asyncio.AsyncEngine | None = None
 _global_atr_sessionmaker: sqlalchemy.ext.asyncio.async_sessionmaker | None = 
None
 
@@ -86,18 +88,30 @@ class Query(Generic[T]):
         self.query = self.query.order_by(*args, **kwargs)
         return self
 
-    async def get(self) -> T | None:
+    def log_query(self, method_name: str, log_query: bool) -> None:
+        if not (self.session.log_queries or global_log_query or log_query):
+            return
+        try:
+            compiled_query = self.query.compile(self.session.bind, 
compile_kwargs={"literal_binds": True})
+            _LOGGER.info(f"Executing query ({method_name}): {compiled_query}")
+        except Exception as e:
+            _LOGGER.error(f"Error compiling query for logging ({method_name}): 
{e}")
+
+    async def get(self, log_query: bool = False) -> T | None:
+        self.log_query("get", log_query)
         result = await self.session.execute(self.query)
         return result.scalar_one_or_none()
 
-    async def demand(self, error: Exception) -> T:
+    async def demand(self, error: Exception, log_query: bool = False) -> T:
+        self.log_query("demand", log_query)
         result = await self.session.execute(self.query)
         item = result.scalar_one_or_none()
         if item is None:
             raise error
         return item
 
-    async def all(self) -> Sequence[T]:
+    async def all(self, log_query: bool = False) -> Sequence[T]:
+        self.log_query("all", log_query)
         result = await self.session.execute(self.query)
         return result.scalars().all()
 
@@ -106,6 +120,14 @@ class Query(Generic[T]):
 
 
 class Session(sqlalchemy.ext.asyncio.AsyncSession):
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        explicit_value_passed_by_sessionmaker = kwargs.pop("log_queries", None)
+        super().__init__(*args, **kwargs)
+
+        self.log_queries: bool = global_log_query
+        if explicit_value_passed_by_sessionmaker is not None:
+            self.log_queries = explicit_value_passed_by_sessionmaker
+
     # TODO: Need to type all of these arguments correctly
 
     def check_result(
@@ -191,6 +213,17 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
 
         return Query(self, query)
 
+    async def execute_query(self, query: sqlalchemy.sql.expression.Executable) 
-> sqlalchemy.engine.Result:
+        if (self.log_queries or global_log_query) and isinstance(query, 
sqlalchemy.sql.expression.Select):
+            try:
+                dialect = self.bind.dialect if self.bind else 
sqlalchemy.dialects.sqlite.dialect()
+                compiled_query = query.compile(dialect=dialect, 
compile_kwargs={"literal_binds": True})
+                _LOGGER.info(f"Executing query (execute_query): 
{compiled_query}")
+            except Exception as e:
+                _LOGGER.error(f"Error compiling query for logging: {e}")
+        execution_result: sqlalchemy.engine.Result = await self.execute(query)
+        return execution_result
+
     async def ns_text_del(self, ns: str, key: str, commit: bool = True) -> 
None:
         stmt = sql.delete(models.TextValue).where(
             models.validate_instrumented_attribute(models.TextValue.ns) == ns,
@@ -320,6 +353,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
         sboms: Opt[list[str]] = NOT_SET,
         release_policy_id: Opt[int] = NOT_SET,
         votes: Opt[list[models.VoteEntry]] = NOT_SET,
+        latest_revision_number: Opt[str | None] = NOT_SET,
         _project: bool = False,
         _release_policy: bool = False,
         _committee: bool = False,
@@ -348,6 +382,11 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
             query = query.where(models.Release.release_policy_id == 
release_policy_id)
         if is_defined(votes):
             query = query.where(models.Release.votes == votes)
+        if is_defined(latest_revision_number):
+            # Must define the subquery explicitly, mirroring the 
column_property
+            # In other words, this doesn't work:
+            # query = query.where(models.Release.latest_revision_number == 
latest_revision_number)
+            query = query.where(models.latest_revision_number_query() == 
latest_revision_number)
 
         if _project:
             query = query.options(select_in_load(models.Release.project))
@@ -538,6 +577,7 @@ async def create_async_engine(app_config: 
type[config.AppConfig]) -> sqlalchemy.
         await conn.execute(sql.text("PRAGMA cache_size=-64000"))
         await conn.execute(sql.text("PRAGMA foreign_keys=ON"))
         await conn.execute(sql.text("PRAGMA busy_timeout=5000"))
+        await conn.execute(sql.text("PRAGMA strict=ON"))
 
     return engine
 
@@ -622,6 +662,18 @@ def is_undefined(v: object | NotSet) -> TypeGuard[NotSet]:
     return isinstance(v, NotSet)
 
 
[email protected]
+def log_queries() -> Iterator[None]:
+    """A context manager to temporarily enable global query logging."""
+    global global_log_query
+    original_global_log_query_state = global_log_query
+    global_log_query = True
+    try:
+        yield
+    finally:
+        global_log_query = original_global_log_query_state
+
+
 # async def recent_tasks(data: Session, release_name: str, file_path: str, 
modified: int) -> dict[str, models.Task]:
 #     """Get the most recent task for each task type for a specific file."""
 #     tasks = await data.task(
@@ -664,9 +716,8 @@ def select_in_load_nested(parent: Any, *descendants: Any) 
-> orm.strategy_option
     return result
 
 
-def session() -> Session:
+def session(log_queries: bool | None = None) -> Session:
     """Create a new asynchronous database session."""
-
     # FIXME: occasionally you see this in the console output
     # <sys>:0: SAWarning: The garbage collector is trying to clean up 
non-checked-in connection <AdaptedConnection
     # <Connection(Thread-291, started daemon 138838634661440)>>, which will be 
dropped, as it cannot be safely
@@ -680,10 +731,15 @@ def session() -> Session:
 
     # from FastAPI documentation: 
https://fastapi-users.github.io/fastapi-users/latest/configuration/databases/sqlalchemy/
 
+    global _global_atr_sessionmaker
     if _global_atr_sessionmaker is None:
-        raise RuntimeError("database not initialized")
+        raise RuntimeError("Call db.init_database or 
db.init_database_for_worker first, before calling db.session")
+
+    if log_queries is not None:
+        session_instance = 
util.validate_as_type(_global_atr_sessionmaker(log_queries=log_queries), 
Session)
     else:
-        return util.validate_as_type(_global_atr_sessionmaker(), Session)
+        session_instance = util.validate_as_type(_global_atr_sessionmaker(), 
Session)
+    return session_instance
 
 
 async def shutdown_database() -> None:
diff --git a/atr/db/models.py b/atr/db/models.py
index 7feb2ce..6414abb 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -28,6 +28,7 @@ import pydantic
 import sqlalchemy
 import sqlalchemy.event as event
 import sqlalchemy.orm as orm
+import sqlalchemy.sql.expression as expression
 import sqlmodel
 
 import atr.schema as schema
@@ -342,6 +343,11 @@ class Revision(sqlmodel.SQLModel, table=True):
 
     description: str | None = sqlmodel.Field(default=None)
 
+    __table_args__ = (
+        sqlmodel.UniqueConstraint("release_name", "seq", 
name="uq_revision_release_seq"),
+        sqlmodel.UniqueConstraint("release_name", "number", 
name="uq_revision_release_number"),
+    )
+
 
 @event.listens_for(Revision, "before_insert")
 def populate_revision_sequence_and_name(
@@ -381,6 +387,7 @@ def populate_revision_sequence_and_name(
         # Do NOT set revision.parent directly here
 
     # Recalculate the Revision.name
+    # This field has a unique constraint, which eliminates the potential for 
race conditions
     revision.name = revision_name(revision.release_name, revision.number)
 
 
@@ -558,6 +565,19 @@ class Release(sqlmodel.SQLModel, table=True):
             raise ValueError("Latest revision number is not a str or None")
         return number
 
+    # NOTE: This does not work
+    # But it we set it with Release.latest_revision_number_query = ..., it 
might work
+    # Not clear that we'd want to do that, though
+    # @property
+    # def latest_revision_number_query(self) -> expression.ScalarSelect[str]:
+    #     return (
+    #         sqlmodel.select(validate_instrumented_attribute(Revision.number))
+    #         .where(validate_instrumented_attribute(Revision.release_name) == 
Release.name)
+    #         .order_by(validate_instrumented_attribute(Revision.seq).desc())
+    #         .limit(1)
+    #         .scalar_subquery()
+    #     )
+
 
 # https://github.com/fastapi/sqlmodel/issues/240#issuecomment-2074161775
 Release._latest_revision_number = orm.column_property(
@@ -570,6 +590,20 @@ Release._latest_revision_number = orm.column_property(
 )
 
 
+def latest_revision_number_query(release_name: str | None = None) -> 
expression.ScalarSelect[str]:
+    if release_name is None:
+        query_release_name = Release.name
+    else:
+        query_release_name = release_name
+    return (
+        sqlmodel.select(validate_instrumented_attribute(Revision.number))
+        .where(validate_instrumented_attribute(Revision.release_name) == 
query_release_name)
+        .order_by(validate_instrumented_attribute(Revision.seq).desc())
+        .limit(1)
+        .scalar_subquery()
+    )
+
+
 class SSHKey(sqlmodel.SQLModel, table=True):
     fingerprint: str = sqlmodel.Field(primary_key=True)
     key: str
diff --git a/atr/revision.py b/atr/revision.py
index e733d6d..54ac770 100644
--- a/atr/revision.py
+++ b/atr/revision.py
@@ -69,6 +69,7 @@ async def create_and_manage(
             description=description,
         )
         data.add(new_revision)
+        # TODO: Add a retry loop here in case of simultaneous creation of 
revisions?
         await data.commit()
 
         # After commit, new_revision has its .name, .seq, and .number 
populated by the listener
@@ -130,15 +131,16 @@ async def create_and_manage(
 
 async def latest_info(project_name: str, version_name: str) -> tuple[str, str, 
datetime.datetime] | None:
     """Get the name, editor, and timestamp of the latest revision."""
+    release_name = models.release_name(project_name, version_name)
     async with db.session() as data:
         # TODO: No need to get release here
         # Just use maximum seq from revisions
-        release = await data.release(name=models.release_name(project_name, 
version_name), _project=True).demand(
-            RuntimeError("Release does not exist")
+        release = await data.release(name=release_name, _project=True).demand(
+            RuntimeError(f"Release {release_name} does not exist")
         )
         if release.latest_revision_number is None:
             return None
-        revision = await data.revision(release_name=release.name, 
number=release.latest_revision_number).get()
+        revision = await data.revision(release_name=release_name, 
number=release.latest_revision_number).get()
         if not revision:
             return None
     return revision.number, revision.asfuid, revision.created
diff --git a/atr/routes/__init__.py b/atr/routes/__init__.py
index 9d0d1a1..55f9656 100644
--- a/atr/routes/__init__.py
+++ b/atr/routes/__init__.py
@@ -211,6 +211,7 @@ class CommitterSession:
         project_name: str,
         version_name: str,
         phase: models.ReleasePhase | db.NotSet | None = db.NOT_SET,
+        latest_revision_number: str | db.NotSet | None = db.NOT_SET,
         data: db.Session | None = None,
         with_committee: bool = False,
         with_project: bool = True,
@@ -231,6 +232,7 @@ class CommitterSession:
                 release = await data.release(
                     name=release_name,
                     phase=phase_value,
+                    latest_revision_number=latest_revision_number,
                     _committee=with_committee,
                     _project=with_project,
                     _tasks=with_tasks,
@@ -240,6 +242,7 @@ class CommitterSession:
             release = await data.release(
                 name=release_name,
                 phase=phase_value,
+                latest_revision_number=latest_revision_number,
                 _committee=with_committee,
                 _project=with_project,
                 _tasks=with_tasks,
diff --git a/atr/routes/announce.py b/atr/routes/announce.py
index ed8bfe0..ad5e680 100644
--- a/atr/routes/announce.py
+++ b/atr/routes/announce.py
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Any, Protocol
 import aiofiles.os
 import aioshutil
 import quart
+import sqlmodel
 import werkzeug.wrappers.response as response
 import wtforms
 
@@ -120,15 +121,21 @@ async def selected_post(
 
     subject = str(announce_form.subject.data)
     body = str(announce_form.body.data)
+    preview_revision_number = str(announce_form.preview_revision.data)
 
     source: str = ""
     target: str = ""
     source_base: pathlib.Path | None = None
 
-    async with db.session() as data:
+    async with db.session(log_queries=True) as data:
         try:
             release = await session.release(
-                project_name, version_name, 
phase=models.ReleasePhase.RELEASE_PREVIEW, with_revisions=True, data=data
+                project_name,
+                version_name,
+                phase=models.ReleasePhase.RELEASE_PREVIEW,
+                latest_revision_number=preview_revision_number,
+                with_revisions=True,
+                data=data,
             )
 
             test_list = "user-tests"
@@ -168,16 +175,7 @@ async def selected_post(
             source_base = util.release_directory_base(release)
             source = str(source_base / release.unwrap_revision_number)
 
-            # TODO: We should update only if the announcement email was sent
-            # That would require moving this, and the filesystem operations, 
into a task
-            release.phase = models.ReleasePhase.RELEASE
-            # Delete all revisions associated with this release
-            for revision in release.revisions:
-                await data.delete(revision)
-            # Essential to set revisions to [], otherwise release.revisions is 
still populated
-            # And util.release_directory() below checks for it
-            release.revisions = []
-            release.released = datetime.datetime.now(datetime.UTC)
+            await _promote(release, data, preview_revision_number)
             await data.commit()
 
         except (routes.FlashError, Exception) as e:
@@ -193,7 +191,7 @@ async def selected_post(
         # This must come after updating the release object
         # Do not put it in the data block after data.commit()
         # Otherwise util.release_directory() will not work
-        release = await 
data.release(name=release.name).demand(RuntimeError("Release does not exist"))
+        release = await 
data.release(name=release.name).demand(RuntimeError(f"Release {release.name} 
does not exist"))
         target = str(util.release_directory(release))
         if await aiofiles.os.path.exists(target):
             raise routes.FlashError("Release already exists")
@@ -245,3 +243,31 @@ async def _create_announce_form_instance(
 
     form_instance = await AnnounceForm.create_form(data=data)
     return form_instance
+
+
+async def _promote(release: models.Release, data: db.Session, 
preview_revision_number: str) -> None:
+    """Promote a release preview to a release and delete its old revisions."""
+    via = models.validate_instrumented_attribute
+
+    update_stmt = (
+        sqlmodel.update(models.Release)
+        .where(
+            via(models.Release.name) == release.name,
+            via(models.Release.phase) == models.ReleasePhase.RELEASE_PREVIEW,
+            models.latest_revision_number_query() == preview_revision_number,
+        )
+        .values(
+            stage=models.ReleaseStage.RELEASE,
+            phase=models.ReleasePhase.RELEASE,
+            released=datetime.datetime.now(datetime.UTC),
+        )
+    )
+    update_result = await data.execute_query(update_stmt)
+    # Avoid a type error with update_result.rowcount
+    # Could not find another way to do it, other than using a Protocol
+    rowcount: int = getattr(update_result, "rowcount", 0)
+    if rowcount != 1:
+        raise RuntimeError("A newer revision appeared, please refresh and try 
again.")
+
+    delete_revisions_stmt = 
sqlmodel.delete(models.Revision).where(via(models.Revision.release_name) == 
release.name)
+    await data.execute_query(delete_revisions_stmt)
diff --git a/atr/routes/voting.py b/atr/routes/voting.py
index 94e2f9a..e8529d6 100644
--- a/atr/routes/voting.py
+++ b/atr/routes/voting.py
@@ -20,6 +20,7 @@ import datetime
 import aiofiles.os
 import asfquart.base as base
 import quart
+import sqlmodel
 import werkzeug.wrappers.response as response
 import wtforms
 
@@ -54,6 +55,10 @@ async def selected_revision(
         committee = util.unwrap(release.committee)
         permitted_recipients = util.permitted_recipients(session.uid)
 
+        selected_revision_number = release.latest_revision_number
+        if selected_revision_number is None:
+            return await session.redirect(compose.selected, error="No revision 
found for this release")
+
         if release.release_policy:
             min_hours = release.release_policy.min_hours
         else:
@@ -121,7 +126,7 @@ async def selected_revision(
                 raise base.ASFQuartException("Invalid mailing list choice", 
errorcode=400)
 
             # This sets the phase to RELEASE_CANDIDATE
-            error = await _promote(data, release.name)
+            error = await _promote(data, release.name, 
selected_revision_number)
             if error:
                 return await session.redirect(root.index, error=error)
 
@@ -150,10 +155,7 @@ async def selected_revision(
                 ).model_dump(),
                 release_name=release.name,
             )
-
             data.add(task)
-            # Flush to get the task ID
-            await data.flush()
             await data.commit()
 
             # TODO: We should log all outgoing email and the session so that 
users can confirm
@@ -189,30 +191,45 @@ async def _keys_warning(
 async def _promote(
     data: db.Session,
     release_name: str,
+    selected_revision_number: str,
 ) -> str | None:
     """Promote a release candidate draft to a new phase."""
-    # Get the release
     # TODO: Use session.release here
-    release = await data.release(name=release_name, _project=True).demand(
+    release_for_pre_checks = await data.release(name=release_name, 
_project=True).demand(
         routes.FlashError("Release candidate draft not found")
     )
 
     # Verify that it's in the correct phase
-    if release.phase != models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
+    # The atomic update below will also check this
+    if release_for_pre_checks.phase != 
models.ReleasePhase.RELEASE_CANDIDATE_DRAFT:
         return "This release is not in the candidate draft phase"
 
-    # Count how many files are in the source directory
-    file_count = await util.number_of_release_files(release)
+    # Check that there is at least one file in the draft
+    # This is why we require _project=True above
+    file_count = await util.number_of_release_files(release_for_pre_checks)
     if file_count == 0:
         return "This candidate draft is empty, containing no files"
 
-    # Promote it to the target phase
-    # TODO: Obtain a lock for this
-    # NOTE: The functionality for skipping phases has been removed
-    release.stage = models.ReleaseStage.RELEASE_CANDIDATE
-    release.phase = models.ReleasePhase.RELEASE_CANDIDATE
+    # Promote it to RELEASE_CANDIDATE
+    # NOTE: We previously allowed skipping phases, but removed that 
functionality
+    # We don't need a lock here because we use an atomic update
+    via = models.validate_instrumented_attribute
+    stmt = (
+        sqlmodel.update(models.Release)
+        .where(
+            via(models.Release.name) == release_name,
+            via(models.Release.phase) == 
models.ReleasePhase.RELEASE_CANDIDATE_DRAFT,
+            models.latest_revision_number_query() == selected_revision_number,
+        )
+        .values(
+            stage=models.ReleaseStage.RELEASE_CANDIDATE,
+            phase=models.ReleasePhase.RELEASE_CANDIDATE,
+        )
+    )
 
-    # We updated the release
+    result = await data.execute(stmt)
+    if result.rowcount != 1:
+        await data.rollback()
+        return "A newer revision appeared, please refresh and try again."
     await data.commit()
-
     return None
diff --git a/migrations/env.py b/migrations/env.py
index a4d1386..6004efb 100644
--- a/migrations/env.py
+++ b/migrations/env.py
@@ -1,5 +1,4 @@
 import datetime
-import logging.config
 import os
 import re
 import subprocess
@@ -34,8 +33,8 @@ alembic_config = alembic.context.config
 
 # Interpret the config file for Python logging.
 # This line sets up loggers basically.
-if alembic_config.config_file_name is not None:
-    logging.config.fileConfig(alembic_config.config_file_name)
+# if alembic_config.config_file_name is not None:
+#     logging.config.fileConfig(alembic_config.config_file_name)
 
 # The SQLModel.metadata object as populated by the ATR models
 target_metadata = sqlmodel.SQLModel.metadata
diff --git a/migrations/versions/0001_2025.05.15_32c59be6.py 
b/migrations/versions/0001_2025.05.15_1d3ee5a0.py
similarity index 97%
rename from migrations/versions/0001_2025.05.15_32c59be6.py
rename to migrations/versions/0001_2025.05.15_1d3ee5a0.py
index 9437269..c0c4988 100644
--- a/migrations/versions/0001_2025.05.15_32c59be6.py
+++ b/migrations/versions/0001_2025.05.15_1d3ee5a0.py
@@ -1,8 +1,8 @@
 """Use the existing ATR schema
 
-Revision ID: 0001_2025.05.15_32c59be6
+Revision ID: 0001_2025.05.15_1d3ee5a0
 Revises:
-Create Date: 2025-05-15 15:44:04.208248+00:00
+Create Date: 2025-05-15 19:39:20.865550+00:00
 """
 
 from collections.abc import Sequence
@@ -13,7 +13,7 @@ from alembic import op
 import atr.db.models
 
 # Revision identifiers, used by Alembic
-revision: str = "0001_2025.05.15_32c59be6"
+revision: str = "0001_2025.05.15_1d3ee5a0"
 down_revision: str | None = None
 branch_labels: str | Sequence[str] | None = None
 depends_on: str | Sequence[str] | None = None
@@ -202,6 +202,8 @@ def upgrade() -> None:
         sa.ForeignKeyConstraint(["release_name"], ["release.name"], 
name=op.f("fk_revision_release_name_release")),
         sa.PrimaryKeyConstraint("name", name=op.f("pk_revision")),
         sa.UniqueConstraint("name", name=op.f("uq_revision_name")),
+        sa.UniqueConstraint("release_name", "number", 
name="uq_revision_release_number"),
+        sa.UniqueConstraint("release_name", "seq", 
name="uq_revision_release_seq"),
     )
     op.create_table(
         "task",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to