This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7ea060d7292 Remove findings from positional session check in Core Jobs
(#67773)
7ea060d7292 is described below
commit 7ea060d72920f329766eb763c5468d1ba1dab00b
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun May 31 08:42:24 2026 +0200
Remove findings from positional session check in Core Jobs (#67773)
* Fix exceptions of positional session use in airflow-core jpbs_processing
* Fix CI tests
---
airflow-core/src/airflow/jobs/base_job_runner.py | 4 ++--
airflow-core/src/airflow/jobs/job.py | 22 ++++++++------------
.../src/airflow/jobs/scheduler_job_runner.py | 24 +++++++++++-----------
.../src/airflow/jobs/triggerer_job_runner.py | 2 +-
airflow-core/tests/unit/jobs/test_base_job.py | 2 +-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 4 ++--
.../ci/prek/known_provide_session_positional.txt | 5 -----
7 files changed, 27 insertions(+), 36 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/base_job_runner.py
b/airflow-core/src/airflow/jobs/base_job_runner.py
index 05671e2050a..8ddde316dbf 100644
--- a/airflow-core/src/airflow/jobs/base_job_runner.py
+++ b/airflow-core/src/airflow/jobs/base_job_runner.py
@@ -54,7 +54,7 @@ class BaseJobRunner:
raise NotImplementedError()
@provide_session
- def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
+ def heartbeat_callback(self, *, session: Session = NEW_SESSION) -> None:
"""
Execute callback during heartbeat.
@@ -63,7 +63,7 @@ class BaseJobRunner:
@classmethod
@provide_session
- def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | None:
+ def most_recent_job(cls, *, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last
heartbeat received."""
from airflow.jobs.job import most_recent_job
diff --git a/airflow-core/src/airflow/jobs/job.py
b/airflow-core/src/airflow/jobs/job.py
index 4ab2defd81b..b19cd103de9 100644
--- a/airflow-core/src/airflow/jobs/job.py
+++ b/airflow-core/src/airflow/jobs/job.py
@@ -169,7 +169,7 @@ class Job(Base, LoggingMixin):
)
@provide_session
- def kill(self, session: Session = NEW_SESSION) -> NoReturn:
+ def kill(self, *, session: Session = NEW_SESSION) -> NoReturn:
"""Handle on_kill callback and updates state in database."""
try:
self.on_kill()
@@ -187,9 +187,7 @@ class Job(Base, LoggingMixin):
"""Will be called when an external kill command is received."""
@provide_session
- def heartbeat(
- self, heartbeat_callback: Callable[[Session], None], session: Session
= NEW_SESSION
- ) -> None:
+ def heartbeat(self, heartbeat_callback: Callable[..., None], *, session:
Session = NEW_SESSION) -> None:
"""
Update the job's entry in the database with the latest_heartbeat
timestamp.
@@ -241,7 +239,7 @@ class Job(Base, LoggingMixin):
self.log.info("Heartbeat recovered after %.2f seconds",
time_since_last_heartbeat)
# At this point, the DB has updated.
previous_heartbeat = self.latest_heartbeat
- heartbeat_callback(session)
+ heartbeat_callback(session=session)
self.log.debug("[heartbeat]")
self.heartbeat_failed = False
except OperationalError:
@@ -266,7 +264,7 @@ class Job(Base, LoggingMixin):
self.latest_heartbeat = previous_heartbeat
@provide_session
- def prepare_for_execution(self, session: Session = NEW_SESSION):
+ def prepare_for_execution(self, *, session: Session = NEW_SESSION):
"""Prepare the job for execution."""
stats.incr(self.__class__.__name__.lower() + "_start", 1, 1)
self.state = JobState.RUNNING
@@ -276,7 +274,7 @@ class Job(Base, LoggingMixin):
make_transient(self)
@provide_session
- def complete_execution(self, session: Session = NEW_SESSION):
+ def complete_execution(self, *, session: Session = NEW_SESSION):
try:
get_listener_manager().hook.before_stopping(component=self)
except Exception:
@@ -287,7 +285,7 @@ class Job(Base, LoggingMixin):
stats.incr(self.__class__.__name__.lower() + "_end", 1, 1)
@provide_session
- def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None:
+ def most_recent_job(self, *, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last
heartbeat received."""
return most_recent_job(str(self.job_type), session=session)
@@ -316,7 +314,7 @@ class Job(Base, LoggingMixin):
@provide_session
-def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job |
None:
+def most_recent_job(job_type: str, *, session: Session = NEW_SESSION) -> Job |
None:
"""
Return the most recent job of this type, if any, based on last heartbeat
received.
@@ -340,7 +338,7 @@ def most_recent_job(job_type: str, session: Session =
NEW_SESSION) -> Job | None
@provide_session
def run_job(
- job: Job, execute_callable: Callable[[], int | None], session: Session =
NEW_SESSION
+ job: Job, execute_callable: Callable[[], int | None], *, session: Session
= NEW_SESSION
) -> int | None:
"""
Run the job.
@@ -393,9 +391,7 @@ def execute_job(job: Job, execute_callable: Callable[[],
int | None]) -> int | N
return ret
-def perform_heartbeat(
- job: Job, heartbeat_callback: Callable[[Session], None],
only_if_necessary: bool
-) -> None:
+def perform_heartbeat(job: Job, heartbeat_callback: Callable[..., None],
only_if_necessary: bool) -> None:
"""
Perform heartbeat for the Job passed to it,optionally checking if it is
necessary.
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 224659c4c4d..596946268a6 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -330,7 +330,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self.scheduler_dag_bag = DBDagBag(load_op_links=False)
@provide_session
- def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
+ def heartbeat_callback(self, *, session: Session = NEW_SESSION) -> None:
stats.incr("scheduler_heartbeat", 1, 1)
def _get_current_dag(self, dag_id: str, session: Session) -> SerializedDAG
| None:
@@ -1567,7 +1567,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
return None
@provide_session
- def _update_dag_run_state_for_paused_dags(self, session: Session =
NEW_SESSION) -> None:
+ def _update_dag_run_state_for_paused_dags(self, *, session: Session =
NEW_SESSION) -> None:
try:
paused_runs = list(
session.scalars(
@@ -1955,7 +1955,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# END: create dagruns
@provide_session
- def _mark_backfills_complete(self, session: Session = NEW_SESSION) -> None:
+ def _mark_backfills_complete(self, *, session: Session = NEW_SESSION) ->
None:
"""Mark completed backfills as completed."""
self.log.debug("checking for completed backfills.")
unfinished_states = (DagRunState.RUNNING, DagRunState.QUEUED)
@@ -2551,7 +2551,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self.log.debug("callback is empty")
@provide_session
- def _handle_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) ->
None:
+ def _handle_tasks_stuck_in_queued(self, *, session: Session = NEW_SESSION)
-> None:
"""
Handle the scenario where a task is queued for longer than
`task_queued_timeout`.
@@ -2592,7 +2592,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
Otherwise, fail it.
"""
- num_times_stuck = self._get_num_times_stuck_in_queued(ti, session)
+ num_times_stuck = self._get_num_times_stuck_in_queued(ti,
session=session)
if num_times_stuck < self._num_stuck_queued_retries:
self.log.info("Task stuck in queued; will try to requeue.
task_instance=%s", ti)
session.add(
@@ -2684,7 +2684,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
@provide_session
- def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session:
Session = NEW_SESSION) -> int:
+ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, *, session:
Session = NEW_SESSION) -> int:
"""
Check the Log table to see how many times a task instance has been
stuck in queued.
@@ -2726,7 +2726,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
previous_ti_metrics: dict[TaskInstanceState, dict[tuple[str, str, str],
int]] = {}
@provide_session
- def _emit_ti_metrics(self, session: Session = NEW_SESSION) -> None:
+ def _emit_ti_metrics(self, *, session: Session = NEW_SESSION) -> None:
metric_states = {State.SCHEDULED, State.QUEUED, State.RUNNING,
State.DEFERRED}
stmt = (
select(
@@ -2771,13 +2771,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self.previous_ti_metrics[state] = ti_metrics
@provide_session
- def _emit_running_dags_metric(self, session: Session = NEW_SESSION) ->
None:
+ def _emit_running_dags_metric(self, *, session: Session = NEW_SESSION) ->
None:
stmt = select(func.count()).select_from(DagRun).where(DagRun.state ==
DagRunState.RUNNING)
running_dags = float(session.scalar(stmt) or 0)
stats.gauge("scheduler.dagruns.running", running_dags)
@provide_session
- def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None:
+ def _emit_pool_metrics(self, *, session: Session = NEW_SESSION) -> None:
from airflow.models.pool import Pool
pools = Pool.slots_stats(session=session)
@@ -2810,7 +2810,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
@provide_session
- def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) ->
int:
+ def adopt_or_reset_orphaned_tasks(self, *, session: Session = NEW_SESSION)
-> int:
"""
Adopt or reset any TaskInstance in resettable state if its
SchedulerJob is no longer running.
@@ -2912,7 +2912,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
@provide_session
def check_trigger_timeouts(
- self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION
+ self, max_retries: int = MAX_DB_RETRIES, *, session: Session =
NEW_SESSION
) -> None:
"""Mark any "deferred" task as failed if the trigger or execution
timeout has passed."""
for attempt in run_with_db_retries(max_retries, logger=self.log):
@@ -3092,7 +3092,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
@provide_session
- def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:
+ def _update_asset_orphanage(self, *, session: Session = NEW_SESSION) ->
None:
"""
Check assets orphanization and update their active entry.
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 38f67a9d43e..e8e1f66e787 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -204,7 +204,7 @@ class TriggererJobRunner(BaseJobRunner, LoggingMixin):
@classmethod
@provide_session
- def is_needed(cls, session) -> bool:
+ def is_needed(cls, *, session: Session) -> bool:
"""
Test if the triggerer job needs to be run (i.e., if there are triggers
in the trigger table).
diff --git a/airflow-core/tests/unit/jobs/test_base_job.py
b/airflow-core/tests/unit/jobs/test_base_job.py
index aca8c6e7c82..a38956c61bd 100644
--- a/airflow-core/tests/unit/jobs/test_base_job.py
+++ b/airflow-core/tests/unit/jobs/test_base_job.py
@@ -268,7 +268,7 @@ class TestJob:
hb_callback = Mock()
job.heartbeat(heartbeat_callback=hb_callback)
- hb_callback.assert_called_once_with(ANY)
+ hb_callback.assert_called_once_with(session=ANY)
hb_callback.reset_mock()
perform_heartbeat(job=job, heartbeat_callback=hb_callback,
only_if_necessary=True)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 05803f5fd5b..f4695e66206 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -102,7 +102,7 @@ from airflow.sdk.definitions.timetables.assets import
PartitionedAssetTimetable
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.serialized_objects import LazyDeserializedDAG
from airflow.timetables.base import DagRunInfo, DataInterval
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import CallbackState, DagRunState, State,
TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -4800,7 +4800,7 @@ class TestSchedulerJob:
dag_maker.dag_model.calculate_dagrun_date_fields(dag,
last_automated_run=None)
@provide_session
- def do_schedule(session):
+ def do_schedule(*, session: Session = NEW_SESSION):
# Use a empty file since the above mock will return the
# expected DAGs. Also specify only a single file so that it doesn't
# try to schedule the above DAG repeatedly.
diff --git a/scripts/ci/prek/known_provide_session_positional.txt
b/scripts/ci/prek/known_provide_session_positional.txt
index d2ba9508cae..c6351fe5bd7 100644
--- a/scripts/ci/prek/known_provide_session_positional.txt
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -1,7 +1,3 @@
-airflow-core/src/airflow/jobs/base_job_runner.py::2
-airflow-core/src/airflow/jobs/job.py::7
-airflow-core/src/airflow/jobs/scheduler_job_runner.py::11
-airflow-core/src/airflow/jobs/triggerer_job_runner.py::1
airflow-core/src/airflow/models/connection.py::2
airflow-core/src/airflow/models/dag.py::7
airflow-core/src/airflow/models/dagcode.py::6
@@ -20,7 +16,6 @@ airflow-core/src/airflow/models/trigger.py::7
airflow-core/src/airflow/models/variable.py::2
airflow-core/src/airflow/secrets/metastore.py::2
airflow-core/src/airflow/serialization/definitions/dag.py::2
-airflow-core/tests/unit/jobs/test_scheduler_job.py::1
airflow-core/tests/unit/listeners/test_listeners.py::7
airflow-core/tests/unit/models/test_taskinstance.py::4
airflow-core/tests/unit/models/test_timestamp.py::2