This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8164c32858 Priority order tasks even when using pools (#22483)
8164c32858 is described below
commit 8164c328588b21dbeb7194d88f2d9c694b8f3b2e
Author: Tanel Kiis <[email protected]>
AuthorDate: Tue Apr 12 16:34:25 2022 +0300
Priority order tasks even when using pools (#22483)
When picking tasks to queue, the scheduler_job groups candidate task
instances
by their pools and picks tasks for queueing for each "pool group".
This way tasks with lower priority could be queued before tasks with higher
priority. This is demostrated in new UT
`test_find_executable_task_instances_order_priority_with_pools` - before
this
change `dummy3` and `dummy1` are queued instead of `dummy3` and `dummy2`.
Co-authored-by: Tanel Kiis <[email protected]>
---
airflow/jobs/scheduler_job.py | 204 ++++++++++++++++++---------------------
tests/jobs/test_scheduler_job.py | 39 ++++++++
2 files changed, 131 insertions(+), 112 deletions(-)
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 60cda34c13..5dc3c56d36 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -355,141 +355,121 @@ class SchedulerJob(BaseJob):
"%s tasks up for execution:\n\t%s",
len(task_instances_to_examine), task_instance_str
)
- pool_to_task_instances: DefaultDict[str, List[TI]] =
defaultdict(list)
for task_instance in task_instances_to_examine:
-
pool_to_task_instances[task_instance.pool].append(task_instance)
+ pool_name = task_instance.pool
- # Go through each pool, and queue up a task for execution if there
are
- # any open slots in the pool.
-
- for pool, task_instances in pool_to_task_instances.items():
- pool_name = pool
- if pool not in pools:
- self.log.warning("Tasks using non-existent pool '%s' will
not be scheduled", pool)
+ pool_stats = pools.get(pool_name)
+ if not pool_stats:
+ self.log.warning("Tasks using non-existent pool '%s' will
not be scheduled", pool_name)
starved_pools.add(pool_name)
continue
- pool_total = pools[pool]["total"]
- open_slots = pools[pool]["open"]
+ pool_total = pool_stats["total"]
+ open_slots = pool_stats["open"]
- num_ready = len(task_instances)
- self.log.info(
- "Figuring out tasks to run in Pool(name=%s) with %s open
slots "
- "and %s task instances ready to be queued",
- pool,
- open_slots,
- num_ready,
- )
-
- priority_sorted_task_instances = sorted(
- task_instances, key=lambda ti: (-ti.priority_weight,
ti.execution_date)
- )
+ if open_slots <= 0:
+ self.log.info(
+ "Not scheduling since there are %s open slots in pool
%s", open_slots, pool_name
+ )
+ # Can't schedule any more since there are no more open
slots.
+ pool_num_starving_tasks[pool_name] += 1
+ num_starving_tasks_total += 1
+ starved_pools.add(pool_name)
+ continue
- for current_index, task_instance in
enumerate(priority_sorted_task_instances):
- if open_slots <= 0:
- self.log.info(
- "Not scheduling since there are %s open slots in
pool %s", open_slots, pool
- )
- # Can't schedule any more since there are no more open
slots.
- num_unhandled = len(priority_sorted_task_instances) -
current_index
- pool_num_starving_tasks[pool_name] += num_unhandled
- num_starving_tasks_total += num_unhandled
- starved_pools.add(pool_name)
- break
-
- if task_instance.pool_slots > pool_total:
- self.log.warning(
- "Not executing %s. Requested pool slots (%s) are
greater than "
- "total pool slots: '%s' for pool: %s.",
- task_instance,
- task_instance.pool_slots,
- pool_total,
- pool,
- )
+ if task_instance.pool_slots > pool_total:
+ self.log.warning(
+ "Not executing %s. Requested pool slots (%s) are
greater than "
+ "total pool slots: '%s' for pool: %s.",
+ task_instance,
+ task_instance.pool_slots,
+ pool_total,
+ pool_name,
+ )
- starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
- continue
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ continue
- if task_instance.pool_slots > open_slots:
- self.log.info(
- "Not executing %s since it requires %s slots "
- "but there are %s open slots in the pool %s.",
- task_instance,
- task_instance.pool_slots,
- open_slots,
- pool,
- )
- pool_num_starving_tasks[pool_name] += 1
- num_starving_tasks_total += 1
- starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
- # Though we can execute tasks with lower priority if
there's enough room
- continue
+ if task_instance.pool_slots > open_slots:
+ self.log.info(
+ "Not executing %s since it requires %s slots "
+ "but there are %s open slots in the pool %s.",
+ task_instance,
+ task_instance.pool_slots,
+ open_slots,
+ pool_name,
+ )
+ pool_num_starving_tasks[pool_name] += 1
+ num_starving_tasks_total += 1
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ # Though we can execute tasks with lower priority if
there's enough room
+ continue
- # Check to make sure that the task max_active_tasks of the
DAG hasn't been
- # reached.
- dag_id = task_instance.dag_id
+ # Check to make sure that the task max_active_tasks of the DAG
hasn't been
+ # reached.
+ dag_id = task_instance.dag_id
- current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
- max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
+ current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
+ max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
+ self.log.info(
+ "DAG %s has %s/%s running and queued tasks",
+ dag_id,
+ current_active_tasks_per_dag,
+ max_active_tasks_per_dag_limit,
+ )
+ if current_active_tasks_per_dag >=
max_active_tasks_per_dag_limit:
self.log.info(
- "DAG %s has %s/%s running and queued tasks",
+ "Not executing %s since the number of tasks running or
queued "
+ "from DAG %s is >= to the DAG's max_active_tasks limit
of %s",
+ task_instance,
dag_id,
- current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)
- if current_active_tasks_per_dag >=
max_active_tasks_per_dag_limit:
- self.log.info(
- "Not executing %s since the number of tasks
running or queued "
- "from DAG %s is >= to the DAG's max_active_tasks
limit of %s",
- task_instance,
+ starved_dags.add(dag_id)
+ continue
+
+ if task_instance.dag_model.has_task_concurrency_limits:
+ # Many dags don't have a task_concurrency, so where we can
avoid loading the full
+ # serialized DAG the better.
+ serialized_dag = self.dagbag.get_dag(dag_id,
session=session)
+ # If the dag is missing, fail the task and continue to the
next task.
+ if not serialized_dag:
+ self.log.error(
+ "DAG '%s' for task instance %s not found in
serialized_dag table",
dag_id,
- max_active_tasks_per_dag_limit,
+ task_instance,
+ )
+ session.query(TI).filter(TI.dag_id == dag_id, TI.state
== State.SCHEDULED).update(
+ {TI.state: State.FAILED},
synchronize_session='fetch'
)
- starved_dags.add(dag_id)
continue
- if task_instance.dag_model.has_task_concurrency_limits:
- # Many dags don't have a task_concurrency, so where we
can avoid loading the full
- # serialized DAG the better.
- serialized_dag = self.dagbag.get_dag(dag_id,
session=session)
- # If the dag is missing, fail the task and continue to
the next task.
- if not serialized_dag:
- self.log.error(
- "DAG '%s' for task instance %s not found in
serialized_dag table",
- dag_id,
+ task_concurrency_limit: Optional[int] = None
+ if serialized_dag.has_task(task_instance.task_id):
+ task_concurrency_limit = serialized_dag.get_task(
+ task_instance.task_id
+ ).max_active_tis_per_dag
+
+ if task_concurrency_limit is not None:
+ current_task_concurrency = task_concurrency_map[
+ (task_instance.dag_id, task_instance.task_id)
+ ]
+
+ if current_task_concurrency >= task_concurrency_limit:
+ self.log.info(
+ "Not executing %s since the task concurrency
for"
+ " this task has been reached.",
task_instance,
)
- session.query(TI).filter(TI.dag_id == dag_id,
TI.state == State.SCHEDULED).update(
- {TI.state: State.FAILED},
synchronize_session='fetch'
- )
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
continue
- task_concurrency_limit: Optional[int] = None
- if serialized_dag.has_task(task_instance.task_id):
- task_concurrency_limit = serialized_dag.get_task(
- task_instance.task_id
- ).max_active_tis_per_dag
-
- if task_concurrency_limit is not None:
- current_task_concurrency = task_concurrency_map[
- (task_instance.dag_id, task_instance.task_id)
- ]
-
- if current_task_concurrency >=
task_concurrency_limit:
- self.log.info(
- "Not executing %s since the task
concurrency for"
- " this task has been reached.",
- task_instance,
- )
- starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
- continue
-
- executable_tis.append(task_instance)
- open_slots -= task_instance.pool_slots
- dag_active_tasks_map[dag_id] += 1
- task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
-
- pools[pool]["open"] = open_slots
+ executable_tis.append(task_instance)
+ open_slots -= task_instance.pool_slots
+ dag_active_tasks_map[dag_id] += 1
+ task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
+
+ pool_stats["open"] = open_slots
is_done = executable_tis or len(task_instances_to_examine) <
max_tis
# Check this to avoid accidental infinite loops
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 80e3cb1f34..e5d077c86d 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -610,6 +610,45 @@ class TestSchedulerJob:
assert [ti.key for ti in res] == [tis[1].key]
session.rollback()
+ def test_find_executable_task_instances_order_priority_with_pools(self,
dag_maker):
+ """
+ The scheduler job should pick tasks with higher priority for execution
+ even if different pools are involved.
+ """
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ session = settings.Session()
+
+ dag_id =
'SchedulerJobTest.test_find_executable_task_instances_order_priority_with_pools'
+
+ session.add(Pool(pool='pool1', slots=32))
+ session.add(Pool(pool='pool2', slots=32))
+
+ with dag_maker(dag_id=dag_id, max_active_tasks=2):
+ op1 = DummyOperator(task_id='dummy1', priority_weight=1,
pool='pool1')
+ op2 = DummyOperator(task_id='dummy2', priority_weight=2,
pool='pool2')
+ op3 = DummyOperator(task_id='dummy3', priority_weight=3,
pool='pool1')
+
+ dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+ ti1 = dag_run.get_task_instance(op1.task_id, session)
+ ti2 = dag_run.get_task_instance(op2.task_id, session)
+ ti3 = dag_run.get_task_instance(op3.task_id, session)
+
+ ti1.state = State.SCHEDULED
+ ti2.state = State.SCHEDULED
+ ti3.state = State.SCHEDULED
+
+ session.flush()
+
+ res =
self.scheduler_job._executable_task_instances_to_queued(max_tis=32,
session=session)
+
+ assert 2 == len(res)
+ assert ti3.key == res[0].key
+ assert ti2.key == res[1].key
+
+ session.rollback()
+
def
test_find_executable_task_instances_order_execution_date_and_priority(self,
dag_maker):
dag_id_1 =
'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-a'
dag_id_2 =
'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-b'