This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9f250d15ef2 remove N+1 db queries for team names (#61471)
9f250d15ef2 is described below
commit 9f250d15ef2a857c789fa3a58736a03afd765d29
Author: Steve Ahn <[email protected]>
AuthorDate: Mon Feb 9 21:21:49 2026 -0800
remove N+1 db queries for team names (#61471)
---
.../src/airflow/jobs/scheduler_job_runner.py | 42 ++++++++++++++++++--
airflow-core/tests/unit/jobs/test_scheduler_job.py | 45 ++++++++++++++++++++++
2 files changed, 83 insertions(+), 4 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 96b865f75bb..765bb9e5db1 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -2881,6 +2881,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
def _purge_task_instances_without_heartbeats(
self, task_instances_without_heartbeats: list[TI], *, session: Session
) -> None:
+ if conf.getboolean("core", "multi_team"):
+ unique_dag_ids = {ti.dag_id for ti in
task_instances_without_heartbeats}
+ dag_id_to_team_name =
self._get_team_names_for_dag_ids(unique_dag_ids, session)
+ else:
+ dag_id_to_team_name = {}
+
for ti in task_instances_without_heartbeats:
task_instance_heartbeat_timeout_message_details = (
self._generate_task_instance_heartbeat_timeout_message_details(ti)
@@ -2925,7 +2931,10 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
request,
)
self.job.executor.send_callback(request)
- if (executor := self._try_to_load_executor(ti, session)) is None:
+ executor = self._try_to_load_executor(
+ ti, session, team_name=dag_id_to_team_name.get(ti.dag_id,
NOTSET)
+ )
+ if executor is None:
self.log.warning(
"Cannot clean up task instance without heartbeat %r with
non-existent executor %s",
ti,
@@ -3099,12 +3108,37 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
existing_warned_dag_ids.add(warning.dag_id)
def _executor_to_tis(
- self, tis: Iterable[TaskInstance], session
+ self,
+ tis: Iterable[TaskInstance],
+ session,
+ dag_id_to_team_name: dict[str, str | None] | None = None,
) -> dict[BaseExecutor, list[TaskInstance]]:
"""Organize TIs into lists per their respective executor."""
+ tis_iter: Iterable[TaskInstance]
+ if conf.getboolean("core", "multi_team"):
+ if dag_id_to_team_name is None:
+ if isinstance(tis, list):
+ tis_list = tis
+ else:
+ tis_list = list(tis)
+ if tis_list:
+ dag_id_to_team_name = self._get_team_names_for_dag_ids(
+ {ti.dag_id for ti in tis_list}, session
+ )
+ else:
+ dag_id_to_team_name = {}
+ tis_iter = tis_list
+ else:
+ tis_iter = tis
+ else:
+ dag_id_to_team_name = {}
+ tis_iter = tis
+
_executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] =
defaultdict(list)
- for ti in tis:
- if executor_obj := self._try_to_load_executor(ti, session):
+ for ti in tis_iter:
+ if executor_obj := self._try_to_load_executor(
+ ti, session, team_name=dag_id_to_team_name.get(ti.dag_id,
NOTSET)
+ ):
_executor_to_tis[executor_obj].append(ti)
return _executor_to_tis
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 069d72b5eca..0130a04352f 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -8233,6 +8233,51 @@ class TestSchedulerJob:
mock_batch.assert_called_once_with({"dag_a", "dag_b"}, session)
assert len(res) == 2
+ @conf_vars({("core", "multi_team"): "true"})
+ def test_multi_team_executor_to_tis_batch_optimization(self, dag_maker,
mock_executors, session):
+ """Test that executor mapping batches team resolution for task
instances."""
+ clear_db_teams()
+ clear_db_dag_bundles()
+
+ team1 = Team(name="team_a")
+ team2 = Team(name="team_b")
+ session.add_all([team1, team2])
+ session.flush()
+
+ bundle1 = DagBundleModel(name="bundle_a")
+ bundle2 = DagBundleModel(name="bundle_b")
+ bundle1.teams.append(team1)
+ bundle2.teams.append(team2)
+ session.add_all([bundle1, bundle2])
+ session.flush()
+
+ mock_executors[0].team_name = "team_a"
+ mock_executors[1].team_name = "team_b"
+
+ with dag_maker(dag_id="dag_a", bundle_name="bundle_a",
session=session):
+ EmptyOperator(task_id="task_a")
+ dr1 = dag_maker.create_dagrun()
+
+ with dag_maker(dag_id="dag_b", bundle_name="bundle_b",
session=session):
+ EmptyOperator(task_id="task_b")
+ dr2 = dag_maker.create_dagrun()
+
+ ti1 = dr1.get_task_instance("task_a", session)
+ ti2 = dr2.get_task_instance("task_b", session)
+
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=scheduler_job)
+
+ with (
+ assert_queries_count(1, session=session),
+ mock.patch.object(self.job_runner, "_get_task_team_name") as
mock_single,
+ ):
+ executor_to_tis = self.job_runner._executor_to_tis([ti1, ti2],
session)
+
+ mock_single.assert_not_called()
+ assert executor_to_tis[mock_executors[0]] == [ti1]
+ assert executor_to_tis[mock_executors[1]] == [ti2]
+
@conf_vars({("core", "multi_team"): "false"})
def test_multi_team_config_disabled_uses_legacy_behavior(self, dag_maker,
mock_executors, session):
"""Test that when multi_team config is disabled, legacy behavior is
preserved."""