This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ea060d7292 Remove findings from positional session check in Core Jobs 
(#67773)
7ea060d7292 is described below

commit 7ea060d72920f329766eb763c5468d1ba1dab00b
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun May 31 08:42:24 2026 +0200

    Remove findings from positional session check in Core Jobs (#67773)
    
    * Fix exceptions of positional session use in airflow-core jpbs_processing
    
    * Fix CI tests
---
 airflow-core/src/airflow/jobs/base_job_runner.py   |  4 ++--
 airflow-core/src/airflow/jobs/job.py               | 22 ++++++++------------
 .../src/airflow/jobs/scheduler_job_runner.py       | 24 +++++++++++-----------
 .../src/airflow/jobs/triggerer_job_runner.py       |  2 +-
 airflow-core/tests/unit/jobs/test_base_job.py      |  2 +-
 airflow-core/tests/unit/jobs/test_scheduler_job.py |  4 ++--
 .../ci/prek/known_provide_session_positional.txt   |  5 -----
 7 files changed, 27 insertions(+), 36 deletions(-)

diff --git a/airflow-core/src/airflow/jobs/base_job_runner.py 
b/airflow-core/src/airflow/jobs/base_job_runner.py
index 05671e2050a..8ddde316dbf 100644
--- a/airflow-core/src/airflow/jobs/base_job_runner.py
+++ b/airflow-core/src/airflow/jobs/base_job_runner.py
@@ -54,7 +54,7 @@ class BaseJobRunner:
         raise NotImplementedError()
 
     @provide_session
-    def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
+    def heartbeat_callback(self, *, session: Session = NEW_SESSION) -> None:
         """
         Execute callback during heartbeat.
 
@@ -63,7 +63,7 @@ class BaseJobRunner:
 
     @classmethod
     @provide_session
-    def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | None:
+    def most_recent_job(cls, *, session: Session = NEW_SESSION) -> Job | None:
         """Return the most recent job of this type, if any, based on last 
heartbeat received."""
         from airflow.jobs.job import most_recent_job
 
diff --git a/airflow-core/src/airflow/jobs/job.py 
b/airflow-core/src/airflow/jobs/job.py
index 4ab2defd81b..b19cd103de9 100644
--- a/airflow-core/src/airflow/jobs/job.py
+++ b/airflow-core/src/airflow/jobs/job.py
@@ -169,7 +169,7 @@ class Job(Base, LoggingMixin):
         )
 
     @provide_session
-    def kill(self, session: Session = NEW_SESSION) -> NoReturn:
+    def kill(self, *, session: Session = NEW_SESSION) -> NoReturn:
         """Handle on_kill callback and updates state in database."""
         try:
             self.on_kill()
@@ -187,9 +187,7 @@ class Job(Base, LoggingMixin):
         """Will be called when an external kill command is received."""
 
     @provide_session
-    def heartbeat(
-        self, heartbeat_callback: Callable[[Session], None], session: Session 
= NEW_SESSION
-    ) -> None:
+    def heartbeat(self, heartbeat_callback: Callable[..., None], *, session: 
Session = NEW_SESSION) -> None:
         """
         Update the job's entry in the database with the latest_heartbeat 
timestamp.
 
@@ -241,7 +239,7 @@ class Job(Base, LoggingMixin):
                     self.log.info("Heartbeat recovered after %.2f seconds", 
time_since_last_heartbeat)
                 # At this point, the DB has updated.
                 previous_heartbeat = self.latest_heartbeat
-                heartbeat_callback(session)
+                heartbeat_callback(session=session)
             self.log.debug("[heartbeat]")
             self.heartbeat_failed = False
         except OperationalError:
@@ -266,7 +264,7 @@ class Job(Base, LoggingMixin):
             self.latest_heartbeat = previous_heartbeat
 
     @provide_session
-    def prepare_for_execution(self, session: Session = NEW_SESSION):
+    def prepare_for_execution(self, *, session: Session = NEW_SESSION):
         """Prepare the job for execution."""
         stats.incr(self.__class__.__name__.lower() + "_start", 1, 1)
         self.state = JobState.RUNNING
@@ -276,7 +274,7 @@ class Job(Base, LoggingMixin):
         make_transient(self)
 
     @provide_session
-    def complete_execution(self, session: Session = NEW_SESSION):
+    def complete_execution(self, *, session: Session = NEW_SESSION):
         try:
             get_listener_manager().hook.before_stopping(component=self)
         except Exception:
@@ -287,7 +285,7 @@ class Job(Base, LoggingMixin):
         stats.incr(self.__class__.__name__.lower() + "_end", 1, 1)
 
     @provide_session
-    def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None:
+    def most_recent_job(self, *, session: Session = NEW_SESSION) -> Job | None:
         """Return the most recent job of this type, if any, based on last 
heartbeat received."""
         return most_recent_job(str(self.job_type), session=session)
 
@@ -316,7 +314,7 @@ class Job(Base, LoggingMixin):
 
 
 @provide_session
-def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | 
None:
+def most_recent_job(job_type: str, *, session: Session = NEW_SESSION) -> Job | 
None:
     """
     Return the most recent job of this type, if any, based on last heartbeat 
received.
 
@@ -340,7 +338,7 @@ def most_recent_job(job_type: str, session: Session = 
NEW_SESSION) -> Job | None
 
 @provide_session
 def run_job(
-    job: Job, execute_callable: Callable[[], int | None], session: Session = 
NEW_SESSION
+    job: Job, execute_callable: Callable[[], int | None], *, session: Session 
= NEW_SESSION
 ) -> int | None:
     """
     Run the job.
@@ -393,9 +391,7 @@ def execute_job(job: Job, execute_callable: Callable[[], 
int | None]) -> int | N
     return ret
 
 
-def perform_heartbeat(
-    job: Job, heartbeat_callback: Callable[[Session], None], 
only_if_necessary: bool
-) -> None:
+def perform_heartbeat(job: Job, heartbeat_callback: Callable[..., None], 
only_if_necessary: bool) -> None:
     """
     Perform heartbeat for the Job passed to it,optionally checking if it is 
necessary.
 
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 224659c4c4d..596946268a6 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -330,7 +330,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         self.scheduler_dag_bag = DBDagBag(load_op_links=False)
 
     @provide_session
-    def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
+    def heartbeat_callback(self, *, session: Session = NEW_SESSION) -> None:
         stats.incr("scheduler_heartbeat", 1, 1)
 
     def _get_current_dag(self, dag_id: str, session: Session) -> SerializedDAG 
| None:
@@ -1567,7 +1567,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         return None
 
     @provide_session
-    def _update_dag_run_state_for_paused_dags(self, session: Session = 
NEW_SESSION) -> None:
+    def _update_dag_run_state_for_paused_dags(self, *, session: Session = 
NEW_SESSION) -> None:
         try:
             paused_runs = list(
                 session.scalars(
@@ -1955,7 +1955,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         # END: create dagruns
 
     @provide_session
-    def _mark_backfills_complete(self, session: Session = NEW_SESSION) -> None:
+    def _mark_backfills_complete(self, *, session: Session = NEW_SESSION) -> 
None:
         """Mark completed backfills as completed."""
         self.log.debug("checking for completed backfills.")
         unfinished_states = (DagRunState.RUNNING, DagRunState.QUEUED)
@@ -2551,7 +2551,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             self.log.debug("callback is empty")
 
     @provide_session
-    def _handle_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> 
None:
+    def _handle_tasks_stuck_in_queued(self, *, session: Session = NEW_SESSION) 
-> None:
         """
         Handle the scenario where a task is queued for longer than 
`task_queued_timeout`.
 
@@ -2592,7 +2592,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         Otherwise, fail it.
         """
-        num_times_stuck = self._get_num_times_stuck_in_queued(ti, session)
+        num_times_stuck = self._get_num_times_stuck_in_queued(ti, 
session=session)
         if num_times_stuck < self._num_stuck_queued_retries:
             self.log.info("Task stuck in queued; will try to requeue. 
task_instance=%s", ti)
             session.add(
@@ -2684,7 +2684,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         )
 
     @provide_session
-    def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: 
Session = NEW_SESSION) -> int:
+    def _get_num_times_stuck_in_queued(self, ti: TaskInstance, *, session: 
Session = NEW_SESSION) -> int:
         """
         Check the Log table to see how many times a task instance has been 
stuck in queued.
 
@@ -2726,7 +2726,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
     previous_ti_metrics: dict[TaskInstanceState, dict[tuple[str, str, str], 
int]] = {}
 
     @provide_session
-    def _emit_ti_metrics(self, session: Session = NEW_SESSION) -> None:
+    def _emit_ti_metrics(self, *, session: Session = NEW_SESSION) -> None:
         metric_states = {State.SCHEDULED, State.QUEUED, State.RUNNING, 
State.DEFERRED}
         stmt = (
             select(
@@ -2771,13 +2771,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             self.previous_ti_metrics[state] = ti_metrics
 
     @provide_session
-    def _emit_running_dags_metric(self, session: Session = NEW_SESSION) -> 
None:
+    def _emit_running_dags_metric(self, *, session: Session = NEW_SESSION) -> 
None:
         stmt = select(func.count()).select_from(DagRun).where(DagRun.state == 
DagRunState.RUNNING)
         running_dags = float(session.scalar(stmt) or 0)
         stats.gauge("scheduler.dagruns.running", running_dags)
 
     @provide_session
-    def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None:
+    def _emit_pool_metrics(self, *, session: Session = NEW_SESSION) -> None:
         from airflow.models.pool import Pool
 
         pools = Pool.slots_stats(session=session)
@@ -2810,7 +2810,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             )
 
     @provide_session
-    def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> 
int:
+    def adopt_or_reset_orphaned_tasks(self, *, session: Session = NEW_SESSION) 
-> int:
         """
         Adopt or reset any TaskInstance in resettable state if its 
SchedulerJob is no longer running.
 
@@ -2912,7 +2912,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
     @provide_session
     def check_trigger_timeouts(
-        self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION
+        self, max_retries: int = MAX_DB_RETRIES, *, session: Session = 
NEW_SESSION
     ) -> None:
         """Mark any "deferred" task as failed if the trigger or execution 
timeout has passed."""
         for attempt in run_with_db_retries(max_retries, logger=self.log):
@@ -3092,7 +3092,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         )
 
     @provide_session
-    def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:
+    def _update_asset_orphanage(self, *, session: Session = NEW_SESSION) -> 
None:
         """
         Check assets orphanization and update their active entry.
 
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py 
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 38f67a9d43e..e8e1f66e787 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -204,7 +204,7 @@ class TriggererJobRunner(BaseJobRunner, LoggingMixin):
 
     @classmethod
     @provide_session
-    def is_needed(cls, session) -> bool:
+    def is_needed(cls, *, session: Session) -> bool:
         """
         Test if the triggerer job needs to be run (i.e., if there are triggers 
in the trigger table).
 
diff --git a/airflow-core/tests/unit/jobs/test_base_job.py 
b/airflow-core/tests/unit/jobs/test_base_job.py
index aca8c6e7c82..a38956c61bd 100644
--- a/airflow-core/tests/unit/jobs/test_base_job.py
+++ b/airflow-core/tests/unit/jobs/test_base_job.py
@@ -268,7 +268,7 @@ class TestJob:
             hb_callback = Mock()
             job.heartbeat(heartbeat_callback=hb_callback)
 
-            hb_callback.assert_called_once_with(ANY)
+            hb_callback.assert_called_once_with(session=ANY)
 
             hb_callback.reset_mock()
             perform_heartbeat(job=job, heartbeat_callback=hb_callback, 
only_if_necessary=True)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 05803f5fd5b..f4695e66206 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -102,7 +102,7 @@ from airflow.sdk.definitions.timetables.assets import 
PartitionedAssetTimetable
 from airflow.serialization.definitions.dag import SerializedDAG
 from airflow.serialization.serialized_objects import LazyDeserializedDAG
 from airflow.timetables.base import DagRunInfo, DataInterval
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import CallbackState, DagRunState, State, 
TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -4800,7 +4800,7 @@ class TestSchedulerJob:
         dag_maker.dag_model.calculate_dagrun_date_fields(dag, 
last_automated_run=None)
 
         @provide_session
-        def do_schedule(session):
+        def do_schedule(*, session: Session = NEW_SESSION):
             # Use a empty file since the above mock will return the
             # expected DAGs. Also specify only a single file so that it doesn't
             # try to schedule the above DAG repeatedly.
diff --git a/scripts/ci/prek/known_provide_session_positional.txt 
b/scripts/ci/prek/known_provide_session_positional.txt
index d2ba9508cae..c6351fe5bd7 100644
--- a/scripts/ci/prek/known_provide_session_positional.txt
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -1,7 +1,3 @@
-airflow-core/src/airflow/jobs/base_job_runner.py::2
-airflow-core/src/airflow/jobs/job.py::7
-airflow-core/src/airflow/jobs/scheduler_job_runner.py::11
-airflow-core/src/airflow/jobs/triggerer_job_runner.py::1
 airflow-core/src/airflow/models/connection.py::2
 airflow-core/src/airflow/models/dag.py::7
 airflow-core/src/airflow/models/dagcode.py::6
@@ -20,7 +16,6 @@ airflow-core/src/airflow/models/trigger.py::7
 airflow-core/src/airflow/models/variable.py::2
 airflow-core/src/airflow/secrets/metastore.py::2
 airflow-core/src/airflow/serialization/definitions/dag.py::2
-airflow-core/tests/unit/jobs/test_scheduler_job.py::1
 airflow-core/tests/unit/listeners/test_listeners.py::7
 airflow-core/tests/unit/models/test_taskinstance.py::4
 airflow-core/tests/unit/models/test_timestamp.py::2

Reply via email to