tanelk commented on a change in pull request #19747:
URL: https://github.com/apache/airflow/pull/19747#discussion_r799234159
##########
File path: airflow/jobs/scheduler_job.py
##########
@@ -266,192 +266,262 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session =
if pool_slots_free == 0:
self.log.debug("All pools are full!")
- return executable_tis
+ return []
max_tis = min(max_tis, pool_slots_free)
- # Get all task instances associated with scheduled
- # DagRuns which are not backfilled, in the given states,
- # and the dag is not paused
- query = (
- session.query(TI)
- .join(TI.dag_run)
- .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state ==
DagRunState.RUNNING)
- .join(TI.dag_model)
- .filter(not_(DM.is_paused))
- .filter(TI.state == TaskInstanceState.SCHEDULED)
- .options(selectinload('dag_model'))
- .order_by(-TI.priority_weight, DR.execution_date)
- )
- starved_pools = [pool_name for pool_name, stats in pools.items() if
stats['open'] <= 0]
- if starved_pools:
- query = query.filter(not_(TI.pool.in_(starved_pools)))
-
- query = query.limit(max_tis)
-
- task_instances_to_examine: List[TI] = with_row_locks(
- query,
- of=TI,
- session=session,
- **skip_locked(session=session),
- ).all()
- # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
- # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
-
- if len(task_instances_to_examine) == 0:
- self.log.debug("No tasks to consider for execution.")
- return executable_tis
-
- # Put one task instance on each line
- task_instance_str = "\n\t".join(repr(x) for x in
task_instances_to_examine)
- self.log.info("%s tasks up for execution:\n\t%s",
len(task_instances_to_examine), task_instance_str)
-
- pool_to_task_instances: DefaultDict[str, List[models.Pool]] =
defaultdict(list)
- for task_instance in task_instances_to_examine:
- pool_to_task_instances[task_instance.pool].append(task_instance)
+ starved_pools = {pool_name for pool_name, stats in pools.items() if
stats['open'] <= 0}
# dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- dag_max_active_tasks_map: DefaultDict[str, int]
+ dag_active_tasks_map: DefaultDict[str, int]
task_concurrency_map: DefaultDict[Tuple[str, str], int]
- dag_max_active_tasks_map, task_concurrency_map =
self.__get_concurrency_maps(
+ dag_active_tasks_map, task_concurrency_map =
self.__get_concurrency_maps(
states=list(EXECUTION_STATES), session=session
)
num_tasks_in_executor = 0
# Number of tasks that cannot be scheduled because of no open slot in
pool
num_starving_tasks_total = 0
- # Go through each pool, and queue up a task for execution if there are
- # any open slots in the pool.
+ # dag and task ids that can't be queued because of concurrency limits
+ starved_dags: Set[str] = set()
+ starved_tasks: Set[Tuple[str, str]] = set()
- for pool, task_instances in pool_to_task_instances.items():
- pool_name = pool
- if pool not in pools:
- self.log.warning("Tasks using non-existent pool '%s' will not
be scheduled", pool)
- continue
+ pool_num_starving_tasks = defaultdict(int)
+
+ for loop_count in itertools.count(start=1):
+
+ num_starved_pools = len(starved_pools)
+ num_starved_dags = len(starved_dags)
+ num_starved_tasks = len(starved_tasks)
+
+ # Get task instances associated with scheduled
+ # DagRuns which are not backfilled, in the given states,
+ # and the dag is not paused
+ query = (
+ session.query(TI)
+ .join(TI.dag_run)
+ .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state ==
DagRunState.RUNNING)
+ .join(TI.dag_model)
+ .filter(not_(DM.is_paused))
+ .filter(TI.state == TaskInstanceState.SCHEDULED)
+ .options(selectinload('dag_model'))
+ .order_by(-TI.priority_weight, DR.execution_date)
+ )
+
+ if starved_pools:
+ query = query.filter(not_(TI.pool.in_(starved_pools)))
+
+ if starved_dags:
+ query = query.filter(not_(TI.dag_id.in_(starved_dags)))
- pool_total = pools[pool]["total"]
- for task_instance in task_instances:
- if task_instance.pool_slots > pool_total:
- self.log.warning(
- "Not executing %s. Requested pool slots (%s) are
greater than "
- "total pool slots: '%s' for pool: %s.",
- task_instance,
- task_instance.pool_slots,
- pool_total,
- pool,
+ if starved_tasks:
+ if settings.Session.bind.dialect.name == 'mssql':
+ task_filter = or_(
+ and_(
+ TaskInstance.dag_id == dag_id,
+ TaskInstance.task_id == task_id,
+ )
+ for (dag_id, task_id) in starved_tasks
)
- task_instances.remove(task_instance)
+ else:
+ task_filter = tuple_(TaskInstance.dag_id,
TaskInstance.task_id).in_(starved_tasks)
+
+ query = query.filter(not_(task_filter))
+
+ query = query.limit(max_tis)
- open_slots = pools[pool]["open"]
+ task_instances_to_examine: List[TI] = with_row_locks(
+ query,
+ of=TI,
+ session=session,
+ **skip_locked(session=session),
+ ).all()
+ # TODO[HA]: This was wrong before anyway, as it only looked at a
sub-set of dags, not everything.
+ # Stats.gauge('scheduler.tasks.pending',
len(task_instances_to_examine))
- num_ready = len(task_instances)
+ if len(task_instances_to_examine) == 0:
+ self.log.debug("No tasks to consider for execution.")
+ break
+
+ # Put one task instance on each line
+ task_instance_str = "\n\t".join(repr(x) for x in
task_instances_to_examine)
self.log.info(
- "Figuring out tasks to run in Pool(name=%s) with %s open slots
"
- "and %s task instances ready to be queued",
- pool,
- open_slots,
- num_ready,
+ "%s tasks up for execution:\n\t%s",
len(task_instances_to_examine), task_instance_str
)
- priority_sorted_task_instances = sorted(
- task_instances, key=lambda ti: (-ti.priority_weight,
ti.execution_date)
- )
+ pool_to_task_instances: DefaultDict[str, List[TI]] =
defaultdict(list)
+ for task_instance in task_instances_to_examine:
+
pool_to_task_instances[task_instance.pool].append(task_instance)
+
+ # Go through each pool, and queue up a task for execution if there
are
+ # any open slots in the pool.
- num_starving_tasks = 0
- for current_index, task_instance in
enumerate(priority_sorted_task_instances):
- if open_slots <= 0:
- self.log.info("Not scheduling since there are %s open
slots in pool %s", open_slots, pool)
- # Can't schedule any more since there are no more open
slots.
- num_unhandled = len(priority_sorted_task_instances) -
current_index
- num_starving_tasks += num_unhandled
- num_starving_tasks_total += num_unhandled
- break
-
- # Check to make sure that the task max_active_tasks of the DAG
hasn't been
- # reached.
- dag_id = task_instance.dag_id
-
- current_max_active_tasks_per_dag =
dag_max_active_tasks_map[dag_id]
- max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
+ for pool, task_instances in pool_to_task_instances.items():
+ pool_name = pool
+ if pool not in pools:
+ self.log.warning("Tasks using non-existent pool '%s' will
not be scheduled", pool)
+ starved_pools.add(pool_name)
+ continue
+
+ pool_total = pools[pool]["total"]
+ open_slots = pools[pool]["open"]
+
+ num_ready = len(task_instances)
self.log.info(
- "DAG %s has %s/%s running and queued tasks",
- dag_id,
- current_max_active_tasks_per_dag,
- max_active_tasks_per_dag_limit,
+ "Figuring out tasks to run in Pool(name=%s) with %s open
slots "
+ "and %s task instances ready to be queued",
+ pool,
+ open_slots,
+ num_ready,
+ )
+
+ priority_sorted_task_instances = sorted(
+ task_instances, key=lambda ti: (-ti.priority_weight,
ti.execution_date)
)
- if current_max_active_tasks_per_dag >=
max_active_tasks_per_dag_limit:
+
+ for current_index, task_instance in
enumerate(priority_sorted_task_instances):
+ if open_slots <= 0:
+ self.log.info(
+ "Not scheduling since there are %s open slots in
pool %s", open_slots, pool
+ )
+ # Can't schedule any more since there are no more open
slots.
+ num_unhandled = len(priority_sorted_task_instances) -
current_index
+ pool_num_starving_tasks[pool_name] += num_unhandled
+ num_starving_tasks_total += num_unhandled
+ starved_pools.add(pool_name)
+ break
+
+ if task_instance.pool_slots > pool_total:
+ self.log.warning(
+ "Not executing %s. Requested pool slots (%s) are
greater than "
+ "total pool slots: '%s' for pool: %s.",
+ task_instance,
+ task_instance.pool_slots,
+ pool_total,
+ pool,
+ )
+
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ continue
+
+ # Check to make sure that the task max_active_tasks of the
DAG hasn't been
+ # reached.
+ dag_id = task_instance.dag_id
+
+ current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
+ max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
self.log.info(
- "Not executing %s since the number of tasks running or
queued "
- "from DAG %s is >= to the DAG's max_active_tasks limit
of %s",
- task_instance,
+ "DAG %s has %s/%s running and queued tasks",
dag_id,
+ current_active_tasks_per_dag,
max_active_tasks_per_dag_limit,
)
- continue
-
- task_concurrency_limit: Optional[int] = None
- if task_instance.dag_model.has_task_concurrency_limits:
- # Many dags don't have a task_concurrency, so where we can
avoid loading the full
- # serialized DAG the better.
- serialized_dag = self.dagbag.get_dag(dag_id,
session=session)
- # If the dag is missing, fail the task and continue to the
next task.
- if not serialized_dag:
- self.log.error(
- "DAG '%s' for task instance %s not found in
serialized_dag table",
- dag_id,
+ if current_active_tasks_per_dag >=
max_active_tasks_per_dag_limit:
+ self.log.info(
+ "Not executing %s since the number of tasks
running or queued "
+ "from DAG %s is >= to the DAG's max_active_tasks
limit of %s",
task_instance,
+ dag_id,
+ max_active_tasks_per_dag_limit,
)
- session.query(TI).filter(TI.dag_id == dag_id, TI.state
== State.SCHEDULED).update(
- {TI.state: State.FAILED},
synchronize_session='fetch'
- )
+ starved_dags.add(dag_id)
continue
- if serialized_dag.has_task(task_instance.task_id):
- task_concurrency_limit = serialized_dag.get_task(
- task_instance.task_id
- ).max_active_tis_per_dag
-
- if task_concurrency_limit is not None:
- current_task_concurrency = task_concurrency_map[
- (task_instance.dag_id, task_instance.task_id)
- ]
-
- if current_task_concurrency >= task_concurrency_limit:
- self.log.info(
- "Not executing %s since the task concurrency
for"
- " this task has been reached.",
+
+ task_concurrency_limit: Optional[int] = None
+ if task_instance.dag_model.has_task_concurrency_limits:
+ # Many dags don't have a task_concurrency, so where we
can avoid loading the full
+ # serialized DAG the better.
+ serialized_dag = self.dagbag.get_dag(dag_id,
session=session)
+ # If the dag is missing, fail the task and continue to
the next task.
+ if not serialized_dag:
+ self.log.error(
+ "DAG '%s' for task instance %s not found in
serialized_dag table",
+ dag_id,
task_instance,
)
+ session.query(TI).filter(TI.dag_id == dag_id,
TI.state == State.SCHEDULED).update(
+ {TI.state: State.FAILED},
synchronize_session='fetch'
+ )
continue
+ if serialized_dag.has_task(task_instance.task_id):
+ task_concurrency_limit = serialized_dag.get_task(
+ task_instance.task_id
+ ).max_active_tis_per_dag
+ task_concurrency_limit: Optional[int] = None
+ if task_instance.dag_model.has_task_concurrency_limits:
+ # Many dags don't have a task_concurrency, so where we
can avoid loading the full
+ # serialized DAG the better.
+ serialized_dag = self.dagbag.get_dag(dag_id,
session=session)
+ if serialized_dag.has_task(task_instance.task_id):
+ task_concurrency_limit = serialized_dag.get_task(
+ task_instance.task_id
+ ).max_active_tis_per_dag
+
+ if task_concurrency_limit is not None:
+ current_task_concurrency = task_concurrency_map[
+ (task_instance.dag_id, task_instance.task_id)
+ ]
+
+ if current_task_concurrency >=
task_concurrency_limit:
+ self.log.info(
+ "Not executing %s since the task
concurrency for"
+ " this task has been reached.",
+ task_instance,
+ )
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ continue
+
+ if task_instance.pool_slots > open_slots:
+ self.log.info(
+ "Not executing %s since it requires %s slots "
+ "but there are %s open slots in the pool %s.",
+ task_instance,
+ task_instance.pool_slots,
+ open_slots,
+ pool,
+ )
+ pool_num_starving_tasks[pool_name] += 1
+ num_starving_tasks_total += 1
+ starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
+ # Though we can execute tasks with lower priority if
there's enough room
+ continue
- if task_instance.pool_slots > open_slots:
- self.log.info(
- "Not executing %s since it requires %s slots "
- "but there are %s open slots in the pool %s.",
- task_instance,
- task_instance.pool_slots,
- open_slots,
- pool,
- )
- num_starving_tasks += 1
- num_starving_tasks_total += 1
- # Though we can execute tasks with lower priority if
there's enough room
- continue
+ executable_tis.append(task_instance)
+ open_slots -= task_instance.pool_slots
+ dag_active_tasks_map[dag_id] += 1
+ task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
+
+ pools[pool]["open"] = open_slots
+
+ is_done = executable_tis or len(task_instances_to_examine) <
max_tis
+ # Check this to avoid accidental infinite loops
+ found_new_filters = (
+ len(starved_pools) > num_starved_pools
+ or len(starved_dags) > num_starved_dags
+ or len(starved_tasks) > num_starved_tasks
+ )
- executable_tis.append(task_instance)
- open_slots -= task_instance.pool_slots
- dag_max_active_tasks_map[dag_id] += 1
- task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
+ if is_done or not found_new_filters:
+ break
+
+ self.log.debug(
+ "Found no task instances to queue on the %s. iteration "
+ "but there could be more candidate task instances to check.",
+ loop_count,
+ )
+ for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)
Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total)
Stats.gauge('scheduler.tasks.running', num_tasks_in_executor)
Stats.gauge('scheduler.tasks.executable', len(executable_tis))
- task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
- self.log.info("Setting the following tasks to queued state:\n\t%s",
task_instance_str)
if len(executable_tis) > 0:
+ task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
+ self.log.info("Setting the following tasks to queued
state:\n\t%s", task_instance_str)
+
Review comment:
Moved the noisy logging into the if block
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]