Asquator commented on code in PR #53492:
URL: https://github.com/apache/airflow/pull/53492#discussion_r2238938830
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -373,292 +375,165 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative -
it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
+ starved_tasks: set[tuple[str, str]] = set()
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
- max_tis = min(max_tis, pool_slots_free)
-
- starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
-
- # dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = ConcurrencyMap()
- concurrency_map.load(session=session)
+ priority_order = [-TI.priority_weight, DR.logical_date, TI.map_index]
- # Number of tasks that cannot be scheduled because of no open slot in
pool
- num_starving_tasks_total = 0
-
- # dag and task ids that can't be queued because of concurrency limits
- starved_dags: set[str] = set()
- starved_tasks: set[tuple[str, str]] = set()
- starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
-
- pool_num_starving_tasks: dict[str, int] = Counter()
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .options(selectinload(TI.dag_model))
+ )
- for loop_count in itertools.count(start=1):
- num_starved_pools = len(starved_pools)
- num_starved_dags = len(starved_dags)
- num_starved_tasks = len(starved_tasks)
- num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
+ @dataclass
+ class Limit:
+ running_now_join: Subquery
+ max_units: Column
+ window: expression.ColumnElement
+
+ def running_tasks_group(*group_fields):
+ return (
+ select(TI, func.count("*").label("now_running"))
+ .where(TI.state.in_(EXECUTION_STATES))
+ .group_by(*group_fields)
+ .cte()
+ )
- query = (
+ def add_window_limit(query: Select, limit: Limit):
+ inner_query = (
+ query.add_columns(limit.window)
+ .join(limit.running_now_join, TI.id ==
limit.running_now_join.c.id)
+ .order_by(*priority_order)
+ .subquery()
+ )
+ return (
select(TI)
- .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
- .join(TI.dag_run)
- .where(DR.state == DagRunState.RUNNING)
- .join(TI.dag_model)
- .where(~DM.is_paused)
- .where(TI.state == TaskInstanceState.SCHEDULED)
- .where(DM.bundle_name.is_not(None))
- .options(selectinload(TI.dag_model))
- .order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
+ .join(inner_query, TI.id == inner_query.c.id)
+ .where(
+ getattr(inner_query.c, limit.window.name) +
limit.running_now_join.c.now_running
+ < limit.max_units
+ )
)
- if starved_pools:
- query = query.where(TI.pool.not_in(starved_pools))
+ running_total_tis_per_dagrun = running_tasks_group(TI.dag_id,
TI.run_id)
+ running_tis_per_dag = running_tasks_group(TI.dag_id, TI.task_id)
+ running_total_tis_per_task_run = running_tasks_group(TI.dag_id,
TI.run_id, TI.task_id)
+ running_tis_per_pool = running_tasks_group(TI.pool)
- if starved_dags:
- query = query.where(TI.dag_id.not_in(starved_dags))
+ total_tis_per_dagrun_count = (
+ func.row_number()
+ .over(partition_by=(TI.dag_id, TI.run_id), order_by=priority_order)
+ .label("total_tis_per_dagrun_count")
+ )
+ tis_per_dag_count = (
+ func.row_number()
+ .over(partition_by=(TI.dag_id, TI.task_id),
order_by=priority_order)
+ .label("tis_per_dag_count")
+ )
+ mapped_tis_per_task_run_count = (
+ func.row_number()
+ .over(partition_by=(TI.dag_id, TI.run_id, TI.task_id),
order_by=priority_order)
+ .label("mapped_tis_per_dagrun_count")
+ )
+ pool_slots_taken = (
+ func.sum(TI.pool_slots)
+ .over(partition_by=(TI.pool), order_by=priority_order)
+ .label("pool_slots_taken")
+ )
- if starved_tasks:
- query = query.where(tuple_(TI.dag_id,
TI.task_id).not_in(starved_tasks))
+ limits = [
+ Limit(running_total_tis_per_dagrun, DagModel.max_active_tasks,
total_tis_per_dagrun_count),
+ Limit(
+ running_tis_per_dag, TaskInstance.max_active_tis_per_dag,
tis_per_dag_count
+ ), # TODO: Add to DB model: TaskInstance.max_active_tis_per_dag,
DUMMY: DagModel.max_active_tasks
+ Limit(
+ running_total_tis_per_task_run, TI.max_tries,
mapped_tis_per_task_run_count
+ ), # TODO: Add to DB model: TI.max_active_tis_per_dagrun, DUMMY:
TI.max_tries
+ Limit(running_tis_per_pool, Pool.slots, pool_slots_taken),
+ ]
- if starved_tasks_task_dagrun_concurrency:
- query = query.where(
- tuple_(TI.dag_id, TI.run_id,
TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
- )
+ for limit in limits:
+ query = add_window_limit(query, limit)
- query = query.limit(max_tis)
+ query = query.limit(max_tis)
- timer = Stats.timer("scheduler.critical_section_query_duration")
- timer.start()
+ timer = Stats.timer("scheduler.critical_section_query_duration")
+ timer.start()
- try:
- query = with_row_locks(query, of=TI, session=session,
skip_locked=True)
- task_instances_to_examine: list[TI] =
session.scalars(query).all()
+ try:
+ query = with_row_locks(query, of=TI, session=session,
skip_locked=True)
+ task_instances_to_examine: list[TI] = session.scalars(query).all()
- timer.stop(send=True)
- except OperationalError as e:
- timer.stop(send=False)
- raise e
+ timer.stop(send=True)
+ except OperationalError as e:
+ timer.stop(send=False)
+ raise e
- # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
- # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
+ # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
+ # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
- if not task_instances_to_examine:
- self.log.debug("No tasks to consider for execution.")
- break
+ if not task_instances_to_examine:
+ self.log.debug("No tasks to consider for execution.")
+ return []
- # Put one task instance on each line
- task_instance_str = "\n".join(f"\t{x!r}" for x in
task_instances_to_examine)
- self.log.info("%s tasks up for execution:\n%s",
len(task_instances_to_examine), task_instance_str)
+ # Put one task instance on each line
+ task_instance_str = "\n".join(f"\t{x!r}" for x in
task_instances_to_examine)
+ self.log.info("%s tasks up for execution:\n%s",
len(task_instances_to_examine), task_instance_str)
- executor_slots_available: dict[ExecutorName, int] = {}
- # First get a mapping of executor names to slots they have
available
- for executor in self.job.executors:
+ executor_slots_available: dict[ExecutorName, int] = {}
+ # First get a mapping of executor names to slots they have available
+ for executor in self.job.executors:
+ if TYPE_CHECKING:
+ # All executors should have a name if they are initted from
the executor_loader.
+ # But we need to check for None to make mypy happy.
+ assert executor.name
+ executor_slots_available[executor.name] = executor.slots_available
+
+ for task_instance in task_instances_to_examine:
+ pool_name = task_instance.pool
+
+ pool_stats = pools.get(pool_name)
+ if not pool_stats:
+ self.log.warning("Tasks using non-existent pool '%s' will not
be scheduled", pool_name)
+ starved_pools.add(pool_name)
+ continue
+
+ if executor_obj :=
self._try_to_load_executor(task_instance.executor):
if TYPE_CHECKING:
# All executors should have a name if they are initted
from the executor_loader.
# But we need to check for None to make mypy happy.
- assert executor.name
- executor_slots_available[executor.name] =
executor.slots_available
-
- for task_instance in task_instances_to_examine:
- pool_name = task_instance.pool
-
- pool_stats = pools.get(pool_name)
- if not pool_stats:
- self.log.warning("Tasks using non-existent pool '%s' will
not be scheduled", pool_name)
- starved_pools.add(pool_name)
- continue
-
- # Make sure to emit metrics if pool has no starving tasks
- pool_num_starving_tasks.setdefault(pool_name, 0)
-
- pool_total = pool_stats["total"]
- open_slots = pool_stats["open"]
-
- if open_slots <= 0:
- self.log.info(
- "Not scheduling since there are %s open slots in pool
%s", open_slots, pool_name
- )
- # Can't schedule any more since there are no more open
slots.
- pool_num_starving_tasks[pool_name] += 1
- num_starving_tasks_total += 1
- starved_pools.add(pool_name)
- continue
-
- if task_instance.pool_slots > pool_total:
- self.log.warning(
- "Not executing %s. Requested pool slots (%s) are
greater than "
- "total pool slots: '%s' for pool: %s.",
- task_instance,
- task_instance.pool_slots,
- pool_total,
- pool_name,
- )
-
- pool_num_starving_tasks[pool_name] += 1
- num_starving_tasks_total += 1
- starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
- continue
-
- if task_instance.pool_slots > open_slots:
- self.log.info(
- "Not executing %s since it requires %s slots "
- "but there are %s open slots in the pool %s.",
- task_instance,
- task_instance.pool_slots,
- open_slots,
- pool_name,
+ assert executor_obj.name
+ if executor_slots_available[executor_obj.name] <= 0:
+ self.log.debug(
Review Comment:
Not in the scope of this change
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]