This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8164c32858 Priority order tasks even when using pools (#22483)
8164c32858 is described below

commit 8164c328588b21dbeb7194d88f2d9c694b8f3b2e
Author: Tanel Kiis <[email protected]>
AuthorDate: Tue Apr 12 16:34:25 2022 +0300

    Priority order tasks even when using pools (#22483)
    
    When picking tasks to queue, the scheduler_job groups candidate task 
instances
    by their pools and picks tasks for queueing for each "pool group".
    
    This way tasks with lower priority could be queued before tasks with higher
    priority. This is demostrated in new UT
    `test_find_executable_task_instances_order_priority_with_pools` - before 
this
    change `dummy3` and `dummy1` are queued instead of `dummy3` and `dummy2`.
    
    Co-authored-by: Tanel Kiis <[email protected]>
---
 airflow/jobs/scheduler_job.py    | 204 ++++++++++++++++++---------------------
 tests/jobs/test_scheduler_job.py |  39 ++++++++
 2 files changed, 131 insertions(+), 112 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 60cda34c13..5dc3c56d36 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -355,141 +355,121 @@ class SchedulerJob(BaseJob):
                 "%s tasks up for execution:\n\t%s", 
len(task_instances_to_examine), task_instance_str
             )
 
-            pool_to_task_instances: DefaultDict[str, List[TI]] = 
defaultdict(list)
             for task_instance in task_instances_to_examine:
-                
pool_to_task_instances[task_instance.pool].append(task_instance)
+                pool_name = task_instance.pool
 
-            # Go through each pool, and queue up a task for execution if there 
are
-            # any open slots in the pool.
-
-            for pool, task_instances in pool_to_task_instances.items():
-                pool_name = pool
-                if pool not in pools:
-                    self.log.warning("Tasks using non-existent pool '%s' will 
not be scheduled", pool)
+                pool_stats = pools.get(pool_name)
+                if not pool_stats:
+                    self.log.warning("Tasks using non-existent pool '%s' will 
not be scheduled", pool_name)
                     starved_pools.add(pool_name)
                     continue
 
-                pool_total = pools[pool]["total"]
-                open_slots = pools[pool]["open"]
+                pool_total = pool_stats["total"]
+                open_slots = pool_stats["open"]
 
-                num_ready = len(task_instances)
-                self.log.info(
-                    "Figuring out tasks to run in Pool(name=%s) with %s open 
slots "
-                    "and %s task instances ready to be queued",
-                    pool,
-                    open_slots,
-                    num_ready,
-                )
-
-                priority_sorted_task_instances = sorted(
-                    task_instances, key=lambda ti: (-ti.priority_weight, 
ti.execution_date)
-                )
+                if open_slots <= 0:
+                    self.log.info(
+                        "Not scheduling since there are %s open slots in pool 
%s", open_slots, pool_name
+                    )
+                    # Can't schedule any more since there are no more open 
slots.
+                    pool_num_starving_tasks[pool_name] += 1
+                    num_starving_tasks_total += 1
+                    starved_pools.add(pool_name)
+                    continue
 
-                for current_index, task_instance in 
enumerate(priority_sorted_task_instances):
-                    if open_slots <= 0:
-                        self.log.info(
-                            "Not scheduling since there are %s open slots in 
pool %s", open_slots, pool
-                        )
-                        # Can't schedule any more since there are no more open 
slots.
-                        num_unhandled = len(priority_sorted_task_instances) - 
current_index
-                        pool_num_starving_tasks[pool_name] += num_unhandled
-                        num_starving_tasks_total += num_unhandled
-                        starved_pools.add(pool_name)
-                        break
-
-                    if task_instance.pool_slots > pool_total:
-                        self.log.warning(
-                            "Not executing %s. Requested pool slots (%s) are 
greater than "
-                            "total pool slots: '%s' for pool: %s.",
-                            task_instance,
-                            task_instance.pool_slots,
-                            pool_total,
-                            pool,
-                        )
+                if task_instance.pool_slots > pool_total:
+                    self.log.warning(
+                        "Not executing %s. Requested pool slots (%s) are 
greater than "
+                        "total pool slots: '%s' for pool: %s.",
+                        task_instance,
+                        task_instance.pool_slots,
+                        pool_total,
+                        pool_name,
+                    )
 
-                        starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
-                        continue
+                    starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
+                    continue
 
-                    if task_instance.pool_slots > open_slots:
-                        self.log.info(
-                            "Not executing %s since it requires %s slots "
-                            "but there are %s open slots in the pool %s.",
-                            task_instance,
-                            task_instance.pool_slots,
-                            open_slots,
-                            pool,
-                        )
-                        pool_num_starving_tasks[pool_name] += 1
-                        num_starving_tasks_total += 1
-                        starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
-                        # Though we can execute tasks with lower priority if 
there's enough room
-                        continue
+                if task_instance.pool_slots > open_slots:
+                    self.log.info(
+                        "Not executing %s since it requires %s slots "
+                        "but there are %s open slots in the pool %s.",
+                        task_instance,
+                        task_instance.pool_slots,
+                        open_slots,
+                        pool_name,
+                    )
+                    pool_num_starving_tasks[pool_name] += 1
+                    num_starving_tasks_total += 1
+                    starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
+                    # Though we can execute tasks with lower priority if 
there's enough room
+                    continue
 
-                    # Check to make sure that the task max_active_tasks of the 
DAG hasn't been
-                    # reached.
-                    dag_id = task_instance.dag_id
+                # Check to make sure that the task max_active_tasks of the DAG 
hasn't been
+                # reached.
+                dag_id = task_instance.dag_id
 
-                    current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
-                    max_active_tasks_per_dag_limit = 
task_instance.dag_model.max_active_tasks
+                current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
+                max_active_tasks_per_dag_limit = 
task_instance.dag_model.max_active_tasks
+                self.log.info(
+                    "DAG %s has %s/%s running and queued tasks",
+                    dag_id,
+                    current_active_tasks_per_dag,
+                    max_active_tasks_per_dag_limit,
+                )
+                if current_active_tasks_per_dag >= 
max_active_tasks_per_dag_limit:
                     self.log.info(
-                        "DAG %s has %s/%s running and queued tasks",
+                        "Not executing %s since the number of tasks running or 
queued "
+                        "from DAG %s is >= to the DAG's max_active_tasks limit 
of %s",
+                        task_instance,
                         dag_id,
-                        current_active_tasks_per_dag,
                         max_active_tasks_per_dag_limit,
                     )
-                    if current_active_tasks_per_dag >= 
max_active_tasks_per_dag_limit:
-                        self.log.info(
-                            "Not executing %s since the number of tasks 
running or queued "
-                            "from DAG %s is >= to the DAG's max_active_tasks 
limit of %s",
-                            task_instance,
+                    starved_dags.add(dag_id)
+                    continue
+
+                if task_instance.dag_model.has_task_concurrency_limits:
+                    # Many dags don't have a task_concurrency, so where we can 
avoid loading the full
+                    # serialized DAG the better.
+                    serialized_dag = self.dagbag.get_dag(dag_id, 
session=session)
+                    # If the dag is missing, fail the task and continue to the 
next task.
+                    if not serialized_dag:
+                        self.log.error(
+                            "DAG '%s' for task instance %s not found in 
serialized_dag table",
                             dag_id,
-                            max_active_tasks_per_dag_limit,
+                            task_instance,
+                        )
+                        session.query(TI).filter(TI.dag_id == dag_id, TI.state 
== State.SCHEDULED).update(
+                            {TI.state: State.FAILED}, 
synchronize_session='fetch'
                         )
-                        starved_dags.add(dag_id)
                         continue
 
-                    if task_instance.dag_model.has_task_concurrency_limits:
-                        # Many dags don't have a task_concurrency, so where we 
can avoid loading the full
-                        # serialized DAG the better.
-                        serialized_dag = self.dagbag.get_dag(dag_id, 
session=session)
-                        # If the dag is missing, fail the task and continue to 
the next task.
-                        if not serialized_dag:
-                            self.log.error(
-                                "DAG '%s' for task instance %s not found in 
serialized_dag table",
-                                dag_id,
+                    task_concurrency_limit: Optional[int] = None
+                    if serialized_dag.has_task(task_instance.task_id):
+                        task_concurrency_limit = serialized_dag.get_task(
+                            task_instance.task_id
+                        ).max_active_tis_per_dag
+
+                    if task_concurrency_limit is not None:
+                        current_task_concurrency = task_concurrency_map[
+                            (task_instance.dag_id, task_instance.task_id)
+                        ]
+
+                        if current_task_concurrency >= task_concurrency_limit:
+                            self.log.info(
+                                "Not executing %s since the task concurrency 
for"
+                                " this task has been reached.",
                                 task_instance,
                             )
-                            session.query(TI).filter(TI.dag_id == dag_id, 
TI.state == State.SCHEDULED).update(
-                                {TI.state: State.FAILED}, 
synchronize_session='fetch'
-                            )
+                            starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
                             continue
 
-                        task_concurrency_limit: Optional[int] = None
-                        if serialized_dag.has_task(task_instance.task_id):
-                            task_concurrency_limit = serialized_dag.get_task(
-                                task_instance.task_id
-                            ).max_active_tis_per_dag
-
-                        if task_concurrency_limit is not None:
-                            current_task_concurrency = task_concurrency_map[
-                                (task_instance.dag_id, task_instance.task_id)
-                            ]
-
-                            if current_task_concurrency >= 
task_concurrency_limit:
-                                self.log.info(
-                                    "Not executing %s since the task 
concurrency for"
-                                    " this task has been reached.",
-                                    task_instance,
-                                )
-                                starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
-                                continue
-
-                    executable_tis.append(task_instance)
-                    open_slots -= task_instance.pool_slots
-                    dag_active_tasks_map[dag_id] += 1
-                    task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
-
-                pools[pool]["open"] = open_slots
+                executable_tis.append(task_instance)
+                open_slots -= task_instance.pool_slots
+                dag_active_tasks_map[dag_id] += 1
+                task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
+
+                pool_stats["open"] = open_slots
 
             is_done = executable_tis or len(task_instances_to_examine) < 
max_tis
             # Check this to avoid accidental infinite loops
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 80e3cb1f34..e5d077c86d 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -610,6 +610,45 @@ class TestSchedulerJob:
         assert [ti.key for ti in res] == [tis[1].key]
         session.rollback()
 
+    def test_find_executable_task_instances_order_priority_with_pools(self, 
dag_maker):
+        """
+        The scheduler job should pick tasks with higher priority for execution
+        even if different pools are involved.
+        """
+
+        self.scheduler_job = SchedulerJob(subdir=os.devnull)
+        session = settings.Session()
+
+        dag_id = 
'SchedulerJobTest.test_find_executable_task_instances_order_priority_with_pools'
+
+        session.add(Pool(pool='pool1', slots=32))
+        session.add(Pool(pool='pool2', slots=32))
+
+        with dag_maker(dag_id=dag_id, max_active_tasks=2):
+            op1 = DummyOperator(task_id='dummy1', priority_weight=1, 
pool='pool1')
+            op2 = DummyOperator(task_id='dummy2', priority_weight=2, 
pool='pool2')
+            op3 = DummyOperator(task_id='dummy3', priority_weight=3, 
pool='pool1')
+
+        dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+        ti1 = dag_run.get_task_instance(op1.task_id, session)
+        ti2 = dag_run.get_task_instance(op2.task_id, session)
+        ti3 = dag_run.get_task_instance(op3.task_id, session)
+
+        ti1.state = State.SCHEDULED
+        ti2.state = State.SCHEDULED
+        ti3.state = State.SCHEDULED
+
+        session.flush()
+
+        res = 
self.scheduler_job._executable_task_instances_to_queued(max_tis=32, 
session=session)
+
+        assert 2 == len(res)
+        assert ti3.key == res[0].key
+        assert ti2.key == res[1].key
+
+        session.rollback()
+
     def 
test_find_executable_task_instances_order_execution_date_and_priority(self, 
dag_maker):
         dag_id_1 = 
'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-a'
         dag_id_2 = 
'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-b'

Reply via email to