This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 24b52c535ff Remove findings from positional session check in Core Dag
Modules (#67789)
24b52c535ff is described below
commit 24b52c535ff2b7633ce399883e20155aea24abd2
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun May 31 13:25:58 2026 +0200
Remove findings from positional session check in Core Dag Modules (#67789)
* Fix exceptions of positional session use in airflow-core models dag
modules
* Fix mypy
---
.../src/airflow/jobs/scheduler_job_runner.py | 6 +-
airflow-core/src/airflow/models/dag.py | 13 +-
airflow-core/src/airflow/models/dagcode.py | 16 +-
airflow-core/src/airflow/models/dagrun.py | 32 +--
airflow-core/src/airflow/models/dagwarning.py | 2 +-
.../src/airflow/ti_deps/deps/prev_dagrun_dep.py | 2 +-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 242 ++++++++++-----------
airflow-core/tests/unit/models/test_cleartasks.py | 2 +-
airflow-core/tests/unit/models/test_dagrun.py | 18 +-
airflow-core/tests/unit/models/test_dagwarning.py | 4 +-
.../providers/standard/sensors/external_task.py | 2 +-
.../ci/prek/known_provide_session_positional.txt | 4 -
12 files changed, 173 insertions(+), 170 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 596946268a6..f61e20840ed 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -2391,7 +2391,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
callback: DagCallbackRequest | None = None
dag = dag_run.dag =
self.scheduler_dag_bag.get_dag_for_run(dag_run=dag_run, session=session)
- dag_model = DM.get_dagmodel(dag_run.dag_id, session)
+ dag_model = DM.get_dagmodel(dag_run.dag_id, session=session)
if not dag_model:
self.log.error("Couldn't find DAG model %s in database!",
dag_run.dag_id)
return callback
@@ -2498,7 +2498,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
for ti in schedulable_tis
],
)
- dag_run.schedule_tis(schedulable_tis, session,
max_tis_per_query=self.job.max_tis_per_query)
+ dag_run.schedule_tis(schedulable_tis, session=session,
max_tis_per_query=self.job.max_tis_per_query)
return callback_to_run
@@ -2514,7 +2514,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
if TYPE_CHECKING:
assert latest_dag_version
- if dag_run.check_version_id_exists_in_dr(latest_dag_version.id,
session):
+ if dag_run.check_version_id_exists_in_dr(latest_dag_version.id,
session=session):
self.log.debug("DAG %s not changed structure, skipping
dagrun.verify_integrity", dag_run.dag_id)
return True
# Refresh the DAG
diff --git a/airflow-core/src/airflow/models/dag.py
b/airflow-core/src/airflow/models/dag.py
index ca84b7047b4..98bcef4cf00 100644
--- a/airflow-core/src/airflow/models/dag.py
+++ b/airflow-core/src/airflow/models/dag.py
@@ -524,7 +524,7 @@ class DagModel(Base):
@staticmethod
@provide_session
- def get_dagmodel(dag_id: str, session: Session = NEW_SESSION) -> DagModel
| None:
+ def get_dagmodel(dag_id: str, *, session: Session = NEW_SESSION) ->
DagModel | None:
return session.get(
DagModel,
dag_id,
@@ -532,12 +532,12 @@ class DagModel(Base):
@classmethod
@provide_session
- def get_current(cls, dag_id: str, session: Session = NEW_SESSION) ->
DagModel | None:
+ def get_current(cls, dag_id: str, *, session: Session = NEW_SESSION) ->
DagModel | None:
return session.scalar(select(cls).where(cls.dag_id == dag_id))
@provide_session
def get_last_dagrun(
- self, session: Session = NEW_SESSION, include_manually_triggered: bool
= False
+ self, *, session: Session = NEW_SESSION, include_manually_triggered:
bool = False
) -> DagRun | None:
return get_last_dagrun(
self.dag_id, session=session,
include_manually_triggered=include_manually_triggered
@@ -549,7 +549,7 @@ class DagModel(Base):
@staticmethod
@provide_session
- def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION)
-> set[str]:
+ def get_paused_dag_ids(dag_ids: list[str], *, session: Session =
NEW_SESSION) -> set[str]:
"""
Given a list of dag_ids, get a set of Paused Dag Ids.
@@ -591,6 +591,7 @@ class DagModel(Base):
cls,
bundle_name: str,
rel_filelocs: Collection[str],
+ *,
session: Session = NEW_SESSION,
) -> bool:
"""
@@ -814,7 +815,7 @@ class DagModel(Base):
@staticmethod
@provide_session
- def get_team_name(dag_id: str, session: Session = NEW_SESSION) -> str |
None:
+ def get_team_name(dag_id: str, *, session: Session = NEW_SESSION) -> str |
None:
"""Return the team name associated to a Dag or None if it is not owned
by a specific team."""
stmt = (
select(Team.name)
@@ -827,7 +828,7 @@ class DagModel(Base):
@staticmethod
@provide_session
def get_dag_id_to_team_name_mapping(
- dag_ids: list[str], session: Session = NEW_SESSION
+ dag_ids: list[str], *, session: Session = NEW_SESSION
) -> dict[str, str | None]:
stmt = (
select(DagModel.dag_id, Team.name)
diff --git a/airflow-core/src/airflow/models/dagcode.py
b/airflow-core/src/airflow/models/dagcode.py
index 60ee91c8b59..d591ddb25b7 100644
--- a/airflow-core/src/airflow/models/dagcode.py
+++ b/airflow-core/src/airflow/models/dagcode.py
@@ -80,7 +80,7 @@ class DagCode(Base):
@classmethod
@provide_session
- def write_code(cls, dag_version: DagVersion, fileloc: str, session:
Session = NEW_SESSION) -> DagCode:
+ def write_code(cls, dag_version: DagVersion, fileloc: str, *, session:
Session = NEW_SESSION) -> DagCode:
"""
Write code into database.
@@ -95,7 +95,7 @@ class DagCode(Base):
@classmethod
@provide_session
- def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool:
+ def has_dag(cls, dag_id: str, *, session: Session = NEW_SESSION) -> bool:
"""
Check a dag exists in dag code table.
@@ -109,13 +109,13 @@ class DagCode(Base):
@classmethod
@provide_session
- def code(cls, dag_id, session: Session = NEW_SESSION) -> str:
+ def code(cls, dag_id, *, session: Session = NEW_SESSION) -> str:
"""
Return source code for this DagCode object.
:return: source code as string
"""
- return cls._get_code_from_db(dag_id, session)
+ return cls._get_code_from_db(dag_id, session=session)
@staticmethod
def get_code_from_file(fileloc):
@@ -131,7 +131,7 @@ class DagCode(Base):
@classmethod
@provide_session
- def _get_code_from_db(cls, dag_id, session: Session = NEW_SESSION) -> str:
+ def _get_code_from_db(cls, dag_id, *, session: Session = NEW_SESSION) ->
str:
dag_code = session.scalar(
select(cls).where(cls.dag_id ==
dag_id).order_by(cls.last_updated.desc()).limit(1)
)
@@ -161,7 +161,7 @@ class DagCode(Base):
@classmethod
@provide_session
- def get_latest_dagcode(cls, dag_id: str, session: Session = NEW_SESSION)
-> DagCode | None:
+ def get_latest_dagcode(cls, dag_id: str, *, session: Session =
NEW_SESSION) -> DagCode | None:
"""
Get the latest dagcode.
@@ -173,7 +173,7 @@ class DagCode(Base):
@classmethod
@provide_session
- def update_source_code(cls, dag_id: str, fileloc: str, session: Session =
NEW_SESSION) -> None:
+ def update_source_code(cls, dag_id: str, fileloc: str, *, session: Session
= NEW_SESSION) -> None:
"""
Check if the source code of the DAG has changed and update it if
needed.
@@ -182,7 +182,7 @@ class DagCode(Base):
:param session: The database session.
:return: None
"""
- latest_dagcode = cls.get_latest_dagcode(dag_id, session)
+ latest_dagcode = cls.get_latest_dagcode(dag_id, session=session)
if not latest_dagcode:
return
new_source_code = cls.get_code_from_file(fileloc)
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 36ed309feb0..25f6397c074 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -467,7 +467,7 @@ class DagRun(Base, LoggingMixin):
@duration.expression # type: ignore[no-redef]
@provide_session
- def duration(cls, session: Session = NEW_SESSION) -> Case:
+ def duration(cls, *, session: Session = NEW_SESSION) -> Case:
dialect_name = get_dialect_name(session)
if dialect_name == "mysql":
return func.timestampdiff(text("SECOND"), cls.start_date,
cls.end_date)
@@ -486,7 +486,7 @@ class DagRun(Base, LoggingMixin):
return case(when_condition, else_=None)
@provide_session
- def check_version_id_exists_in_dr(self, dag_version_id: UUID, session:
Session = NEW_SESSION):
+ def check_version_id_exists_in_dr(self, dag_version_id: UUID, *, session:
Session = NEW_SESSION):
select_stmt = (
select(TI.dag_version_id)
.where(TI.dag_id == self.dag_id, TI.dag_version_id ==
dag_version_id, TI.run_id == self.run_id)
@@ -584,7 +584,7 @@ class DagRun(Base, LoggingMixin):
return synonym("_state", descriptor=property(self.get_state,
self.set_state))
@provide_session
- def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
+ def refresh_from_db(self, *, session: Session = NEW_SESSION) -> None:
"""
Reload the current dagrun from the database.
@@ -602,7 +602,7 @@ class DagRun(Base, LoggingMixin):
cls,
*,
dag_ids: Iterable[str],
- exclude_backfill,
+ exclude_backfill: bool,
session: Session = NEW_SESSION,
) -> dict[str, int]:
"""
@@ -753,9 +753,10 @@ class DagRun(Base, LoggingMixin):
state: DagRunState | None = None,
no_backfills: bool = False,
run_type: DagRunType | None = None,
- session: Session = NEW_SESSION,
logical_start_date: datetime | None = None,
logical_end_date: datetime | None = None,
+ *,
+ session: Session = NEW_SESSION,
) -> list[DagRun]:
"""
Return a set of dag runs for the given search criteria.
@@ -836,6 +837,7 @@ class DagRun(Base, LoggingMixin):
run_id: str | None = None,
task_ids: list[str] | None = None,
state: Iterable[TaskInstanceState | None] | None = None,
+ *,
session: Session = NEW_SESSION,
) -> list[TI]:
"""Return the task instances for this dag run."""
@@ -916,6 +918,7 @@ class DagRun(Base, LoggingMixin):
def get_task_instances(
self,
state: Iterable[TaskInstanceState | None] | None = None,
+ *,
session: Session = NEW_SESSION,
) -> list[TI]:
"""
@@ -933,9 +936,9 @@ class DagRun(Base, LoggingMixin):
def get_task_instance(
self,
task_id: str,
- session: Session = NEW_SESSION,
*,
map_index: int = -1,
+ session: Session = NEW_SESSION,
) -> TI | None:
"""
Return the task instance specified by task_id for this dag run.
@@ -957,8 +960,9 @@ class DagRun(Base, LoggingMixin):
dag_id: str,
dag_run_id: str,
task_id: str,
- session: Session = NEW_SESSION,
+ *,
map_index: int = -1,
+ session: Session = NEW_SESSION,
) -> TI | None:
"""
Return the task instance specified by task_id for this dag run.
@@ -986,7 +990,7 @@ class DagRun(Base, LoggingMixin):
@staticmethod
@provide_session
def get_previous_dagrun(
- dag_run: DagRun, state: DagRunState | None = None, session: Session =
NEW_SESSION
+ dag_run: DagRun, state: DagRunState | None = None, *, session: Session
= NEW_SESSION
) -> DagRun | None:
"""
Return the previous DagRun, if there is one.
@@ -1009,6 +1013,7 @@ class DagRun(Base, LoggingMixin):
@provide_session
def get_previous_scheduled_dagrun(
dag_run_id: int,
+ *,
session: Session = NEW_SESSION,
) -> DagRun | None:
"""
@@ -1100,7 +1105,7 @@ class DagRun(Base, LoggingMixin):
@provide_session
def update_state(
- self, session: Session = NEW_SESSION, execute_callbacks: bool = True
+ self, *, session: Session = NEW_SESSION, execute_callbacks: bool = True
) -> tuple[list[TI], DagCallbackRequest | None]:
"""
Determine the overall state of the DagRun based on the state of its
TaskInstances.
@@ -1145,7 +1150,7 @@ class DagRun(Base, LoggingMixin):
tags=self.stats_tags,
):
dag = self.get_dag()
- info = self.task_instance_scheduling_decisions(session)
+ info = self.task_instance_scheduling_decisions(session=session)
tis = info.tis
schedulable_tis = info.schedulable_tis
@@ -1295,7 +1300,7 @@ class DagRun(Base, LoggingMixin):
return schedulable_tis, callback
@provide_session
- def task_instance_scheduling_decisions(self, session: Session =
NEW_SESSION) -> TISchedulingDecision:
+ def task_instance_scheduling_decisions(self, *, session: Session =
NEW_SESSION) -> TISchedulingDecision:
tis = self.get_task_instances(session=session, state=State.task_states)
self.log.debug("number of tis tasks for %s: %s task(s)", self,
len(tis))
@@ -1967,7 +1972,7 @@ class DagRun(Base, LoggingMixin):
@classmethod
@provide_session
- def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]:
+ def get_latest_runs(cls, *, session: Session = NEW_SESSION) ->
list[DagRun]:
"""Return the latest DagRun for each DAG."""
subquery = (
select(cls.dag_id,
func.max(cls.logical_date).label("logical_date"))
@@ -1987,6 +1992,7 @@ class DagRun(Base, LoggingMixin):
def schedule_tis(
self,
schedulable_tis: Iterable[TI],
+ *,
session: Session = NEW_SESSION,
max_tis_per_query: int | None = None,
) -> int:
@@ -2126,7 +2132,7 @@ class DagRun(Base, LoggingMixin):
@staticmethod
@provide_session
- def _get_log_template(log_template_id: int | None, session: Session =
NEW_SESSION) -> LogTemplate:
+ def _get_log_template(log_template_id: int | None, *, session: Session =
NEW_SESSION) -> LogTemplate:
template: LogTemplate | None
if log_template_id is None: # DagRun created before LogTemplate
introduction.
template =
session.scalar(select(LogTemplate).order_by(LogTemplate.id).limit(1))
diff --git a/airflow-core/src/airflow/models/dagwarning.py
b/airflow-core/src/airflow/models/dagwarning.py
index 246eacee51f..d411a3b9386 100644
--- a/airflow-core/src/airflow/models/dagwarning.py
+++ b/airflow-core/src/airflow/models/dagwarning.py
@@ -77,7 +77,7 @@ class DagWarning(Base):
@classmethod
@provide_session
@retry_db_transaction
- def purge_inactive_dag_warnings(cls, session: Session = NEW_SESSION) ->
None:
+ def purge_inactive_dag_warnings(cls, *, session: Session = NEW_SESSION) ->
None:
"""
Deactivate DagWarning records for inactive dags.
diff --git a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
index 397746d13a1..a1bf124be39 100644
--- a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -160,7 +160,7 @@ class PrevDagrunDep(BaseTIDep):
# Don't depend on the previous task instance if we are the first task.
catchup = ti.task.dag and ti.task.dag.catchup
if catchup:
- last_dagrun = DagRun.get_previous_scheduled_dagrun(dr.id, session)
+ last_dagrun = DagRun.get_previous_scheduled_dagrun(dr.id,
session=session)
else:
last_dagrun = DagRun.get_previous_dagrun(dr, session=session)
# First ever run for this DAG.
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index f4695e66206..872d133c122 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -251,7 +251,7 @@ def task_maker(
tis_list = []
for i in range(task_num):
- ti = dag_run.get_task_instance(dag_tasks[f"op{i}"].task_id, session)
+ ti = dag_run.get_task_instance(dag_tasks[f"op{i}"].task_id,
session=session)
# e.g.
# If running_num is 2, then for i=0 and i=1, state will be RUNNING.
# If running_num is 0, then state will be SCHEDULED for all.
@@ -414,7 +414,7 @@ class TestSchedulerJob:
with dag_maker(dag_id="test_only_idle_one_task",
fileloc="test_only_idle_one_task.py"):
EmptyOperator(task_id="dummy")
dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED,
state=State.RUNNING)
- ti = dr.get_task_instance("dummy", session)
+ ti = dr.get_task_instance("dummy", session=session)
ti.state = State.SCHEDULED
session.merge(ti)
session.commit()
@@ -1495,11 +1495,11 @@ class TestSchedulerJob:
dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- ti1 = dag_run.get_task_instance(op1.task_id, session)
- ti2 = dag_run.get_task_instance(op2.task_id, session)
- ti3 = dag_run.get_task_instance(op3.task_id, session)
- ti4 = dag_run.get_task_instance(op4.task_id, session)
- ti5 = dag_run.get_task_instance(op5.task_id, session)
+ ti1 = dag_run.get_task_instance(op1.task_id, session=session)
+ ti2 = dag_run.get_task_instance(op2.task_id, session=session)
+ ti3 = dag_run.get_task_instance(op3.task_id, session=session)
+ ti4 = dag_run.get_task_instance(op4.task_id, session=session)
+ ti5 = dag_run.get_task_instance(op5.task_id, session=session)
tis_tuple = (ti1, ti2, ti3, ti4, ti5)
for ti in tis_tuple:
@@ -1555,11 +1555,11 @@ class TestSchedulerJob:
dr3 = dag_maker.create_dagrun()
tis = [
- dr1.get_task_instance(op1.task_id, session),
- dr1.get_task_instance(op2.task_id, session),
- dr2.get_task_instance(op3.task_id, session),
- dr2.get_task_instance(op4.task_id, session),
- dr3.get_task_instance(op5.task_id, session),
+ dr1.get_task_instance(op1.task_id, session=session),
+ dr1.get_task_instance(op2.task_id, session=session),
+ dr2.get_task_instance(op3.task_id, session=session),
+ dr2.get_task_instance(op4.task_id, session=session),
+ dr3.get_task_instance(op5.task_id, session=session),
]
for ti in tis:
@@ -1849,9 +1849,9 @@ class TestSchedulerJob:
dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- ti1 = dag_run.get_task_instance(op1.task_id, session)
- ti2 = dag_run.get_task_instance(op2.task_id, session)
- ti3 = dag_run.get_task_instance(op3.task_id, session)
+ ti1 = dag_run.get_task_instance(op1.task_id, session=session)
+ ti2 = dag_run.get_task_instance(op2.task_id, session=session)
+ ti3 = dag_run.get_task_instance(op3.task_id, session=session)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
@@ -1910,8 +1910,8 @@ class TestSchedulerJob:
dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
dr2 = dag_maker.create_dagrun_after(dr1,
run_type=DagRunType.SCHEDULED, state=State.RUNNING)
- ti1 = dr1.get_task_instance(op1.task_id, session)
- ti2 = dr2.get_task_instance(op2.task_id, session)
+ ti1 = dr1.get_task_instance(op1.task_id, session=session)
+ ti2 = dr2.get_task_instance(op2.task_id, session=session)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
@@ -2272,12 +2272,12 @@ class TestSchedulerJob:
)
# DR1's TI is deferred (waiting on a trigger)
- ti1 = dr1.get_task_instance(task1.task_id, session)
+ ti1 = dr1.get_task_instance(task1.task_id, session=session)
ti1.state = TaskInstanceState.DEFERRED
session.merge(ti1)
# DR2's TI is scheduled and wants to run
- ti2 = dr2.get_task_instance(task1.task_id, session)
+ ti2 = dr2.get_task_instance(task1.task_id, session=session)
ti2.state = State.SCHEDULED
session.merge(ti2)
session.flush()
@@ -2311,15 +2311,15 @@ class TestSchedulerJob:
dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session
)
- ti1 = dr1.get_task_instance(task1.task_id, session)
+ ti1 = dr1.get_task_instance(task1.task_id, session=session)
ti1.state = TaskInstanceState.RUNNING
session.merge(ti1)
- ti2 = dr2.get_task_instance(task1.task_id, session)
+ ti2 = dr2.get_task_instance(task1.task_id, session=session)
ti2.state = TaskInstanceState.DEFERRED
session.merge(ti2)
- ti3 = dr3.get_task_instance(task1.task_id, session)
+ ti3 = dr3.get_task_instance(task1.task_id, session=session)
ti3.state = State.SCHEDULED
session.merge(ti3)
session.flush()
@@ -2351,15 +2351,15 @@ class TestSchedulerJob:
dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session
)
- ti1 = dr1.get_task_instance(task1.task_id, session)
+ ti1 = dr1.get_task_instance(task1.task_id, session=session)
ti1.state = TaskInstanceState.DEFERRED
session.merge(ti1)
- ti2 = dr2.get_task_instance(task1.task_id, session)
+ ti2 = dr2.get_task_instance(task1.task_id, session=session)
ti2.state = State.SCHEDULED
session.merge(ti2)
- ti3 = dr3.get_task_instance(task1.task_id, session)
+ ti3 = dr3.get_task_instance(task1.task_id, session=session)
ti3.state = State.SCHEDULED
session.merge(ti3)
session.flush()
@@ -2389,17 +2389,17 @@ class TestSchedulerJob:
)
# task_a in DR1 is deferred
- ti_a1 = dr1.get_task_instance(task_a.task_id, session)
+ ti_a1 = dr1.get_task_instance(task_a.task_id, session=session)
ti_a1.state = TaskInstanceState.DEFERRED
session.merge(ti_a1)
# task_a in DR2 is scheduled (should be blocked by deferred ti_a1)
- ti_a2 = dr2.get_task_instance(task_a.task_id, session)
+ ti_a2 = dr2.get_task_instance(task_a.task_id, session=session)
ti_a2.state = State.SCHEDULED
session.merge(ti_a2)
# task_b in DR1 is scheduled (should NOT be blocked)
- ti_b1 = dr1.get_task_instance(task_b.task_id, session)
+ ti_b1 = dr1.get_task_instance(task_b.task_id, session=session)
ti_b1.state = State.SCHEDULED
session.merge(ti_b1)
session.flush()
@@ -2427,8 +2427,8 @@ class TestSchedulerJob:
dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
)
- ti1 = dr1.get_task_instance(task1.task_id, session)
- ti2 = dr2.get_task_instance(task1.task_id, session)
+ ti1 = dr1.get_task_instance(task1.task_id, session=session)
+ ti2 = dr2.get_task_instance(task1.task_id, session=session)
# Step 1: ti1 is deferred, ti2 scheduled -> ti2 blocked
ti1.state = TaskInstanceState.DEFERRED
@@ -2469,8 +2469,8 @@ class TestSchedulerJob:
dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED,
session=session)
- ti_a0 = dr.get_task_instance(task_a.task_id, session, map_index=0)
- ti_a1 = dr.get_task_instance(task_a.task_id, session, map_index=1)
+ ti_a0 = dr.get_task_instance(task_a.task_id, session=session,
map_index=0)
+ ti_a1 = dr.get_task_instance(task_a.task_id, session=session,
map_index=1)
ti_a0.state = TaskInstanceState.DEFERRED
ti_a1.state = State.SCHEDULED
@@ -2564,8 +2564,8 @@ class TestSchedulerJob:
dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- ti1 = dr1.get_task_instance(op1.task_id, session)
- ti2 = dr1.get_task_instance(op2.task_id, session)
+ ti1 = dr1.get_task_instance(op1.task_id, session=session)
+ ti2 = dr1.get_task_instance(op2.task_id, session=session)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
session.flush()
@@ -2599,9 +2599,9 @@ class TestSchedulerJob:
op2 = EmptyOperator(task_id="dummy2", priority_weight=1)
dr2 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- ti1a = dr1.get_task_instance(op1a.task_id, session)
- ti1b = dr1.get_task_instance(op1b.task_id, session)
- ti2 = dr2.get_task_instance(op2.task_id, session)
+ ti1a = dr1.get_task_instance(op1a.task_id, session=session)
+ ti1b = dr1.get_task_instance(op1b.task_id, session=session)
+ ti2 = dr2.get_task_instance(op2.task_id, session=session)
ti1a.state = State.RUNNING
ti1b.state = State.SCHEDULED
ti2.state = State.SCHEDULED
@@ -2628,9 +2628,9 @@ class TestSchedulerJob:
dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
- ti1a = dr1.get_task_instance(op1a.task_id, session)
- ti1b = dr1.get_task_instance(op1b.task_id, session)
- ti2a = dr2.get_task_instance(op1a.task_id, session)
+ ti1a = dr1.get_task_instance(op1a.task_id, session=session)
+ ti1b = dr1.get_task_instance(op1b.task_id, session=session)
+ ti2a = dr2.get_task_instance(op1a.task_id, session=session)
ti1a.state = State.RUNNING
ti1b.state = State.SCHEDULED
ti2a.state = State.SCHEDULED
@@ -2657,9 +2657,9 @@ class TestSchedulerJob:
dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
- ti1a = dr1.get_task_instance(op1a.task_id, session)
- ti1b = dr1.get_task_instance(op1b.task_id, session)
- ti2a = dr2.get_task_instance(op1a.task_id, session)
+ ti1a = dr1.get_task_instance(op1a.task_id, session=session)
+ ti1b = dr1.get_task_instance(op1b.task_id, session=session)
+ ti2a = dr2.get_task_instance(op1a.task_id, session=session)
ti1a.state = State.RUNNING
ti1b.state = State.SCHEDULED
ti2a.state = State.SCHEDULED
@@ -2690,9 +2690,9 @@ class TestSchedulerJob:
op1b = EmptyOperator(task_id="dummy1-b", priority_weight=1)
dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- ti1a0 = dr.get_task_instance(op1a.task_id, session, map_index=0)
- ti1a1 = dr.get_task_instance(op1a.task_id, session, map_index=1)
- ti1b = dr.get_task_instance(op1b.task_id, session)
+ ti1a0 = dr.get_task_instance(op1a.task_id, session=session,
map_index=0)
+ ti1a1 = dr.get_task_instance(op1a.task_id, session=session,
map_index=1)
+ ti1b = dr.get_task_instance(op1b.task_id, session=session)
ti1a0.state = State.RUNNING
ti1a1.state = State.SCHEDULED
ti1b.state = State.SCHEDULED
@@ -2731,8 +2731,8 @@ class TestSchedulerJob:
dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- ti1 = dr1.get_task_instance(op1.task_id, session)
- ti2 = dr1.get_task_instance(op2.task_id, session)
+ ti1 = dr1.get_task_instance(op1.task_id, session=session)
+ ti2 = dr1.get_task_instance(op2.task_id, session=session)
ti1.state = State.SCHEDULED
ti2.state = State.RUNNING
session.flush()
@@ -2757,7 +2757,7 @@ class TestSchedulerJob:
dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- ti = dr.get_task_instance(op.task_id, session)
+ ti = dr.get_task_instance(op.task_id, session=session)
ti.state = State.SCHEDULED
set_default_pool_slots(1)
@@ -2805,7 +2805,7 @@ class TestSchedulerJob:
self.job_runner = SchedulerJobRunner(job=scheduler_job,
executors=[self.null_exec])
dr1 = dag_maker.create_dagrun()
- ti1 = dr1.get_task_instance(task1.task_id, session)
+ ti1 = dr1.get_task_instance(task1.task_id, session=session)
with patch.object(BaseExecutor, "queue_workload") as
mock_queue_workload:
self.job_runner._enqueue_task_instances_with_queued_state(
@@ -2836,8 +2836,8 @@ class TestSchedulerJob:
self.job_runner = SchedulerJobRunner(job=Job(),
executors=(regular_exec, pre_assigning_exec))
dr = dag_maker.create_dagrun()
- ti_pre_assign = dr.get_task_instance("a_task_pre_assign", session)
- ti_regular = dr.get_task_instance("b_task_regular", session)
+ ti_pre_assign = dr.get_task_instance("a_task_pre_assign",
session=session)
+ ti_regular = dr.get_task_instance("b_task_regular", session=session)
ti_regular.state = State.SCHEDULED
ti_regular.executor = regular_exec.name.module_path
@@ -2880,7 +2880,7 @@ class TestSchedulerJob:
self.job_runner = SchedulerJobRunner(job=scheduler_job)
dr1 = dag_maker.create_dagrun(state=state)
- ti = dr1.get_task_instance(task1.task_id, session)
+ ti = dr1.get_task_instance(task1.task_id, session=session)
ti.state = State.SCHEDULED
session.merge(ti)
session.commit()
@@ -2906,7 +2906,7 @@ class TestSchedulerJob:
self.job_runner = SchedulerJobRunner(job=scheduler_job,
executors=[self.null_exec])
dr1 = dag_maker.create_dagrun()
- ti = dr1.get_task_instance(task1.task_id, session)
+ ti = dr1.get_task_instance(task1.task_id, session=session)
ti.state = State.SCHEDULED
ti.dag_version_id = None
session.merge(ti)
@@ -2952,10 +2952,10 @@ class TestSchedulerJob:
dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED,
session=session)
- dr1_ti1 = dr1.get_task_instance(task1.task_id, session)
- dr1_ti2 = dr1.get_task_instance(task2.task_id, session)
- dr1_ti3 = dr1.get_task_instance(task3.task_id, session)
- dr1_ti4 = dr1.get_task_instance(task4.task_id, session)
+ dr1_ti1 = dr1.get_task_instance(task1.task_id, session=session)
+ dr1_ti2 = dr1.get_task_instance(task2.task_id, session=session)
+ dr1_ti3 = dr1.get_task_instance(task3.task_id, session=session)
+ dr1_ti4 = dr1.get_task_instance(task4.task_id, session=session)
dr1_ti1.state = State.RUNNING
dr1_ti2.state = State.RUNNING
dr1_ti3.state = State.RUNNING
@@ -2975,10 +2975,10 @@ class TestSchedulerJob:
# create second dag run
dr2 = dag_maker.create_dagrun_after(dr1,
run_type=DagRunType.SCHEDULED, session=session)
- dr2_ti1 = dr2.get_task_instance(task1.task_id, session)
- dr2_ti2 = dr2.get_task_instance(task2.task_id, session)
- dr2_ti3 = dr2.get_task_instance(task3.task_id, session)
- dr2_ti4 = dr2.get_task_instance(task4.task_id, session)
+ dr2_ti1 = dr2.get_task_instance(task1.task_id, session=session)
+ dr2_ti2 = dr2.get_task_instance(task2.task_id, session=session)
+ dr2_ti3 = dr2.get_task_instance(task3.task_id, session=session)
+ dr2_ti4 = dr2.get_task_instance(task4.task_id, session=session)
# manually set to scheduled so we can pick them up
dr2_ti1.state = State.SCHEDULED
dr2_ti2.state = State.SCHEDULED
@@ -3039,9 +3039,9 @@ class TestSchedulerJob:
tis1 = []
tis2 = []
for dr in _create_dagruns():
- ti1 = dr.get_task_instance(task1.task_id, session)
+ ti1 = dr.get_task_instance(task1.task_id, session=session)
tis1.append(ti1)
- ti2 = dr.get_task_instance(task2.task_id, session)
+ ti2 = dr.get_task_instance(task2.task_id, session=session)
tis2.append(ti2)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
@@ -3104,9 +3104,9 @@ class TestSchedulerJob:
tis = []
for dr in _create_dagruns():
- ti1 = dr.get_task_instance(task1.task_id, session)
+ ti1 = dr.get_task_instance(task1.task_id, session=session)
tis.append(ti1)
- ti2 = dr.get_task_instance(task2.task_id, session)
+ ti2 = dr.get_task_instance(task2.task_id, session=session)
tis.append(ti2)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
@@ -3145,9 +3145,9 @@ class TestSchedulerJob:
tis = []
for dr in _create_dagruns():
- ti1 = dr.get_task_instance(task1.task_id, session)
+ ti1 = dr.get_task_instance(task1.task_id, session=session)
tis.append(ti1)
- ti2 = dr.get_task_instance(task2.task_id, session)
+ ti2 = dr.get_task_instance(task2.task_id, session=session)
tis.append(ti2)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
@@ -3192,8 +3192,8 @@ class TestSchedulerJob:
yield dagrun
for dr in _create_dagruns():
- ti1 = dr.get_task_instance(task1.task_id, session)
- ti2 = dr.get_task_instance(task2.task_id, session)
+ ti1 = dr.get_task_instance(task1.task_id, session=session)
+ ti2 = dr.get_task_instance(task2.task_id, session=session)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
session.flush()
@@ -3245,8 +3245,8 @@ class TestSchedulerJob:
yield dagrun
for dr in _create_dagruns():
- ti1 = dr.get_task_instance(task1.task_id, session)
- ti2 = dr.get_task_instance(task2.task_id, session)
+ ti1 = dr.get_task_instance(task1.task_id, session=session)
+ ti2 = dr.get_task_instance(task2.task_id, session=session)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
session.flush()
@@ -3793,7 +3793,7 @@ class TestSchedulerJob:
bundle_version=orm_dag.bundle_version,
context_from_server=DagRunContext(
dag_run=dr,
- last_ti=dr.get_task_instance("dummy", session),
+ last_ti=dr.get_task_instance("dummy", session=session),
),
msg="timed_out",
)
@@ -3891,8 +3891,8 @@ class TestSchedulerJob:
self.job_runner = SchedulerJobRunner(job=scheduler_job,
executors=[self.null_exec])
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance("dummy", session)
- ti.set_state(state, session)
+ ti = dr.get_task_instance("dummy", session=session)
+ ti.set_state(state, session=session)
session.flush()
self.job_runner._do_scheduling(session)
@@ -3942,8 +3942,8 @@ class TestSchedulerJob:
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance("dummy", session)
- ti.set_state(state, session)
+ ti = dr.get_task_instance("dummy", session=session)
+ ti.set_state(state, session=session)
self.job_runner._do_scheduling(session)
@@ -3989,7 +3989,7 @@ class TestSchedulerJob:
bundle_version=None,
context_from_server=DagRunContext(
dag_run=dr,
- last_ti=dr.get_task_instance("empty", session),
+ last_ti=dr.get_task_instance("empty", session=session),
),
)
@@ -4058,8 +4058,8 @@ class TestSchedulerJob:
self.job_runner._send_dag_callbacks_to_processor = mock.Mock()
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance("test_task", session)
- ti.set_state(state, session)
+ ti = dr.get_task_instance("test_task", session=session)
+ ti.set_state(state, session=session)
self.job_runner._do_scheduling(session)
@@ -5818,7 +5818,7 @@ class TestSchedulerJob:
triggered_by=DagRunTriggeredByType.TEST,
)
- run1_ti = run1.get_task_instance(task1.task_id, session)
+ run1_ti = run1.get_task_instance(task1.task_id, session=session)
run1_ti.state = State.RUNNING
logical_date_2 = DEFAULT_DATE + timedelta(seconds=10)
@@ -5853,7 +5853,7 @@ class TestSchedulerJob:
assert run2.state == State.RUNNING
self.job_runner._schedule_dag_run(run2, session)
session.expunge_all()
- run2_ti = run2.get_task_instance(task1.task_id, session)
+ run2_ti = run2.get_task_instance(task1.task_id, session=session)
assert run2_ti.state == State.SCHEDULED
def test_do_schedule_max_active_runs_task_removed(self, session,
dag_maker):
@@ -5922,7 +5922,7 @@ class TestSchedulerJob:
# set dagrun to success
dr = session.scalars(select(DagRun)).one()
dr.state = DagRunState.SUCCESS
- ti = dr.get_task_instance("task", session)
+ ti = dr.get_task_instance("task", session=session)
ti.state = TaskInstanceState.SUCCESS
session.merge(ti)
session.merge(dr)
@@ -7226,8 +7226,8 @@ class TestSchedulerJob:
dr2 = dag_maker.create_dagrun(
run_id="test2", logical_date=DEFAULT_DATE +
datetime.timedelta(seconds=1)
)
- ti1 = dr1.get_task_instance("dummy1", session)
- ti2 = dr2.get_task_instance("dummy1", session)
+ ti1 = dr1.get_task_instance("dummy1", session=session)
+ ti2 = dr2.get_task_instance("dummy1", session=session)
ti1.state = State.DEFERRED
ti1.trigger_timeout = timezone.utcnow() -
datetime.timedelta(seconds=60)
ti2.state = State.DEFERRED
@@ -7292,8 +7292,8 @@ class TestSchedulerJob:
dr2 = dag_maker.create_dagrun(
run_id="test2", logical_date=DEFAULT_DATE +
datetime.timedelta(seconds=1)
)
- ti1 = dr1.get_task_instance("dummy1", session)
- ti2 = dr2.get_task_instance("dummy1", session)
+ ti1 = dr1.get_task_instance("dummy1", session=session)
+ ti2 = dr2.get_task_instance("dummy1", session=session)
ti1.state = State.DEFERRED
ti1.trigger_timeout = timezone.utcnow() -
datetime.timedelta(seconds=60)
ti2.state = State.DEFERRED
@@ -7700,7 +7700,7 @@ class TestSchedulerJob:
scheduled_run.last_scheduling_decision =
datetime.datetime.now(timezone.utc) - timedelta(minutes=1)
ti = scheduled_run.get_task_instances(session=session)[0]
ti.set_state(TaskInstanceState.RUNNING)
- dm = DagModel.get_dagmodel(dag.dag_id, session)
+ dm = DagModel.get_dagmodel(dag.dag_id, session=session)
dm.is_paused = True
session.flush()
@@ -8650,12 +8650,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="task_a")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._get_workload_team_name(ti, session)
+ result = self.job_runner._get_workload_team_name(ti, session=session)
assert result == "team_a"
def test_multi_team_get_workload_team_name_no_team(self, dag_maker,
session):
@@ -8664,12 +8664,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="task_no_team")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._get_workload_team_name(ti, session)
+ result = self.job_runner._get_workload_team_name(ti, session=session)
assert result is None
def test_multi_team_get_workload_team_name_database_error(self, dag_maker,
session):
@@ -8678,14 +8678,14 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="task_test")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
# Mock _get_team_names_for_dag_ids to return empty dict (simulates
database error handling in that function)
with mock.patch.object(self.job_runner, "_get_team_names_for_dag_ids",
return_value={}) as mock_batch:
- result = self.job_runner._get_workload_team_name(ti, session)
+ result = self.job_runner._get_workload_team_name(ti,
session=session)
mock_batch.assert_called_once_with([ti.dag_id], session)
# Should return None when batch function returns empty dict
@@ -8698,13 +8698,13 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task",
executor="secondary_exec")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
with mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_team_resolve:
- result = self.job_runner._try_to_load_executor(ti, session)
+ result = self.job_runner._try_to_load_executor(ti, session=session)
# Should not call team resolution when multi_team is disabled
mock_team_resolve.assert_not_called()
@@ -8719,12 +8719,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task") # No explicit executor
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._try_to_load_executor(ti, session)
+ result = self.job_runner._try_to_load_executor(ti, session=session)
# Should return the global default executor (first executor in Job)
assert result == self.job_runner.executor
@@ -8753,12 +8753,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task") # No explicit executor
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._try_to_load_executor(ti, session)
+ result = self.job_runner._try_to_load_executor(ti, session=session)
# Should return the team-specific default executor set above
assert result == mock_executors[1]
@@ -8785,12 +8785,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task") # No explicit executor
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._try_to_load_executor(ti, session)
+ result = self.job_runner._try_to_load_executor(ti, session=session)
# Should return the team-specific default executor set above
assert result == mock_executors[0]
@@ -8819,12 +8819,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task",
executor="secondary_exec")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._try_to_load_executor(ti, session)
+ result = self.job_runner._try_to_load_executor(ti, session=session)
# Should return the team-specific executor that matches the explicit
executor name
assert result == mock_executors[1]
@@ -8853,12 +8853,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task", executor="default_exec")
# Global executor
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._try_to_load_executor(ti, session)
+ result = self.job_runner._try_to_load_executor(ti, session=session)
# Should return the global executor (default) even though task has a
team
assert result == mock_executors[0]
@@ -8890,7 +8890,7 @@ class TestSchedulerJob:
) # Executor for different team
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
@@ -8915,7 +8915,7 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task",
executor="nonexistent_executor")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
@@ -8947,14 +8947,14 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
# Call with pre-resolved team name (as done in the scheduling loop)
with mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_team_resolve:
- result = self.job_runner._try_to_load_executor(ti, session,
team_name="team_a")
+ result = self.job_runner._try_to_load_executor(ti,
session=session, team_name="team_a")
mock_team_resolve.assert_not_called() # We don't query for the
team if it is pre-resolved
assert result == mock_executors[1]
@@ -8980,12 +8980,12 @@ class TestSchedulerJob:
task = EmptyOperator(task_id="test_task", executor="LocalExecutor")
dr = dag_maker.create_dagrun()
- ti = dr.get_task_instance(task.task_id, session)
+ ti = dr.get_task_instance(task.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
- result = self.job_runner._try_to_load_executor(ti, session)
+ result = self.job_runner._try_to_load_executor(ti, session=session)
# Should match by classname (last component of module_path) and return
the global executor
assert result == mock_executors[0]
@@ -9019,8 +9019,8 @@ class TestSchedulerJob:
EmptyOperator(task_id="task_b")
dr2 = dag_maker.create_dagrun()
- ti1 = dr1.get_task_instance("task_a", session)
- ti2 = dr2.get_task_instance("task_b", session)
+ ti1 = dr1.get_task_instance("task_a", session=session)
+ ti2 = dr2.get_task_instance("task_b", session=session)
ti1.state = State.SCHEDULED
ti2.state = State.SCHEDULED
session.flush()
@@ -9067,8 +9067,8 @@ class TestSchedulerJob:
EmptyOperator(task_id="task_b")
dr2 = dag_maker.create_dagrun()
- ti1 = dr1.get_task_instance("task_a", session)
- ti2 = dr2.get_task_instance("task_b", session)
+ ti1 = dr1.get_task_instance("task_a", session=session)
+ ti2 = dr2.get_task_instance("task_b", session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
@@ -9077,7 +9077,7 @@ class TestSchedulerJob:
assert_queries_count(1, session=session),
mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_single,
):
- executor_to_workloads =
self.job_runner._executor_to_workloads([ti1, ti2], session)
+ executor_to_workloads =
self.job_runner._executor_to_workloads([ti1, ti2], session=session)
mock_single.assert_not_called()
assert executor_to_workloads[mock_executors[0]] == [ti1]
@@ -9091,15 +9091,15 @@ class TestSchedulerJob:
task2 = EmptyOperator(task_id="test_task2",
executor="secondary_exec")
dr = dag_maker.create_dagrun()
- ti1 = dr.get_task_instance(task1.task_id, session)
- ti2 = dr.get_task_instance(task2.task_id, session)
+ ti1 = dr.get_task_instance(task1.task_id, session=session)
+ ti2 = dr.get_task_instance(task2.task_id, session=session)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)
with mock.patch.object(self.job_runner, "_get_workload_team_name") as
mock_team_resolve:
- result1 = self.job_runner._try_to_load_executor(ti1, session)
- result2 = self.job_runner._try_to_load_executor(ti2, session)
+ result1 = self.job_runner._try_to_load_executor(ti1,
session=session)
+ result2 = self.job_runner._try_to_load_executor(ti2,
session=session)
# Should use legacy logic without calling team resolution
mock_team_resolve.assert_not_called()
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py
b/airflow-core/tests/unit/models/test_cleartasks.py
index e82cb46f0f4..58eb1b37b12 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -903,7 +903,7 @@ class TestClearTasks:
assert sorted(new_tis) == ["2", "3"]
session.rollback()
- dr.refresh_from_db(session)
+ dr.refresh_from_db(session=session)
assert dr.created_dag_version_id == old_dag_version.id
assert len(dr.task_instances) == 2 # should be only the 2 earlier
tasks
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 7ac2514cbf7..e8a946a6aa1 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -787,7 +787,7 @@ class TestDagRun:
...
self.create_dag_run(dag, logical_date=timezone.datetime(2015, 1, 1),
session=session)
self.create_dag_run(dag, logical_date=timezone.datetime(2015, 1, 2),
session=session)
- dagruns = DagRun.get_latest_runs(session)
+ dagruns = DagRun.get_latest_runs(session=session)
session.close()
for dagrun in dagruns:
if dagrun.dag_id == "test_latest_runs_1":
@@ -1121,12 +1121,12 @@ class TestDagRun:
triggered_by=DagRunTriggeredByType.TEST,
session=session,
)
- ti = dag_run.get_task_instance(dag_task.task_id, session)
- ti.set_state(TaskInstanceState.SUCCESS, session)
+ ti = dag_run.get_task_instance(dag_task.task_id, session=session)
+ ti.set_state(TaskInstanceState.SUCCESS, session=session)
session.flush()
with
mock.patch("airflow._shared.observability.metrics.stats.timing") as stats_mock:
- dag_run.update_state(session)
+ dag_run.update_state(session=session)
metric_name = f"dagrun.{dag.dag_id}.first_task_scheduling_delay"
@@ -1209,13 +1209,13 @@ class TestDagRun:
session=session,
)
dag_run.queued_at = queued_at
- ti = dag_run.get_task_instance(dag_task.task_id, session)
- ti.set_state(TaskInstanceState.SUCCESS, session)
+ ti = dag_run.get_task_instance(dag_task.task_id, session=session)
+ ti.set_state(TaskInstanceState.SUCCESS, session=session)
ti.start_date = ti_start_date
session.flush()
with
mock.patch("airflow._shared.observability.metrics.stats.timing") as stats_mock:
- dag_run.update_state(session)
+ dag_run.update_state(session=session)
start_delay_call = call("dagrun.first_task_start_delay", mock.ANY,
tags=expected_stat_tags)
if expected:
@@ -3178,7 +3178,7 @@ def
test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session):
session.commit()
for table in [TaskInstanceNote, TaskReschedule, XComModel]:
assert session.scalar(select(func.count()).select_from(table)) == 1
- dr1.task_instance_scheduling_decisions(session)
+ dr1.task_instance_scheduling_decisions(session=session)
for table in [TaskInstanceNote, TaskReschedule, XComModel]:
assert session.scalar(select(func.count()).select_from(table)) == 0
@@ -3408,7 +3408,7 @@ def test_tis_considered_for_state(dag_maker, session,
input, expected):
reduce(lambda x, y: x >> y, tasks)
dr = dag_maker.create_dagrun()
- tis = dr.task_instance_scheduling_decisions(session).tis
+ tis = dr.task_instance_scheduling_decisions(session=session).tis
tis_for_state = {x.task_id for x in dr._tis_for_dagrun_state(dag=dag,
tis=tis)}
assert tis_for_state == expected
diff --git a/airflow-core/tests/unit/models/test_dagwarning.py
b/airflow-core/tests/unit/models/test_dagwarning.py
index 167244afbc1..0c5a3d3f969 100644
--- a/airflow-core/tests/unit/models/test_dagwarning.py
+++ b/airflow-core/tests/unit/models/test_dagwarning.py
@@ -54,7 +54,7 @@ class TestDagWarning:
session.add_all(dag_warnings)
session.commit()
- DagWarning.purge_inactive_dag_warnings(session)
+ DagWarning.purge_inactive_dag_warnings(session=session)
remaining_dag_warnings = session.scalars(select(DagWarning)).all()
assert len(remaining_dag_warnings) == 1
@@ -70,7 +70,7 @@ class TestDagWarning:
self.session_mock.execute.side_effect = [OperationalError(None, None,
"database timeout"), None]
- DagWarning.purge_inactive_dag_warnings(self.session_mock)
+ DagWarning.purge_inactive_dag_warnings(session=self.session_mock)
# Assert that the delete method was called twice
assert delete_mock.call_count == 2
diff --git
a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
index cfc4c0c3faf..8270386754f 100644
--- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
+++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
@@ -511,7 +511,7 @@ class ExternalTaskSensor(BaseSensorOperator):
)
def _check_for_existence(self, session: Session) -> None:
- dag_to_wait = DagModel.get_current(self.external_dag_id, session)
+ dag_to_wait = DagModel.get_current(self.external_dag_id,
session=session)
if not dag_to_wait:
raise ExternalDagNotFoundError(f"The external DAG
{self.external_dag_id} does not exist.")
diff --git a/scripts/ci/prek/known_provide_session_positional.txt
b/scripts/ci/prek/known_provide_session_positional.txt
index c6351fe5bd7..164938f1b8e 100644
--- a/scripts/ci/prek/known_provide_session_positional.txt
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -1,8 +1,4 @@
airflow-core/src/airflow/models/connection.py::2
-airflow-core/src/airflow/models/dag.py::7
-airflow-core/src/airflow/models/dagcode.py::6
-airflow-core/src/airflow/models/dagrun.py::15
-airflow-core/src/airflow/models/dagwarning.py::1
airflow-core/src/airflow/models/deadline.py::1
airflow-core/src/airflow/models/deadline_alert.py::1
airflow-core/src/airflow/models/pool.py::11