Nataneljpwd commented on code in PR #53492:
URL: https://github.com/apache/airflow/pull/53492#discussion_r2237286013
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -373,292 +375,165 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative -
it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
+ starved_tasks: set[tuple[str, str]] = set()
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
- max_tis = min(max_tis, pool_slots_free)
-
- starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
-
- # dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = ConcurrencyMap()
- concurrency_map.load(session=session)
+ priority_order = [-TI.priority_weight, DR.logical_date, TI.map_index]
- # Number of tasks that cannot be scheduled because of no open slot in
pool
- num_starving_tasks_total = 0
-
- # dag and task ids that can't be queued because of concurrency limits
- starved_dags: set[str] = set()
- starved_tasks: set[tuple[str, str]] = set()
- starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
-
- pool_num_starving_tasks: dict[str, int] = Counter()
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .options(selectinload(TI.dag_model))
+ )
- for loop_count in itertools.count(start=1):
- num_starved_pools = len(starved_pools)
- num_starved_dags = len(starved_dags)
- num_starved_tasks = len(starved_tasks)
- num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
+ @dataclass
+ class Limit:
+ running_now_join: Subquery
+ max_units: Column
+ window: expression.ColumnElement
+
+ def running_tasks_group(*group_fields):
Review Comment:
```suggestion
def running_tasks_group(*group_fields: list[Column]) -> CTE:
```
what exactly does this method do? does it just group tasks by the groups
given? what is the EXECUTION_STATES?
that would make it a little nicer to read
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -373,292 +375,165 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative -
it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
+ starved_tasks: set[tuple[str, str]] = set()
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
- max_tis = min(max_tis, pool_slots_free)
-
- starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
-
- # dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = ConcurrencyMap()
- concurrency_map.load(session=session)
+ priority_order = [-TI.priority_weight, DR.logical_date, TI.map_index]
- # Number of tasks that cannot be scheduled because of no open slot in
pool
- num_starving_tasks_total = 0
-
- # dag and task ids that can't be queued because of concurrency limits
- starved_dags: set[str] = set()
- starved_tasks: set[tuple[str, str]] = set()
- starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
-
- pool_num_starving_tasks: dict[str, int] = Counter()
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .options(selectinload(TI.dag_model))
+ )
- for loop_count in itertools.count(start=1):
- num_starved_pools = len(starved_pools)
- num_starved_dags = len(starved_dags)
- num_starved_tasks = len(starved_tasks)
- num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
+ @dataclass
+ class Limit:
+ running_now_join: Subquery
+ max_units: Column
+ window: expression.ColumnElement
+
+ def running_tasks_group(*group_fields):
+ return (
+ select(TI, func.count("*").label("now_running"))
+ .where(TI.state.in_(EXECUTION_STATES))
+ .group_by(*group_fields)
+ .cte()
+ )
- query = (
+ def add_window_limit(query: Select, limit: Limit):
Review Comment:
```suggestion
def add_window_limit(query: Select, limit: Limit) -> Select:
```
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -373,292 +375,165 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative -
it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
+ starved_tasks: set[tuple[str, str]] = set()
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
- max_tis = min(max_tis, pool_slots_free)
-
- starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
-
- # dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = ConcurrencyMap()
- concurrency_map.load(session=session)
+ priority_order = [-TI.priority_weight, DR.logical_date, TI.map_index]
- # Number of tasks that cannot be scheduled because of no open slot in
pool
- num_starving_tasks_total = 0
-
- # dag and task ids that can't be queued because of concurrency limits
- starved_dags: set[str] = set()
- starved_tasks: set[tuple[str, str]] = set()
- starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
-
- pool_num_starving_tasks: dict[str, int] = Counter()
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .options(selectinload(TI.dag_model))
+ )
- for loop_count in itertools.count(start=1):
- num_starved_pools = len(starved_pools)
- num_starved_dags = len(starved_dags)
- num_starved_tasks = len(starved_tasks)
- num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
+ @dataclass
+ class Limit:
+ running_now_join: Subquery
+ max_units: Column
+ window: expression.ColumnElement
Review Comment:
why does a Limit hold a column?
why does a limit hold a subquery as well? maybe the name should be something
more along the lines of `WindowDescriptor` or something along those lines?
a limit should not hold columns or column elements, it would also help if
you added a docstring on this dataclass, and give the variables more general
names, so that it would be simpler to understand
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -373,292 +375,165 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative -
it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
+ starved_tasks: set[tuple[str, str]] = set()
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
- max_tis = min(max_tis, pool_slots_free)
-
- starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
-
- # dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = ConcurrencyMap()
- concurrency_map.load(session=session)
+ priority_order = [-TI.priority_weight, DR.logical_date, TI.map_index]
- # Number of tasks that cannot be scheduled because of no open slot in
pool
- num_starving_tasks_total = 0
-
- # dag and task ids that can't be queued because of concurrency limits
- starved_dags: set[str] = set()
- starved_tasks: set[tuple[str, str]] = set()
- starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
-
- pool_num_starving_tasks: dict[str, int] = Counter()
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .options(selectinload(TI.dag_model))
+ )
- for loop_count in itertools.count(start=1):
- num_starved_pools = len(starved_pools)
- num_starved_dags = len(starved_dags)
- num_starved_tasks = len(starved_tasks)
- num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
+ @dataclass
+ class Limit:
+ running_now_join: Subquery
+ max_units: Column
+ window: expression.ColumnElement
+
+ def running_tasks_group(*group_fields):
+ return (
+ select(TI, func.count("*").label("now_running"))
Review Comment:
```suggestion
select(TI, func.count("id").label("now_running"))
```
to enable projection in sql, will make the query slightly faster
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -373,292 +375,165 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative -
it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
+ starved_tasks: set[tuple[str, str]] = set()
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
- max_tis = min(max_tis, pool_slots_free)
-
- starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
-
- # dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = ConcurrencyMap()
- concurrency_map.load(session=session)
+ priority_order = [-TI.priority_weight, DR.logical_date, TI.map_index]
- # Number of tasks that cannot be scheduled because of no open slot in
pool
- num_starving_tasks_total = 0
-
- # dag and task ids that can't be queued because of concurrency limits
- starved_dags: set[str] = set()
- starved_tasks: set[tuple[str, str]] = set()
- starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
-
- pool_num_starving_tasks: dict[str, int] = Counter()
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .options(selectinload(TI.dag_model))
+ )
- for loop_count in itertools.count(start=1):
- num_starved_pools = len(starved_pools)
- num_starved_dags = len(starved_dags)
- num_starved_tasks = len(starved_tasks)
- num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
+ @dataclass
+ class Limit:
+ running_now_join: Subquery
+ max_units: Column
+ window: expression.ColumnElement
Review Comment:
```suggestion
class LimitWindowDescriptor:
join_on: Subquery
choose_up_to: Column
window_by: expression.ColumnElement
```
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -373,292 +375,165 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative -
it breaks SQL
pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
+ starved_tasks: set[tuple[str, str]] = set()
if pool_slots_free == 0:
self.log.debug("All pools are full!")
return []
- max_tis = min(max_tis, pool_slots_free)
-
- starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
-
- # dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = ConcurrencyMap()
- concurrency_map.load(session=session)
+ priority_order = [-TI.priority_weight, DR.logical_date, TI.map_index]
- # Number of tasks that cannot be scheduled because of no open slot in
pool
- num_starving_tasks_total = 0
-
- # dag and task ids that can't be queued because of concurrency limits
- starved_dags: set[str] = set()
- starved_tasks: set[tuple[str, str]] = set()
- starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
-
- pool_num_starving_tasks: dict[str, int] = Counter()
+ query = (
+ select(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(TI.dag_run)
+ .where(DR.state == DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .where(~DM.is_paused)
+ .where(TI.state == TaskInstanceState.SCHEDULED)
+ .where(DM.bundle_name.is_not(None))
+ .options(selectinload(TI.dag_model))
+ )
- for loop_count in itertools.count(start=1):
- num_starved_pools = len(starved_pools)
- num_starved_dags = len(starved_dags)
- num_starved_tasks = len(starved_tasks)
- num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
+ @dataclass
+ class Limit:
+ running_now_join: Subquery
+ max_units: Column
+ window: expression.ColumnElement
+
+ def running_tasks_group(*group_fields):
+ return (
+ select(TI, func.count("*").label("now_running"))
+ .where(TI.state.in_(EXECUTION_STATES))
+ .group_by(*group_fields)
+ .cte()
+ )
- query = (
+ def add_window_limit(query: Select, limit: Limit):
+ inner_query = (
+ query.add_columns(limit.window)
+ .join(limit.running_now_join, TI.id ==
limit.running_now_join.c.id)
+ .order_by(*priority_order)
+ .subquery()
+ )
+ return (
select(TI)
- .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
- .join(TI.dag_run)
- .where(DR.state == DagRunState.RUNNING)
- .join(TI.dag_model)
- .where(~DM.is_paused)
- .where(TI.state == TaskInstanceState.SCHEDULED)
- .where(DM.bundle_name.is_not(None))
- .options(selectinload(TI.dag_model))
- .order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
+ .join(inner_query, TI.id == inner_query.c.id)
+ .where(
+ getattr(inner_query.c, limit.window.name) +
limit.running_now_join.c.now_running
+ < limit.max_units
+ )
)
- if starved_pools:
- query = query.where(TI.pool.not_in(starved_pools))
+ running_total_tis_per_dagrun = running_tasks_group(TI.dag_id,
TI.run_id)
+ running_tis_per_dag = running_tasks_group(TI.dag_id, TI.task_id)
+ running_total_tis_per_task_run = running_tasks_group(TI.dag_id,
TI.run_id, TI.task_id)
+ running_tis_per_pool = running_tasks_group(TI.pool)
- if starved_dags:
- query = query.where(TI.dag_id.not_in(starved_dags))
+ total_tis_per_dagrun_count = (
+ func.row_number()
+ .over(partition_by=(TI.dag_id, TI.run_id), order_by=priority_order)
+ .label("total_tis_per_dagrun_count")
+ )
+ tis_per_dag_count = (
+ func.row_number()
+ .over(partition_by=(TI.dag_id, TI.task_id),
order_by=priority_order)
+ .label("tis_per_dag_count")
+ )
+ mapped_tis_per_task_run_count = (
+ func.row_number()
+ .over(partition_by=(TI.dag_id, TI.run_id, TI.task_id),
order_by=priority_order)
+ .label("mapped_tis_per_dagrun_count")
+ )
+ pool_slots_taken = (
+ func.sum(TI.pool_slots)
+ .over(partition_by=(TI.pool), order_by=priority_order)
+ .label("pool_slots_taken")
+ )
- if starved_tasks:
- query = query.where(tuple_(TI.dag_id,
TI.task_id).not_in(starved_tasks))
+ limits = [
+ Limit(running_total_tis_per_dagrun, DagModel.max_active_tasks,
total_tis_per_dagrun_count),
+ Limit(
+ running_tis_per_dag, TaskInstance.max_active_tis_per_dag,
tis_per_dag_count
+ ), # TODO: Add to DB model: TaskInstance.max_active_tis_per_dag,
DUMMY: DagModel.max_active_tasks
+ Limit(
+ running_total_tis_per_task_run, TI.max_tries,
mapped_tis_per_task_run_count
+ ), # TODO: Add to DB model: TI.max_active_tis_per_dagrun, DUMMY:
TI.max_tries
+ Limit(running_tis_per_pool, Pool.slots, pool_slots_taken),
+ ]
- if starved_tasks_task_dagrun_concurrency:
- query = query.where(
- tuple_(TI.dag_id, TI.run_id,
TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
- )
+ for limit in limits:
+ query = add_window_limit(query, limit)
- query = query.limit(max_tis)
+ query = query.limit(max_tis)
- timer = Stats.timer("scheduler.critical_section_query_duration")
- timer.start()
+ timer = Stats.timer("scheduler.critical_section_query_duration")
+ timer.start()
- try:
- query = with_row_locks(query, of=TI, session=session,
skip_locked=True)
- task_instances_to_examine: list[TI] =
session.scalars(query).all()
+ try:
+ query = with_row_locks(query, of=TI, session=session,
skip_locked=True)
+ task_instances_to_examine: list[TI] = session.scalars(query).all()
- timer.stop(send=True)
- except OperationalError as e:
- timer.stop(send=False)
- raise e
+ timer.stop(send=True)
+ except OperationalError as e:
+ timer.stop(send=False)
+ raise e
- # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
- # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
+ # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
+ # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
- if not task_instances_to_examine:
- self.log.debug("No tasks to consider for execution.")
- break
+ if not task_instances_to_examine:
+ self.log.debug("No tasks to consider for execution.")
+ return []
- # Put one task instance on each line
- task_instance_str = "\n".join(f"\t{x!r}" for x in
task_instances_to_examine)
- self.log.info("%s tasks up for execution:\n%s",
len(task_instances_to_examine), task_instance_str)
+ # Put one task instance on each line
+ task_instance_str = "\n".join(f"\t{x!r}" for x in
task_instances_to_examine)
+ self.log.info("%s tasks up for execution:\n%s",
len(task_instances_to_examine), task_instance_str)
- executor_slots_available: dict[ExecutorName, int] = {}
- # First get a mapping of executor names to slots they have
available
- for executor in self.job.executors:
+ executor_slots_available: dict[ExecutorName, int] = {}
+ # First get a mapping of executor names to slots they have available
+ for executor in self.job.executors:
+ if TYPE_CHECKING:
+ # All executors should have a name if they are initted from
the executor_loader.
+ # But we need to check for None to make mypy happy.
+ assert executor.name
+ executor_slots_available[executor.name] = executor.slots_available
+
+ for task_instance in task_instances_to_examine:
+ pool_name = task_instance.pool
+
+ pool_stats = pools.get(pool_name)
+ if not pool_stats:
+ self.log.warning("Tasks using non-existent pool '%s' will not
be scheduled", pool_name)
+ starved_pools.add(pool_name)
+ continue
+
+ if executor_obj :=
self._try_to_load_executor(task_instance.executor):
if TYPE_CHECKING:
# All executors should have a name if they are initted
from the executor_loader.
# But we need to check for None to make mypy happy.
- assert executor.name
- executor_slots_available[executor.name] =
executor.slots_available
-
- for task_instance in task_instances_to_examine:
- pool_name = task_instance.pool
-
- pool_stats = pools.get(pool_name)
- if not pool_stats:
- self.log.warning("Tasks using non-existent pool '%s' will
not be scheduled", pool_name)
- starved_pools.add(pool_name)
- continue
-
- # Make sure to emit metrics if pool has no starving tasks
- pool_num_starving_tasks.setdefault(pool_name, 0)
-
- pool_total = pool_stats["total"]
- open_slots = pool_stats["open"]
-
- if open_slots <= 0:
- self.log.info(
- "Not scheduling since there are %s open slots in pool
%s", open_slots, pool_name
- )
- # Can't schedule any more since there are no more open
slots.
- pool_num_starving_tasks[pool_name] += 1
- num_starving_tasks_total += 1
- starved_pools.add(pool_name)
- continue
-
- if task_instance.pool_slots > pool_total:
- self.log.warning(
- "Not executing %s. Requested pool slots (%s) are
greater than "
- "total pool slots: '%s' for pool: %s.",
- task_instance,
- task_instance.pool_slots,
- pool_total,
- pool_name,
- )
-
- pool_num_starving_tasks[pool_name] += 1
- num_starving_tasks_total += 1
- starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
- continue
-
- if task_instance.pool_slots > open_slots:
- self.log.info(
- "Not executing %s since it requires %s slots "
- "but there are %s open slots in the pool %s.",
- task_instance,
- task_instance.pool_slots,
- open_slots,
- pool_name,
+ assert executor_obj.name
+ if executor_slots_available[executor_obj.name] <= 0:
+ self.log.debug(
Review Comment:
I think here it should be a warning rather than a debug, as if we start
getting a lot of those, we would like to know about it and know how to adjust,
rather than just write to debug.
```suggestion
if executor_slots_available[executor_obj.name] <= 0:
self.log.warn(
"Not scheduling %s since its executor %s does not
currently have any more "
"available slots"
)
```
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1023,15 +898,15 @@ def _execute(self) -> int | None:
def _update_dag_run_state_for_paused_dags(self, session: Session =
NEW_SESSION) -> None:
try:
paused_runs = session.scalars(
- select(DagRun)
+ select(DR)
Review Comment:
why not just keep it `DagRun` :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]