This is an automated email from the ASF dual-hosted git repository.

arm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git

commit a9f6090275d49dee785110b49f6b36a95054e676
Author: Alastair McFarlane <[email protected]>
AuthorDate: Thu Jan 15 16:13:31 2026 +0000

    Add scheduled column for tasks, allow asf_uid to be passed in task arguments
---
 atr/models/sql.py                               |   4 +
 atr/server.py                                   |  39 +++-----
 atr/tasks/__init__.py                           |  14 +--
 atr/tasks/gha.py                                |  14 +--
 atr/tasks/metadata.py                           |   6 +-
 atr/worker.py                                   | 118 +++++++++++++-----------
 migrations/versions/0040_2026.01.15_31d91cc5.py |  31 +++++++
 7 files changed, 130 insertions(+), 96 deletions(-)

diff --git a/atr/models/sql.py b/atr/models/sql.py
index 0a80912..5891a04 100644
--- a/atr/models/sql.py
+++ b/atr/models/sql.py
@@ -358,6 +358,10 @@ class Task(sqlmodel.SQLModel, table=True):
         default_factory=lambda: datetime.datetime.now(datetime.UTC),
         sa_column=sqlalchemy.Column(UTCDateTime, index=True),
     )
+    scheduled: datetime.datetime = sqlmodel.Field(
+        default=None,
+        sa_column=sqlalchemy.Column(UTCDateTime, index=True),
+    )
     started: datetime.datetime | None = sqlmodel.Field(
         default=None,
         sa_column=sqlalchemy.Column(UTCDateTime),
diff --git a/atr/server.py b/atr/server.py
index d64f7d1..9dde3d0 100644
--- a/atr/server.py
+++ b/atr/server.py
@@ -208,7 +208,8 @@ def _app_setup_lifecycle(app: base.QuartApp) -> None:
         await worker_manager.start()
 
         # Register recurring tasks (metadata updates, workflow status checks, 
etc.)
-        await _register_recurrent_tasks()
+        scheduler_task = asyncio.create_task(_register_recurrent_tasks())
+        app.extensions["scheduler_task"] = scheduler_task
 
         await _initialise_test_environment()
 
@@ -250,13 +251,13 @@ def _app_setup_lifecycle(app: base.QuartApp) -> None:
         await worker_manager.stop()
 
         # Stop the metadata scheduler
-        # metadata_scheduler = app.extensions.get("metadata_scheduler")
-        # if metadata_scheduler:
-        #     metadata_scheduler.cancel()
-        #     try:
-        #         await metadata_scheduler
-        #     except asyncio.CancelledError:
-        #         ...
+        scheduler_task = app.extensions.get("scheduler_task")
+        if scheduler_task:
+            scheduler_task.cancel()
+            try:
+                await scheduler_task
+            except asyncio.CancelledError:
+                ...
 
         ssh_server = app.extensions.get("ssh_server")
         if ssh_server:
@@ -514,31 +515,15 @@ async def _initialise_test_environment() -> None:
             await data.commit()
 
 
-#
-# async def _metadata_update_scheduler() -> None:
-#     """Periodically schedule remote metadata updates."""
-#     # Wait one minute to allow the server to start
-#     await asyncio.sleep(60)
-#
-#     while True:
-#         try:
-#             task = await tasks.metadata_update(asf_uid="system")
-#             log.info(f"Scheduled remote metadata update with ID {task.id}")
-#         except Exception as e:
-#             log.exception(f"Failed to schedule remote metadata update: 
{e!s}")
-#
-#         # Schedule next update in 24 hours
-#         await asyncio.sleep(86400)
-
-
 async def _register_recurrent_tasks() -> None:
     """Schedule recurring tasks"""
-    # Wait one minute to allow the server to start
-    await asyncio.sleep(30)
+    # Start scheduled tasks 5 min after server start
+    await asyncio.sleep(300)
     try:
         await tasks.clear_scheduled()
         metadata = await tasks.metadata_update(asf_uid="system", 
schedule_next=True)
         log.info(f"Scheduled remote metadata update with ID {metadata.id}")
+        await asyncio.sleep(60)
         workflow = await tasks.workflow_update(asf_uid="system", 
schedule_next=True)
         log.info(f"Scheduled workflow status update with ID {workflow.id}")
 
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 99f2328..b575a83 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -74,7 +74,7 @@ async def clear_scheduled(caller_data: db.Session | None = 
None):
                 ]
             ),
             via(sql.Task.status) == sql.TaskStatus.QUEUED,
-            via(sql.Task.added) > now,
+            sqlmodel.or_(via(sql.Task.scheduled).is_(None), 
via(sql.Task.scheduled) > now),
         )
 
         await data.execute(delete_stmt)
@@ -181,9 +181,9 @@ async def metadata_update(
     schedule_next: bool = False,
 ) -> sql.Task:
     """Queue a metadata update task."""
-    args = metadata.Update(asf_uid=asf_uid, next_schedule=0)
+    args = metadata.Update(asf_uid=asf_uid, next_schedule_seconds=0)
     if schedule_next:
-        args.next_schedule = 60 * 24
+        args.next_schedule_seconds = 60 * 60 * 24
     async with db.ensure_session(caller_data) as data:
         task = sql.Task(
             status=sql.TaskStatus.QUEUED,
@@ -194,7 +194,7 @@ async def metadata_update(
             primary_rel_path=None,
         )
         if schedule:
-            task.added = schedule
+            task.scheduled = schedule
         data.add(task)
         await data.commit()
         await data.flush()
@@ -302,9 +302,9 @@ async def workflow_update(
     schedule_next: bool = False,
 ) -> sql.Task:
     """Queue a workflow status update task."""
-    args = gha.WorkflowStatusCheck(next_schedule=0, run_id=0)
+    args = gha.WorkflowStatusCheck(next_schedule_seconds=0, run_id=0)
     if schedule_next:
-        args.next_schedule = 2
+        args.next_schedule_seconds = 2 * 60
     async with db.ensure_session(caller_data) as data:
         task = sql.Task(
             status=sql.TaskStatus.QUEUED,
@@ -315,7 +315,7 @@ async def workflow_update(
             primary_rel_path=None,
         )
         if schedule:
-            task.added = schedule
+            task.scheduled = schedule
         data.add(task)
         await data.commit()
         await data.flush()
diff --git a/atr/tasks/gha.py b/atr/tasks/gha.py
index 9ccaef8..1e8eda1 100644
--- a/atr/tasks/gha.py
+++ b/atr/tasks/gha.py
@@ -63,7 +63,7 @@ class DistributionWorkflow(schema.Strict):
 
 class WorkflowStatusCheck(schema.Strict):
     run_id: int | None = schema.description("Run ID")
-    next_schedule: int = pydantic.Field(default=0, description="The next 
scheduled time (in minutes)")
+    next_schedule_seconds: int = pydantic.Field(default=0, description="The 
next scheduled time")
 
 
 @checks.with_model(DistributionWorkflow)
@@ -123,7 +123,7 @@ async def trigger_workflow(args: DistributionWorkflow, *, 
task_id: int | None =
 
 
 @checks.with_model(WorkflowStatusCheck)
-async def status_check(args: WorkflowStatusCheck) -> 
DistributionWorkflowStatus:
+async def status_check(args: WorkflowStatusCheck, asf_uid: str) -> 
DistributionWorkflowStatus:
     """Check remote workflow statuses."""
 
     headers = {"Accept": "application/vnd.github+json", "Authorization": 
f"Bearer {config.get().GITHUB_TOKEN}"}
@@ -182,7 +182,7 @@ async def status_check(args: WorkflowStatusCheck) -> 
DistributionWorkflowStatus:
         )
 
         # Schedule next update
-        await _schedule_next(args)
+        await _schedule_next(args, asf_uid)
 
         return results.DistributionWorkflowStatus(
             kind="distribution_workflow_status",
@@ -274,10 +274,10 @@ async def _request_and_retry(
     return None
 
 
-async def _schedule_next(args: WorkflowStatusCheck):
-    if args.next_schedule:
-        next_schedule = datetime.datetime.now(datetime.UTC) + 
datetime.timedelta(minutes=args.next_schedule)
-        await tasks.workflow_update("system", schedule=next_schedule, 
schedule_next=True)
+async def _schedule_next(args: WorkflowStatusCheck, asf_uid: str) -> None:
+    if args.next_schedule_seconds:
+        next_schedule = datetime.datetime.now(datetime.UTC) + 
datetime.timedelta(seconds=args.next_schedule_seconds)
+        await tasks.workflow_update(asf_uid, schedule=next_schedule, 
schedule_next=True)
         log.info(
             f"Scheduled next workflow status update for: 
{next_schedule.strftime('%Y-%m-%d %H:%M:%S')}",
         )
diff --git a/atr/tasks/metadata.py b/atr/tasks/metadata.py
index ba2b6e5..81e070a 100644
--- a/atr/tasks/metadata.py
+++ b/atr/tasks/metadata.py
@@ -32,7 +32,7 @@ class Update(schema.Strict):
     """Arguments for the task to update metadata from remote data sources."""
 
     asf_uid: str = schema.description("The ASF UID of the user triggering the 
update")
-    next_schedule: int = pydantic.Field(default=0, description="The next 
scheduled time (in minutes)")
+    next_schedule_seconds: int = pydantic.Field(default=0, description="The 
next scheduled time")
 
 
 class UpdateError(Exception):
@@ -52,8 +52,8 @@ async def update(args: Update) -> results.Results | None:
         )
 
         # Schedule next update
-        if args.next_schedule and args.next_schedule > 0:
-            next_schedule = datetime.datetime.now(datetime.UTC) + 
datetime.timedelta(minutes=args.next_schedule)
+        if args.next_schedule_seconds and args.next_schedule_seconds > 0:
+            next_schedule = datetime.datetime.now(datetime.UTC) + 
datetime.timedelta(seconds=args.next_schedule_seconds)
             await tasks.metadata_update(args.asf_uid, schedule=next_schedule, 
schedule_next=True)
             log.info(
                 f"Scheduled next metadata update for: 
{next_schedule.strftime('%Y-%m-%d %H:%M:%S')}",
diff --git a/atr/worker.py b/atr/worker.py
index db3a9ac..aa19e28 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -28,6 +28,7 @@ import inspect
 import os
 import signal
 import traceback
+from collections.abc import Awaitable, Callable
 from typing import Any, Final
 
 import sqlmodel
@@ -103,12 +104,13 @@ def _setup_logging() -> None:
 # Task functions
 
 
-async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | 
None:
+async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any], 
str] | None:
     """
     Attempt to claim the oldest unclaimed task.
     Returns (task_id, task_type, task_args) if successful.
     Returns None if no tasks are available.
     """
+    via = sql.validate_instrumented_attribute
     async with db.session() as data:
         async with data.begin():
             # Get the ID of the oldest queued task
@@ -117,10 +119,12 @@ async def _task_next_claim() -> tuple[int, str, list[str] 
| dict[str, Any]] | No
                 .where(
                     sqlmodel.and_(
                         sql.Task.status == task.QUEUED,
-                        sql.Task.added <= datetime.datetime.now(datetime.UTC) 
- datetime.timedelta(seconds=2),
+                        sqlmodel.or_(
+                            via(sql.Task.scheduled).is_(None), 
sql.Task.scheduled <= datetime.datetime.now(datetime.UTC)
+                        ),
                     )
                 )
-                
.order_by(sql.validate_instrumented_attribute(sql.Task.added).asc())
+                .order_by(via(sql.Task.added).asc())
                 .limit(1)
             )
 
@@ -135,6 +139,7 @@ async def _task_next_claim() -> tuple[int, str, list[str] | 
dict[str, Any]] | No
                     sql.validate_instrumented_attribute(sql.Task.id),
                     sql.validate_instrumented_attribute(sql.Task.task_type),
                     sql.validate_instrumented_attribute(sql.Task.task_args),
+                    sql.validate_instrumented_attribute(sql.Task.asf_uid),
                 )
             )
 
@@ -142,14 +147,14 @@ async def _task_next_claim() -> tuple[int, str, list[str] 
| dict[str, Any]] | No
             claimed_task = result.first()
 
             if claimed_task:
-                task_id, task_type, task_args = claimed_task
+                task_id, task_type, task_args, asf_uid = claimed_task
                 log.info(f"Claimed task {task_id} ({task_type}) with args 
{task_args}")
-                return task_id, task_type, task_args
+                return task_id, task_type, task_args, asf_uid
 
             return None
 
 
-async def _task_process(task_id: int, task_type: str, task_args: list[str] | 
dict[str, Any]) -> None:
+async def _task_process(task_id: int, task_type: str, task_args: list[str] | 
dict[str, Any], asf_uid: str) -> None:
     """Process a claimed task."""
     log.info(f"Processing task {task_id} ({task_type}) with raw args 
{task_args}")
     try:
@@ -167,52 +172,15 @@ async def _task_process(task_id: int, task_type: str, 
task_args: list[str] | dic
 
         # Check whether the handler is a check handler
         if (len(params) == 1) and (params[0].annotation == 
checks.FunctionArguments):
-            log.debug(f"Handler {handler.__name__} expects 
checks.FunctionArguments, fetching full task details")
-            async with db.session() as data:
-                task_obj = await data.task(id=task_id).demand(
-                    ValueError(f"Task {task_id} disappeared during processing")
-                )
-
-            # Validate required fields from the Task object itself
-            if task_obj.project_name is None:
-                raise ValueError(f"Task {task_id} is missing required 
project_name")
-            if task_obj.version_name is None:
-                raise ValueError(f"Task {task_id} is missing required 
version_name")
-            if task_obj.revision_number is None:
-                raise ValueError(f"Task {task_id} is missing required 
revision_number")
-
-            if not isinstance(task_args, dict):
-                raise TypeError(
-                    f"Task {task_id} ({task_type}) has non-dict raw args"
-                    f" {task_args} which should represent keyword_args"
-                )
-
-            async def recorder_factory() -> checks.Recorder:
-                return await checks.Recorder.create(
-                    checker=handler,
-                    project_name=task_obj.project_name or "",
-                    version_name=task_obj.version_name or "",
-                    revision_number=task_obj.revision_number or "",
-                    primary_rel_path=task_obj.primary_rel_path,
-                )
-
-            function_arguments = checks.FunctionArguments(
-                recorder=recorder_factory,
-                asf_uid=task_obj.asf_uid,
-                project_name=task_obj.project_name or "",
-                version_name=task_obj.version_name or "",
-                revision_number=task_obj.revision_number,
-                primary_rel_path=task_obj.primary_rel_path,
-                extra_args=task_args,
-            )
-            log.debug(f"Calling {handler.__name__} with structured arguments: 
{function_arguments}")
-            handler_result = await handler(function_arguments)
+            handler_result = await _execute_check_task(handler, task_args, 
task_id, task_type)
         else:
             # Otherwise, it's not a check handler
-            if sig.parameters.get("task_id") is None:
-                handler_result = await handler(task_args)
-            else:
-                handler_result = await handler(task_args, task_id=task_id)
+            additional_kwargs = {}
+            if sig.parameters.get("task_id") is not None:
+                additional_kwargs["task_id"] = task_id
+            if sig.parameters.get("asf_uid") is not None:
+                additional_kwargs["asf_uid"] = asf_uid
+            handler_result = await handler(task_args, **additional_kwargs)
 
         task_results = handler_result
         status = task.COMPLETED
@@ -226,6 +194,52 @@ async def _task_process(task_id: int, task_type: str, 
task_args: list[str] | dic
     await _task_result_process(task_id, task_results, status, error)
 
 
+async def _execute_check_task(
+    handler: Callable[..., Awaitable[results.Results | None]],
+    task_args: list[str] | dict[str, Any],
+    task_id: int,
+    task_type: str,
+) -> results.Results | None:
+    log.debug(f"Handler {handler.__name__} expects checks.FunctionArguments, 
fetching full task details")
+    async with db.session() as data:
+        task_obj = await data.task(id=task_id).demand(ValueError(f"Task 
{task_id} disappeared during processing"))
+
+    # Validate required fields from the Task object itself
+    if task_obj.project_name is None:
+        raise ValueError(f"Task {task_id} is missing required project_name")
+    if task_obj.version_name is None:
+        raise ValueError(f"Task {task_id} is missing required version_name")
+    if task_obj.revision_number is None:
+        raise ValueError(f"Task {task_id} is missing required revision_number")
+
+    if not isinstance(task_args, dict):
+        raise TypeError(
+            f"Task {task_id} ({task_type}) has non-dict raw args {task_args} 
which should represent keyword_args"
+        )
+
+    async def recorder_factory() -> checks.Recorder:
+        return await checks.Recorder.create(
+            checker=handler,
+            project_name=task_obj.project_name or "",
+            version_name=task_obj.version_name or "",
+            revision_number=task_obj.revision_number or "",
+            primary_rel_path=task_obj.primary_rel_path,
+        )
+
+    function_arguments = checks.FunctionArguments(
+        recorder=recorder_factory,
+        asf_uid=task_obj.asf_uid,
+        project_name=task_obj.project_name or "",
+        version_name=task_obj.version_name or "",
+        revision_number=task_obj.revision_number,
+        primary_rel_path=task_obj.primary_rel_path,
+        extra_args=task_args,
+    )
+    log.debug(f"Calling {handler.__name__} with structured arguments: 
{function_arguments}")
+    handler_result = await handler(function_arguments)
+    return handler_result
+
+
 async def _task_result_process(
     task_id: int, task_results: results.Results | None, status: 
sql.TaskStatus, error: str | None = None
 ) -> None:
@@ -255,8 +269,8 @@ async def _worker_loop_run() -> None:
         try:
             task = await _task_next_claim()
             if task:
-                task_id, task_type, task_args = task
-                await _task_process(task_id, task_type, task_args)
+                task_id, task_type, task_args, asf_uid = task
+                await _task_process(task_id, task_type, task_args, asf_uid)
                 processed += 1
                 # Only process max_to_process tasks and then exit
                 # This prevents memory leaks from accumulating
diff --git a/migrations/versions/0040_2026.01.15_31d91cc5.py 
b/migrations/versions/0040_2026.01.15_31d91cc5.py
new file mode 100644
index 0000000..ceac8fa
--- /dev/null
+++ b/migrations/versions/0040_2026.01.15_31d91cc5.py
@@ -0,0 +1,31 @@
+"""Add schedule column for tasks
+
+Revision ID: 0040_2026.01.15_31d91cc5
+Revises: 0039_2026.01.14_cd44f0ea
+Create Date: 2026-01-15 15:34:00.515650+00:00
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+from alembic import op
+
+import atr.models.sql
+
+# Revision identifiers, used by Alembic
+revision: str = "0040_2026.01.15_31d91cc5"
+down_revision: str | None = "0039_2026.01.14_cd44f0ea"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+
+def upgrade() -> None:
+    with op.batch_alter_table("task", schema=None) as batch_op:
+        batch_op.add_column(sa.Column("scheduled", 
atr.models.sql.UTCDateTime(timezone=True), nullable=True))
+        batch_op.create_index(batch_op.f("ix_task_scheduled"), ["scheduled"], 
unique=False)
+
+
+def downgrade() -> None:
+    with op.batch_alter_table("task", schema=None) as batch_op:
+        batch_op.drop_index(batch_op.f("ix_task_scheduled"))
+        batch_op.drop_column("scheduled")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to