This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new cd68540 Fix Tasks getting stuck in scheduled state (#19747)
cd68540 is described below
commit cd68540ef19b36180fdd1ebe38435637586747d4
Author: Tanel Kiis <[email protected]>
AuthorDate: Tue Mar 22 19:30:37 2022 +0200
Fix Tasks getting stuck in scheduled state (#19747)
The scheduler_job can get stuck in a state, where it is not able to queue
new tasks. It will get out of this state on its own, but the time taken depends
on the runtime of current tasks - this could be several hours or even days.
If the scheduler can't queue any tasks because of different concurrency
limits (per pool, dag or task), then on next iterations of the scheduler loop
it will try to queue the same tasks. Meanwhile there could be some scheduled
tasks with lower priority_weight that could be queued, but they will remain
waiting.
The proposed solution is to keep track of dag and task ids, that are
concurrecy limited and then repeat the query with these dags and tasks filtered
out.
Co-authored-by: Tanel Kiis <[email protected]>
---
airflow/jobs/scheduler_job.py | 348 +++++++++++++++++++++++----------------
tests/jobs/test_scheduler_job.py | 94 ++++++++++-
2 files changed, 298 insertions(+), 144 deletions(-)
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index a09b39a..5b02e25 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -26,7 +26,7 @@ import time
import warnings
from collections import defaultdict
from datetime import timedelta
-from typing import Collection, DefaultDict, Dict, Iterator, List, Optional,
Tuple
+from typing import Collection, DefaultDict, Dict, Iterator, List, Optional,
Set, Tuple
from sqlalchemy import and_, func, not_, or_, text, tuple_
from sqlalchemy.exc import OperationalError
@@ -271,54 +271,16 @@ class SchedulerJob(BaseJob):
if pool_slots_free == 0:
self.log.debug("All pools are full!")
- return executable_tis
+ return []
max_tis = min(max_tis, pool_slots_free)
- # Get all task instances associated with scheduled
- # DagRuns which are not backfilled, in the given states,
- # and the dag is not paused
- query = (
- session.query(TI)
- .join(TI.dag_run)
- .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state ==
DagRunState.RUNNING)
- .join(TI.dag_model)
- .filter(not_(DM.is_paused))
- .filter(TI.state == TaskInstanceState.SCHEDULED)
- .options(selectinload('dag_model'))
- .order_by(-TI.priority_weight, DR.execution_date)
- )
- starved_pools = [pool_name for pool_name, stats in pools.items() if
stats['open'] <= 0]
- if starved_pools:
- query = query.filter(not_(TI.pool.in_(starved_pools)))
-
- query = query.limit(max_tis)
-
- task_instances_to_examine: List[TI] = with_row_locks(
- query,
- of=TI,
- session=session,
- **skip_locked(session=session),
- ).all()
- # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
- # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
-
- if len(task_instances_to_examine) == 0:
- self.log.debug("No tasks to consider for execution.")
- return executable_tis
-
- # Put one task instance on each line
- task_instance_str = "\n\t".join(repr(x) for x in
task_instances_to_examine)
- self.log.info("%s tasks up for execution:\n\t%s",
len(task_instances_to_examine), task_instance_str)
-
- pool_to_task_instances: DefaultDict[str, List[models.Pool]] =
defaultdict(list)
- for task_instance in task_instances_to_examine:
- pool_to_task_instances[task_instance.pool].append(task_instance)
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats['open'] <= 0}
# dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- dag_max_active_tasks_map: DefaultDict[str, int]
+ dag_active_tasks_map: DefaultDict[str, int]
task_concurrency_map: DefaultDict[Tuple[str, str], int]
- dag_max_active_tasks_map, task_concurrency_map =
self.__get_concurrency_maps(
+ dag_active_tasks_map, task_concurrency_map =
self.__get_concurrency_maps(
states=list(EXECUTION_STATES), session=session
)
@@ -326,137 +288,237 @@ class SchedulerJob(BaseJob):
# Number of tasks that cannot be scheduled because of no open slot in
pool
num_starving_tasks_total = 0
- # Go through each pool, and queue up a task for execution if there are
- # any open slots in the pool.
+ # dag and task ids that can't be queued because of concurrency limits
+ starved_dags: Set[str] = set()
+ starved_tasks: Set[Tuple[str, str]] = set()
- for pool, task_instances in pool_to_task_instances.items():
- pool_name = pool
- if pool not in pools:
- self.log.warning("Tasks using non-existent pool '%s' will not
be scheduled", pool)
- continue
+ pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)
+
+ for loop_count in itertools.count(start=1):
+
+ num_starved_pools = len(starved_pools)
+ num_starved_dags = len(starved_dags)
+ num_starved_tasks = len(starved_tasks)
+
+ # Get task instances associated with scheduled
+ # DagRuns which are not backfilled, in the given states,
+ # and the dag is not paused
+ query = (
+ session.query(TI)
+ .join(TI.dag_run)
+ .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state ==
DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .filter(not_(DM.is_paused))
+ .filter(TI.state == TaskInstanceState.SCHEDULED)
+ .options(selectinload('dag_model'))
+ .order_by(-TI.priority_weight, DR.execution_date)
+ )
+
+ if starved_pools:
+ query = query.filter(not_(TI.pool.in_(starved_pools)))
+
+ if starved_dags:
+ query = query.filter(not_(TI.dag_id.in_(starved_dags)))
- pool_total = pools[pool]["total"]
- for task_instance in task_instances:
- if task_instance.pool_slots > pool_total:
- self.log.warning(
- "Not executing %s. Requested pool slots (%s) are
greater than "
- "total pool slots: '%s' for pool: %s.",
- task_instance,
- task_instance.pool_slots,
- pool_total,
- pool,
+ if starved_tasks:
+ if settings.engine.dialect.name == 'mssql':
+ task_filter = or_(
+ and_(
+ TaskInstance.dag_id == dag_id,
+ TaskInstance.task_id == task_id,
+ )
+ for (dag_id, task_id) in starved_tasks
)
- task_instances.remove(task_instance)
+ else:
+ task_filter = tuple_(TaskInstance.dag_id,
TaskInstance.task_id).in_(starved_tasks)
+
+ query = query.filter(not_(task_filter))
+
+ query = query.limit(max_tis)
- open_slots = pools[pool]["open"]
+ task_instances_to_examine: List[TI] = with_row_locks(
+ query,
+ of=TI,
+ session=session,
+ **skip_locked(session=session),
+ ).all()
+ # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
+ # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
- num_ready = len(task_instances)
+ if len(task_instances_to_examine) == 0:
+ self.log.debug("No tasks to consider for execution.")
+ break
+
+ # Put one task instance on each line
+ task_instance_str = "\n\t".join(repr(x) for x in
task_instances_to_examine)
self.log.info(
- "Figuring out tasks to run in Pool(name=%s) with %s open slots
"
- "and %s task instances ready to be queued",
- pool,
- open_slots,
- num_ready,
+ "%s tasks up for execution:\n\t%s",
len(task_instances_to_examine), task_instance_str
)
- priority_sorted_task_instances = sorted(
- task_instances, key=lambda ti: (-ti.priority_weight,
ti.execution_date)
- )
+ pool_to_task_instances: DefaultDict[str, List[TI]] =
defaultdict(list)
+ for task_instance in task_instances_to_examine:
+
pool_to_task_instances[task_instance.pool].append(task_instance)
+
+ # Go through each pool, and queue up a task for execution if there
are
+ # any open slots in the pool.
- num_starving_tasks = 0
- for current_index, task_instance in
enumerate(priority_sorted_task_instances):
- if open_slots <= 0:
- self.log.info("Not scheduling since there are %s open
slots in pool %s", open_slots, pool)
- # Can't schedule any more since there are no more open
slots.
- num_unhandled = len(priority_sorted_task_instances) -
current_index
- num_starving_tasks += num_unhandled
- num_starving_tasks_total += num_unhandled
- break
-
- # Check to make sure that the task max_active_tasks of the DAG
hasn't been
- # reached.
- dag_id = task_instance.dag_id
-
- current_max_active_tasks_per_dag =
dag_max_active_tasks_map[dag_id]
- max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
+ for pool, task_instances in pool_to_task_instances.items():
+ pool_name = pool
+ if pool not in pools:
+ self.log.warning("Tasks using non-existent pool '%s' will
not be scheduled", pool)
+ starved_pools.add(pool_name)
+ continue
+
+ pool_total = pools[pool]["total"]
+ open_slots = pools[pool]["open"]
+
+ num_ready = len(task_instances)
self.log.info(
- "DAG %s has %s/%s running and queued tasks",
- dag_id,
- current_max_active_tasks_per_dag,
- max_active_tasks_per_dag_limit,
+ "Figuring out tasks to run in Pool(name=%s) with %s open
slots "
+ "and %s task instances ready to be queued",
+ pool,
+ open_slots,
+ num_ready,
)
- if current_max_active_tasks_per_dag >=
max_active_tasks_per_dag_limit:
+
+ priority_sorted_task_instances = sorted(
+ task_instances, key=lambda ti: (-ti.priority_weight,
ti.execution_date)
+ )
+
+ for current_index, task_instance in
enumerate(priority_sorted_task_instances):
+ if open_slots <= 0:
+ self.log.info(
+ "Not scheduling since there are %s open slots in
pool %s", open_slots, pool
+ )
+ # Can't schedule any more since there are no more open
slots.
+ num_unhandled = len(priority_sorted_task_instances) -
current_index
+ pool_num_starving_tasks[pool_name] += num_unhandled
+ num_starving_tasks_total += num_unhandled
+ starved_pools.add(pool_name)
+ break
+
+ if task_instance.pool_slots > pool_total:
+ self.log.warning(
+ "Not executing %s. Requested pool slots (%s) are
greater than "
+ "total pool slots: '%s' for pool: %s.",
+ task_instance,
+ task_instance.pool_slots,
+ pool_total,
+ pool,
+ )
+
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ continue
+
+ if task_instance.pool_slots > open_slots:
+ self.log.info(
+ "Not executing %s since it requires %s slots "
+ "but there are %s open slots in the pool %s.",
+ task_instance,
+ task_instance.pool_slots,
+ open_slots,
+ pool,
+ )
+ pool_num_starving_tasks[pool_name] += 1
+ num_starving_tasks_total += 1
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ # Though we can execute tasks with lower priority if
there's enough room
+ continue
+
+ # Check to make sure that the task max_active_tasks of the
DAG hasn't been
+ # reached.
+ dag_id = task_instance.dag_id
+
+ current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
+ max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
self.log.info(
- "Not executing %s since the number of tasks running or
queued "
- "from DAG %s is >= to the DAG's max_active_tasks limit
of %s",
- task_instance,
+ "DAG %s has %s/%s running and queued tasks",
dag_id,
+ current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)
- continue
-
- task_concurrency_limit: Optional[int] = None
- if task_instance.dag_model.has_task_concurrency_limits:
- # Many dags don't have a task_concurrency, so where we can
avoid loading the full
- # serialized DAG the better.
- serialized_dag = self.dagbag.get_dag(dag_id,
session=session)
- # If the dag is missing, fail the task and continue to the
next task.
- if not serialized_dag:
- self.log.error(
- "DAG '%s' for task instance %s not found in
serialized_dag table",
- dag_id,
+ if current_active_tasks_per_dag >=
max_active_tasks_per_dag_limit:
+ self.log.info(
+ "Not executing %s since the number of tasks
running or queued "
+ "from DAG %s is >= to the DAG's max_active_tasks
limit of %s",
task_instance,
+ dag_id,
+ max_active_tasks_per_dag_limit,
)
- session.query(TI).filter(TI.dag_id == dag_id, TI.state
== State.SCHEDULED).update(
- {TI.state: State.FAILED},
synchronize_session='fetch'
- )
+ starved_dags.add(dag_id)
continue
- if serialized_dag.has_task(task_instance.task_id):
- task_concurrency_limit = serialized_dag.get_task(
- task_instance.task_id
- ).max_active_tis_per_dag
-
- if task_concurrency_limit is not None:
- current_task_concurrency = task_concurrency_map[
- (task_instance.dag_id, task_instance.task_id)
- ]
-
- if current_task_concurrency >= task_concurrency_limit:
- self.log.info(
- "Not executing %s since the task concurrency
for"
- " this task has been reached.",
+
+ if task_instance.dag_model.has_task_concurrency_limits:
+ # Many dags don't have a task_concurrency, so where we
can avoid loading the full
+ # serialized DAG the better.
+ serialized_dag = self.dagbag.get_dag(dag_id,
session=session)
+ # If the dag is missing, fail the task and continue to
the next task.
+ if not serialized_dag:
+ self.log.error(
+ "DAG '%s' for task instance %s not found in
serialized_dag table",
+ dag_id,
task_instance,
)
+ session.query(TI).filter(TI.dag_id == dag_id,
TI.state == State.SCHEDULED).update(
+ {TI.state: State.FAILED},
synchronize_session='fetch'
+ )
continue
- if task_instance.pool_slots > open_slots:
- self.log.info(
- "Not executing %s since it requires %s slots "
- "but there are %s open slots in the pool %s.",
- task_instance,
- task_instance.pool_slots,
- open_slots,
- pool,
- )
- num_starving_tasks += 1
- num_starving_tasks_total += 1
- # Though we can execute tasks with lower priority if
there's enough room
- continue
+ task_concurrency_limit: Optional[int] = None
+ if serialized_dag.has_task(task_instance.task_id):
+ task_concurrency_limit = serialized_dag.get_task(
+ task_instance.task_id
+ ).max_active_tis_per_dag
+
+ if task_concurrency_limit is not None:
+ current_task_concurrency = task_concurrency_map[
+ (task_instance.dag_id, task_instance.task_id)
+ ]
+
+ if current_task_concurrency >=
task_concurrency_limit:
+ self.log.info(
+ "Not executing %s since the task
concurrency for"
+ " this task has been reached.",
+ task_instance,
+ )
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ continue
+
+ executable_tis.append(task_instance)
+ open_slots -= task_instance.pool_slots
+ dag_active_tasks_map[dag_id] += 1
+ task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
+
+ pools[pool]["open"] = open_slots
+
+ is_done = executable_tis or len(task_instances_to_examine) <
max_tis
+ # Check this to avoid accidental infinite loops
+ found_new_filters = (
+ len(starved_pools) > num_starved_pools
+ or len(starved_dags) > num_starved_dags
+ or len(starved_tasks) > num_starved_tasks
+ )
- executable_tis.append(task_instance)
- open_slots -= task_instance.pool_slots
- dag_max_active_tasks_map[dag_id] += 1
- task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
+ if is_done or not found_new_filters:
+ break
+ self.log.debug(
+ "Found no task instances to queue on the %s. iteration "
+ "but there could be more candidate task instances to check.",
+ loop_count,
+ )
+
+ for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)
Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total)
Stats.gauge('scheduler.tasks.running', num_tasks_in_executor)
Stats.gauge('scheduler.tasks.executable', len(executable_tis))
- task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
- self.log.info("Setting the following tasks to queued state:\n\t%s",
task_instance_str)
if len(executable_tis) > 0:
+ task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
+ self.log.info("Setting the following tasks to queued
state:\n\t%s", task_instance_str)
+
# set TIs to queued state
filter_for_tis = TI.filter_for_tis(executable_tis)
session.query(TI).filter(filter_for_tis).update(
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index eedd45e..d576e30 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -486,6 +486,7 @@ class TestSchedulerJob:
dr2.get_task_instance(task_id_1, session=session),
dr2.get_task_instance(task_id_2, session=session),
]
+ tis = sorted(tis, key=lambda ti: ti.key)
for ti in tis:
ti.state = State.SCHEDULED
session.merge(ti)
@@ -502,7 +503,7 @@ class TestSchedulerJob:
for ti in res:
res_keys.append(ti.key)
assert tis[0].key in res_keys
- assert tis[1].key in res_keys
+ assert tis[2].key in res_keys
assert tis[3].key in res_keys
session.rollback()
@@ -996,6 +997,97 @@ class TestSchedulerJob:
session.rollback()
+ def
test_find_executable_task_instances_not_enough_pool_slots_for_first(self,
dag_maker):
+ set_default_pool_slots(1)
+
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ session = settings.Session()
+
+ dag_id =
'SchedulerJobTest.test_find_executable_task_instances_not_enough_pool_slots_for_first'
+ with dag_maker(dag_id=dag_id):
+ op1 = DummyOperator(task_id='dummy1', priority_weight=2,
pool_slots=2)
+ op2 = DummyOperator(task_id='dummy2', priority_weight=1,
pool_slots=1)
+
+ dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+ ti1 = dr1.get_task_instance(op1.task_id, session)
+ ti2 = dr1.get_task_instance(op2.task_id, session)
+ ti1.state = State.SCHEDULED
+ ti2.state = State.SCHEDULED
+ session.flush()
+
+ # Schedule ti with lower priority,
+ # because the one with higher priority is limited by a concurrency
limit
+ res =
self.scheduler_job._executable_task_instances_to_queued(max_tis=32,
session=session)
+ assert 1 == len(res)
+ assert res[0].key == ti2.key
+
+ session.rollback()
+
+ def
test_find_executable_task_instances_not_enough_dag_concurrency_for_first(self,
dag_maker):
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ session = settings.Session()
+
+ dag_id_1 = (
+
'SchedulerJobTest.test_find_executable_task_instances_not_enough_dag_concurrency_for_first-a'
+ )
+ dag_id_2 = (
+
'SchedulerJobTest.test_find_executable_task_instances_not_enough_dag_concurrency_for_first-b'
+ )
+
+ with dag_maker(dag_id=dag_id_1, max_active_tasks=1):
+ op1a = DummyOperator(task_id='dummy1-a', priority_weight=2)
+ op1b = DummyOperator(task_id='dummy1-b', priority_weight=2)
+ dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+ with dag_maker(dag_id=dag_id_2):
+ op2 = DummyOperator(task_id='dummy2', priority_weight=1)
+ dr2 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+ ti1a = dr1.get_task_instance(op1a.task_id, session)
+ ti1b = dr1.get_task_instance(op1b.task_id, session)
+ ti2 = dr2.get_task_instance(op2.task_id, session)
+ ti1a.state = State.RUNNING
+ ti1b.state = State.SCHEDULED
+ ti2.state = State.SCHEDULED
+ session.flush()
+
+ # Schedule ti with lower priority,
+ # because the one with higher priority is limited by a concurrency
limit
+ res =
self.scheduler_job._executable_task_instances_to_queued(max_tis=1,
session=session)
+ assert 1 == len(res)
+ assert res[0].key == ti2.key
+
+ session.rollback()
+
+ def
test_find_executable_task_instances_not_enough_task_concurrency_for_first(self,
dag_maker):
+ self.scheduler_job = SchedulerJob(subdir=os.devnull)
+ session = settings.Session()
+
+ dag_id =
'SchedulerJobTest.test_find_executable_task_instances_not_enough_task_concurrency_for_first'
+
+ with dag_maker(dag_id=dag_id):
+ op1a = DummyOperator(task_id='dummy1-a', priority_weight=2,
max_active_tis_per_dag=1)
+ op1b = DummyOperator(task_id='dummy1-b', priority_weight=1)
+ dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+ dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
+
+ ti1a = dr1.get_task_instance(op1a.task_id, session)
+ ti1b = dr1.get_task_instance(op1b.task_id, session)
+ ti2a = dr2.get_task_instance(op1a.task_id, session)
+ ti1a.state = State.RUNNING
+ ti1b.state = State.SCHEDULED
+ ti2a.state = State.SCHEDULED
+ session.flush()
+
+ # Schedule ti with lower priority,
+ # because the one with higher priority is limited by a concurrency
limit
+ res =
self.scheduler_job._executable_task_instances_to_queued(max_tis=1,
session=session)
+ assert 1 == len(res)
+ assert res[0].key == ti1b.key
+
+ session.rollback()
+
def test_enqueue_task_instances_with_queued_state(self, dag_maker):
dag_id =
'SchedulerJobTest.test_enqueue_task_instances_with_queued_state'
task_id_1 = 'dummy'