This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ae96933761 Max active tasks to be evaluated per dag run (#42953)
ae96933761 is described below
commit ae96933761b545bf4e64abe20dd30a16f5742018
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Oct 17 14:07:53 2024 -0700
Max active tasks to be evaluated per dag run (#42953)
This behavior change was accepted by lazy consensus here:
https://lists.apache.org/thread/9o84d3yn934m32gtlpokpwtbbmtxj47l.
Previously max_active_tasks was evaluated across all runs of a dag.
Co-authored-by: Wei Lee <[email protected]>
---
airflow/jobs/scheduler_job_runner.py | 67 +++++------
newsfragments/42953.significant | 3 +
tests/jobs/test_scheduler_job.py | 211 +++++++++++++++++++----------------
3 files changed, 144 insertions(+), 137 deletions(-)
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index ed67e7a4ac..e085efc314 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -24,7 +24,6 @@ import signal
import sys
import time
from collections import Counter, defaultdict, deque
-from dataclasses import dataclass
from datetime import timedelta
from functools import lru_cache, partial
from pathlib import Path
@@ -83,7 +82,6 @@ if TYPE_CHECKING:
from datetime import datetime
from types import FrameType
- from sqlalchemy.engine import Result
from sqlalchemy.orm import Query, Session
from airflow.dag_processing.manager import DagFileProcessorAgent
@@ -99,7 +97,6 @@ DR = DagRun
DM = DagModel
-@dataclass
class ConcurrencyMap:
"""
Dataclass to represent concurrency maps.
@@ -109,17 +106,24 @@ class ConcurrencyMap:
to # of task instances in the given state list in each DAG run.
"""
- dag_active_tasks_map: dict[str, int]
- task_concurrency_map: dict[tuple[str, str], int]
- task_dagrun_concurrency_map: dict[tuple[str, str, str], int]
-
- @classmethod
- def from_concurrency_map(cls, mapping: dict[tuple[str, str, str], int]) ->
ConcurrencyMap:
- instance = cls(Counter(), Counter(), Counter(mapping))
- for (d, _, t), c in mapping.items():
- instance.dag_active_tasks_map[d] += c
- instance.task_concurrency_map[(d, t)] += c
- return instance
+ def __init__(self):
+ self.dag_run_active_tasks_map: Counter[tuple[str, str]] = Counter()
+ self.task_concurrency_map: Counter[tuple[str, str]] = Counter()
+ self.task_dagrun_concurrency_map: Counter[tuple[str, str, str]] =
Counter()
+
+ def load(self, session: Session) -> None:
+ self.dag_run_active_tasks_map.clear()
+ self.task_concurrency_map.clear()
+ self.task_dagrun_concurrency_map.clear()
+ query = session.execute(
+ select(TI.dag_id, TI.task_id, TI.run_id, func.count("*"))
+ .where(TI.state.in_(EXECUTION_STATES))
+ .group_by(TI.task_id, TI.run_id, TI.dag_id)
+ )
+ for dag_id, task_id, run_id, c in query:
+ self.dag_run_active_tasks_map[dag_id, run_id] += c
+ self.task_concurrency_map[(dag_id, task_id)] += c
+ self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += c
def _is_parent_process() -> bool:
@@ -258,22 +262,6 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
executor.debug_dump()
self.log.info("-" * 80)
- def __get_concurrency_maps(self, states: Iterable[TaskInstanceState],
session: Session) -> ConcurrencyMap:
- """
- Get the concurrency maps.
-
- :param states: List of states to query for
- :return: Concurrency map
- """
- ti_concurrency_query: Result = session.execute(
- select(TI.task_id, TI.run_id, TI.dag_id, func.count("*"))
- .where(TI.state.in_(states))
- .group_by(TI.task_id, TI.run_id, TI.dag_id)
- )
- return ConcurrencyMap.from_concurrency_map(
- {(dag_id, run_id, task_id): count for task_id, run_id, dag_id,
count in ti_concurrency_query}
- )
-
def _executable_task_instances_to_queued(self, max_tis: int, session:
Session) -> list[TI]:
"""
Find TIs that are ready for execution based on conditions.
@@ -326,7 +314,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
# dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- concurrency_map = self.__get_concurrency_maps(states=EXECUTION_STATES,
session=session)
+ concurrency_map = ConcurrencyMap()
+ concurrency_map.load(session=session)
# Number of tasks that cannot be scheduled because of no open slot in
pool
num_starving_tasks_total = 0
@@ -465,22 +454,22 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# Check to make sure that the task max_active_tasks of the DAG
hasn't been
# reached.
dag_id = task_instance.dag_id
-
- current_active_tasks_per_dag =
concurrency_map.dag_active_tasks_map[dag_id]
- max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
+ dag_run_key = (dag_id, task_instance.run_id)
+ current_active_tasks_per_dag_run =
concurrency_map.dag_run_active_tasks_map[dag_run_key]
+ dag_max_active_tasks = task_instance.dag_model.max_active_tasks
self.log.info(
"DAG %s has %s/%s running and queued tasks",
dag_id,
- current_active_tasks_per_dag,
- max_active_tasks_per_dag_limit,
+ current_active_tasks_per_dag_run,
+ dag_max_active_tasks,
)
- if current_active_tasks_per_dag >=
max_active_tasks_per_dag_limit:
+ if current_active_tasks_per_dag_run >= dag_max_active_tasks:
self.log.info(
"Not executing %s since the number of tasks running or
queued "
"from DAG %s is >= to the DAG's max_active_tasks limit
of %s",
task_instance,
dag_id,
- max_active_tasks_per_dag_limit,
+ dag_max_active_tasks,
)
starved_dags.add(dag_id)
continue
@@ -571,7 +560,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
executable_tis.append(task_instance)
open_slots -= task_instance.pool_slots
- concurrency_map.dag_active_tasks_map[dag_id] += 1
+ concurrency_map.dag_run_active_tasks_map[dag_run_key] += 1
concurrency_map.task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
concurrency_map.task_dagrun_concurrency_map[
(task_instance.dag_id, task_instance.run_id,
task_instance.task_id)
diff --git a/newsfragments/42953.significant b/newsfragments/42953.significant
new file mode 100644
index 0000000000..20f25b4345
--- /dev/null
+++ b/newsfragments/42953.significant
@@ -0,0 +1,3 @@
+DAG.max_active_runs now evaluated per-run
+
+Previously, this was evaluated across all runs of the dag. This behavior
change was passed by lazy consensus. Vote thread:
https://lists.apache.org/thread/9o84d3yn934m32gtlpokpwtbbmtxj47l.
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 52ad2f6a8c..17fd851714 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -22,7 +22,7 @@ import datetime
import logging
import os
import sys
-from collections import deque
+from collections import Counter, deque
from datetime import timedelta
from importlib import reload
from typing import Generator
@@ -1169,86 +1169,78 @@ class TestSchedulerJob:
assert ti1.state == State.SCHEDULED
assert ti2.state == State.QUEUED
- def test_find_executable_task_instances_concurrency(self, dag_maker):
- dag_id =
"SchedulerJobTest.test_find_executable_task_instances_concurrency"
- session = settings.Session()
+ @pytest.mark.parametrize("active_state", [TaskInstanceState.RUNNING,
TaskInstanceState.QUEUED])
+ def test_find_executable_task_instances_concurrency(self, dag_maker,
active_state, session):
+ """We verify here that, with varying amounts of queued / running /
scheduled tasks,
+ the correct number of TIs are queued"""
+ dag_id = "check_MAT_dag"
with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session):
- EmptyOperator(task_id="dummy")
+ EmptyOperator(task_id="task_1")
+ EmptyOperator(task_id="task_2")
+ EmptyOperator(task_id="task_3")
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
- dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
- dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
- dr3 = dag_maker.create_dagrun_after(dr2, run_type=DagRunType.SCHEDULED)
+ dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED,
run_id="run_1", session=session)
+ dr2 = dag_maker.create_dagrun_after(
+ dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
+ )
+ dr3 = dag_maker.create_dagrun_after(
+ dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session
+ )
- ti1 = dr1.task_instances[0]
- ti2 = dr2.task_instances[0]
- ti3 = dr3.task_instances[0]
- ti1.state = State.RUNNING
- ti2.state = State.SCHEDULED
- ti3.state = State.SCHEDULED
- session.merge(ti1)
- session.merge(ti2)
- session.merge(ti3)
+ # set 2 tis in dr1 to running
+ # no more can be queued
+ t1, t2, t3 = dr1.get_task_instances(session=session)
+ t1.state = active_state
+ t2.state = active_state
+ t3.state = State.SCHEDULED
+ session.merge(t1)
+ session.merge(t2)
+ session.merge(t3)
+ # set 1 ti from dr1 to running
+ # one can be queued
+ t1, t2, t3 = dr2.get_task_instances(session=session)
+ t1.state = active_state
+ t2.state = State.SCHEDULED
+ t3.state = State.SCHEDULED
+ session.merge(t1)
+ session.merge(t2)
+ session.merge(t3)
+ # set 0 tis from dr1 to running
+ # two can be queued
+ t1, t2, t3 = dr3.get_task_instances(session=session)
+ t1.state = State.SCHEDULED
+ t2.state = State.SCHEDULED
+ t3.state = State.SCHEDULED
+ session.merge(t1)
+ session.merge(t2)
+ session.merge(t3)
session.flush()
- res = self.job_runner._executable_task_instances_to_queued(max_tis=32,
session=session)
-
- assert 1 == len(res)
- res_keys = (x.key for x in res)
- assert ti2.key in res_keys
-
- ti2.state = State.RUNNING
- session.merge(ti2)
- session.flush()
-
- res = self.job_runner._executable_task_instances_to_queued(max_tis=32,
session=session)
-
- assert 0 == len(res)
- session.rollback()
-
- def test_find_executable_task_instances_concurrency_queued(self,
dag_maker):
- dag_id =
"SchedulerJobTest.test_find_executable_task_instances_concurrency_queued"
- with dag_maker(dag_id=dag_id, max_active_tasks=3):
- task1 = EmptyOperator(task_id="dummy1")
- task2 = EmptyOperator(task_id="dummy2")
- task3 = EmptyOperator(task_id="dummy3")
-
- scheduler_job = Job()
- self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
- session = settings.Session()
-
- dag_run = dag_maker.create_dagrun()
-
- ti1 = dag_run.get_task_instance(task1.task_id)
- ti2 = dag_run.get_task_instance(task2.task_id)
- ti3 = dag_run.get_task_instance(task3.task_id)
- ti1.state = State.RUNNING
- ti2.state = State.QUEUED
- ti3.state = State.SCHEDULED
-
- session.merge(ti1)
- session.merge(ti2)
- session.merge(ti3)
+ queued_tis =
self.job_runner._executable_task_instances_to_queued(max_tis=32,
session=session)
+ queued_runs = Counter([x.run_id for x in queued_tis])
+ assert queued_runs["run_1"] == 0
+ assert queued_runs["run_2"] == 1
+ assert queued_runs["run_3"] == 2
- session.flush()
+ session.commit()
+ session.query(TaskInstance).all()
- res = self.job_runner._executable_task_instances_to_queued(max_tis=32,
session=session)
+ # now we still have max tis running so no more will be queued
+ queued_tis =
self.job_runner._executable_task_instances_to_queued(max_tis=32,
session=session)
+ assert queued_tis == []
- assert 1 == len(res)
- assert res[0].key == ti3.key
session.rollback()
# TODO: This is a hack, I think I need to just remove the setting and have
it on always
def test_find_executable_task_instances_max_active_tis_per_dag(self,
dag_maker):
dag_id =
"SchedulerJobTest.test_find_executable_task_instances_max_active_tis_per_dag"
- task_id_1 = "dummy"
- task_id_2 = "dummy2"
with dag_maker(dag_id=dag_id, max_active_tasks=16):
- task1 = EmptyOperator(task_id=task_id_1, max_active_tis_per_dag=2)
- task2 = EmptyOperator(task_id=task_id_2)
+ task1 = EmptyOperator(task_id="dummy", max_active_tis_per_dag=2)
+ task2 = EmptyOperator(task_id="dummy2")
executor = MockExecutor(do_update=True)
@@ -1653,65 +1645,88 @@ class TestSchedulerJob:
("secondary_exec", "secondary_exec"),
],
)
- def test_critical_section_enqueue_task_instances(self, task1_exec,
task2_exec, dag_maker, mock_executors):
+ def test_critical_section_enqueue_task_instances(
+ self, task1_exec, task2_exec, dag_maker, mock_executors, session
+ ):
dag_id = "SchedulerJobTest.test_execute_task_instances"
- task_id_1 = "dummy_task"
- task_id_2 = "dummy_task_nonexistent_queue"
- session = settings.Session()
# important that len(tasks) is less than max_active_tasks
# because before scheduler._execute_task_instances would only
# check the num tasks once so if max_active_tasks was 3,
# we could execute arbitrarily many tasks in the second run
with dag_maker(dag_id=dag_id, max_active_tasks=3, session=session) as
dag:
- task1 = EmptyOperator(task_id=task_id_1, executor=task1_exec)
- task2 = EmptyOperator(task_id=task_id_2, executor=task2_exec)
+ task1 = EmptyOperator(task_id="t1", executor=task1_exec)
+ task2 = EmptyOperator(task_id="t2", executor=task2_exec)
+ task3 = EmptyOperator(task_id="t3", executor=task2_exec)
+ task4 = EmptyOperator(task_id="t4", executor=task2_exec)
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
- # create first dag run with 2 running tasks
+ # create first dag run with 3 running tasks
- dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+ dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED,
session=session)
- ti1 = dr1.get_task_instance(task1.task_id, session)
- ti2 = dr1.get_task_instance(task2.task_id, session)
- ti1.state = State.RUNNING
- ti2.state = State.RUNNING
+ dr1_ti1 = dr1.get_task_instance(task1.task_id, session)
+ dr1_ti2 = dr1.get_task_instance(task2.task_id, session)
+ dr1_ti3 = dr1.get_task_instance(task3.task_id, session)
+ dr1_ti4 = dr1.get_task_instance(task4.task_id, session)
+ dr1_ti1.state = State.RUNNING
+ dr1_ti2.state = State.RUNNING
+ dr1_ti3.state = State.RUNNING
+ dr1_ti4.state = State.SCHEDULED
session.flush()
- assert State.RUNNING == dr1.state
- assert 2 == DAG.get_num_task_instances(
- dag_id, task_ids=dag.task_ids, states=[State.RUNNING],
session=session
+ assert dr1.state == State.RUNNING
+ num_tis = DAG.get_num_task_instances(
+ dag_id=dag_id,
+ task_ids=dag.task_ids,
+ states=[State.RUNNING],
+ session=session,
)
+ assert num_tis == 3
# create second dag run
- dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
- ti3 = dr2.get_task_instance(task1.task_id, session)
- ti4 = dr2.get_task_instance(task2.task_id, session)
+ dr2 = dag_maker.create_dagrun_after(dr1,
run_type=DagRunType.SCHEDULED, session=session)
+ dr2_ti1 = dr2.get_task_instance(task1.task_id, session)
+ dr2_ti2 = dr2.get_task_instance(task2.task_id, session)
+ dr2_ti3 = dr2.get_task_instance(task3.task_id, session)
+ dr2_ti4 = dr2.get_task_instance(task4.task_id, session)
# manually set to scheduled so we can pick them up
- ti3.state = State.SCHEDULED
- ti4.state = State.SCHEDULED
+ dr2_ti1.state = State.SCHEDULED
+ dr2_ti2.state = State.SCHEDULED
+ dr2_ti3.state = State.SCHEDULED
+ dr2_ti4.state = State.SCHEDULED
session.flush()
- assert State.RUNNING == dr2.state
+ assert dr2.state == State.RUNNING
- res = self.job_runner._critical_section_enqueue_task_instances(session)
+ num_queued =
self.job_runner._critical_section_enqueue_task_instances(session=session)
+ assert num_queued == 3
# check that max_active_tasks is respected
- ti1.refresh_from_db()
- ti2.refresh_from_db()
- ti3.refresh_from_db()
- ti4.refresh_from_db()
- assert 3 == DAG.get_num_task_instances(
- dag_id, task_ids=dag.task_ids, states=[State.RUNNING,
State.QUEUED], session=session
- )
- assert State.RUNNING == ti1.state
- assert State.RUNNING == ti2.state
- assert {State.QUEUED, State.SCHEDULED} == {ti3.state, ti4.state}
- assert 1 == res
- res = self.job_runner._critical_section_enqueue_task_instances(session)
- assert 0 == res
+ num_tis = DAG.get_num_task_instances(
+ dag_id=dag_id,
+ task_ids=dag.task_ids,
+ states=[State.RUNNING, State.QUEUED],
+ session=session,
+ )
+ assert num_tis == 6
+
+ # this doesn't really tell us anything since we set these values
manually, but hey
+ dr1_counter = Counter(x.state for x in
dr1.get_task_instances(session=session))
+ assert dr1_counter[State.RUNNING] == 3
+ assert dr1_counter[State.SCHEDULED] == 1
+
+ # this is the more meaningful bit
+ # three of dr2's tasks should be queued since that's max active tasks
+ # and max active tasks is evaluated per-dag-run
+ dr2_counter = Counter(x.state for x in
dr2.get_task_instances(session=session))
+ assert dr2_counter[State.QUEUED] == 3
+ assert dr2_counter[State.SCHEDULED] == 1
+
+ num_queued =
self.job_runner._critical_section_enqueue_task_instances(session=session)
+ assert num_queued == 0
def test_execute_task_instances_limit_second_executor(self, dag_maker,
mock_executors):
dag_id = "SchedulerJobTest.test_execute_task_instances_limit"