This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ae96933761 Max active tasks to be evaluated per dag run (#42953)
ae96933761 is described below

commit ae96933761b545bf4e64abe20dd30a16f5742018
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Oct 17 14:07:53 2024 -0700

    Max active tasks to be evaluated per dag run (#42953)
    
    This behavior change was accepted by lazy consensus here: 
https://lists.apache.org/thread/9o84d3yn934m32gtlpokpwtbbmtxj47l.
    
    Previously max_active_tasks was evaluated across all runs of a dag.
    
    Co-authored-by: Wei Lee <[email protected]>
---
 airflow/jobs/scheduler_job_runner.py |  67 +++++------
 newsfragments/42953.significant      |   3 +
 tests/jobs/test_scheduler_job.py     | 211 +++++++++++++++++++----------------
 3 files changed, 144 insertions(+), 137 deletions(-)

diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index ed67e7a4ac..e085efc314 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -24,7 +24,6 @@ import signal
 import sys
 import time
 from collections import Counter, defaultdict, deque
-from dataclasses import dataclass
 from datetime import timedelta
 from functools import lru_cache, partial
 from pathlib import Path
@@ -83,7 +82,6 @@ if TYPE_CHECKING:
     from datetime import datetime
     from types import FrameType
 
-    from sqlalchemy.engine import Result
     from sqlalchemy.orm import Query, Session
 
     from airflow.dag_processing.manager import DagFileProcessorAgent
@@ -99,7 +97,6 @@ DR = DagRun
 DM = DagModel
 
 
-@dataclass
 class ConcurrencyMap:
     """
     Dataclass to represent concurrency maps.
@@ -109,17 +106,24 @@ class ConcurrencyMap:
     to # of task instances in the given state list in each DAG run.
     """
 
-    dag_active_tasks_map: dict[str, int]
-    task_concurrency_map: dict[tuple[str, str], int]
-    task_dagrun_concurrency_map: dict[tuple[str, str, str], int]
-
-    @classmethod
-    def from_concurrency_map(cls, mapping: dict[tuple[str, str, str], int]) -> 
ConcurrencyMap:
-        instance = cls(Counter(), Counter(), Counter(mapping))
-        for (d, _, t), c in mapping.items():
-            instance.dag_active_tasks_map[d] += c
-            instance.task_concurrency_map[(d, t)] += c
-        return instance
+    def __init__(self):
+        self.dag_run_active_tasks_map: Counter[tuple[str, str]] = Counter()
+        self.task_concurrency_map: Counter[tuple[str, str]] = Counter()
+        self.task_dagrun_concurrency_map: Counter[tuple[str, str, str]] = 
Counter()
+
+    def load(self, session: Session) -> None:
+        self.dag_run_active_tasks_map.clear()
+        self.task_concurrency_map.clear()
+        self.task_dagrun_concurrency_map.clear()
+        query = session.execute(
+            select(TI.dag_id, TI.task_id, TI.run_id, func.count("*"))
+            .where(TI.state.in_(EXECUTION_STATES))
+            .group_by(TI.task_id, TI.run_id, TI.dag_id)
+        )
+        for dag_id, task_id, run_id, c in query:
+            self.dag_run_active_tasks_map[dag_id, run_id] += c
+            self.task_concurrency_map[(dag_id, task_id)] += c
+            self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += c
 
 
 def _is_parent_process() -> bool:
@@ -258,22 +262,6 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             executor.debug_dump()
             self.log.info("-" * 80)
 
-    def __get_concurrency_maps(self, states: Iterable[TaskInstanceState], 
session: Session) -> ConcurrencyMap:
-        """
-        Get the concurrency maps.
-
-        :param states: List of states to query for
-        :return: Concurrency map
-        """
-        ti_concurrency_query: Result = session.execute(
-            select(TI.task_id, TI.run_id, TI.dag_id, func.count("*"))
-            .where(TI.state.in_(states))
-            .group_by(TI.task_id, TI.run_id, TI.dag_id)
-        )
-        return ConcurrencyMap.from_concurrency_map(
-            {(dag_id, run_id, task_id): count for task_id, run_id, dag_id, 
count in ti_concurrency_query}
-        )
-
     def _executable_task_instances_to_queued(self, max_tis: int, session: 
Session) -> list[TI]:
         """
         Find TIs that are ready for execution based on conditions.
@@ -326,7 +314,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         starved_pools = {pool_name for pool_name, stats in pools.items() if 
stats["open"] <= 0}
 
         # dag_id to # of running tasks and (dag_id, task_id) to # of running 
tasks.
-        concurrency_map = self.__get_concurrency_maps(states=EXECUTION_STATES, 
session=session)
+        concurrency_map = ConcurrencyMap()
+        concurrency_map.load(session=session)
 
         # Number of tasks that cannot be scheduled because of no open slot in 
pool
         num_starving_tasks_total = 0
@@ -465,22 +454,22 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 # Check to make sure that the task max_active_tasks of the DAG 
hasn't been
                 # reached.
                 dag_id = task_instance.dag_id
-
-                current_active_tasks_per_dag = 
concurrency_map.dag_active_tasks_map[dag_id]
-                max_active_tasks_per_dag_limit = 
task_instance.dag_model.max_active_tasks
+                dag_run_key = (dag_id, task_instance.run_id)
+                current_active_tasks_per_dag_run = 
concurrency_map.dag_run_active_tasks_map[dag_run_key]
+                dag_max_active_tasks = task_instance.dag_model.max_active_tasks
                 self.log.info(
                     "DAG %s has %s/%s running and queued tasks",
                     dag_id,
-                    current_active_tasks_per_dag,
-                    max_active_tasks_per_dag_limit,
+                    current_active_tasks_per_dag_run,
+                    dag_max_active_tasks,
                 )
-                if current_active_tasks_per_dag >= 
max_active_tasks_per_dag_limit:
+                if current_active_tasks_per_dag_run >= dag_max_active_tasks:
                     self.log.info(
                         "Not executing %s since the number of tasks running or 
queued "
                         "from DAG %s is >= to the DAG's max_active_tasks limit 
of %s",
                         task_instance,
                         dag_id,
-                        max_active_tasks_per_dag_limit,
+                        dag_max_active_tasks,
                     )
                     starved_dags.add(dag_id)
                     continue
@@ -571,7 +560,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
                 executable_tis.append(task_instance)
                 open_slots -= task_instance.pool_slots
-                concurrency_map.dag_active_tasks_map[dag_id] += 1
+                concurrency_map.dag_run_active_tasks_map[dag_run_key] += 1
                 concurrency_map.task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
                 concurrency_map.task_dagrun_concurrency_map[
                     (task_instance.dag_id, task_instance.run_id, 
task_instance.task_id)
diff --git a/newsfragments/42953.significant b/newsfragments/42953.significant
new file mode 100644
index 0000000000..20f25b4345
--- /dev/null
+++ b/newsfragments/42953.significant
@@ -0,0 +1,3 @@
+DAG.max_active_runs now evaluated per-run
+
+Previously, this was evaluated across all runs of the dag. This behavior 
change was passed by lazy consensus.  Vote thread: 
https://lists.apache.org/thread/9o84d3yn934m32gtlpokpwtbbmtxj47l.
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 52ad2f6a8c..17fd851714 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -22,7 +22,7 @@ import datetime
 import logging
 import os
 import sys
-from collections import deque
+from collections import Counter, deque
 from datetime import timedelta
 from importlib import reload
 from typing import Generator
@@ -1169,86 +1169,78 @@ class TestSchedulerJob:
         assert ti1.state == State.SCHEDULED
         assert ti2.state == State.QUEUED
 
-    def test_find_executable_task_instances_concurrency(self, dag_maker):
-        dag_id = 
"SchedulerJobTest.test_find_executable_task_instances_concurrency"
-        session = settings.Session()
+    @pytest.mark.parametrize("active_state", [TaskInstanceState.RUNNING, 
TaskInstanceState.QUEUED])
+    def test_find_executable_task_instances_concurrency(self, dag_maker, 
active_state, session):
+        """We verify here that, with varying amounts of queued / running / 
scheduled tasks,
+        the correct number of TIs are queued"""
+        dag_id = "check_MAT_dag"
         with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session):
-            EmptyOperator(task_id="dummy")
+            EmptyOperator(task_id="task_1")
+            EmptyOperator(task_id="task_2")
+            EmptyOperator(task_id="task_3")
 
         scheduler_job = Job()
         self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
 
-        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
-        dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
-        dr3 = dag_maker.create_dagrun_after(dr2, run_type=DagRunType.SCHEDULED)
+        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, 
run_id="run_1", session=session)
+        dr2 = dag_maker.create_dagrun_after(
+            dr1, run_type=DagRunType.SCHEDULED, run_id="run_2", session=session
+        )
+        dr3 = dag_maker.create_dagrun_after(
+            dr2, run_type=DagRunType.SCHEDULED, run_id="run_3", session=session
+        )
 
-        ti1 = dr1.task_instances[0]
-        ti2 = dr2.task_instances[0]
-        ti3 = dr3.task_instances[0]
-        ti1.state = State.RUNNING
-        ti2.state = State.SCHEDULED
-        ti3.state = State.SCHEDULED
-        session.merge(ti1)
-        session.merge(ti2)
-        session.merge(ti3)
+        # set 2 tis in dr1 to running
+        # no more can be queued
+        t1, t2, t3 = dr1.get_task_instances(session=session)
+        t1.state = active_state
+        t2.state = active_state
+        t3.state = State.SCHEDULED
+        session.merge(t1)
+        session.merge(t2)
+        session.merge(t3)
+        # set 1 ti from dr1 to running
+        # one can be queued
+        t1, t2, t3 = dr2.get_task_instances(session=session)
+        t1.state = active_state
+        t2.state = State.SCHEDULED
+        t3.state = State.SCHEDULED
+        session.merge(t1)
+        session.merge(t2)
+        session.merge(t3)
+        # set 0 tis from dr1 to running
+        # two can be queued
+        t1, t2, t3 = dr3.get_task_instances(session=session)
+        t1.state = State.SCHEDULED
+        t2.state = State.SCHEDULED
+        t3.state = State.SCHEDULED
+        session.merge(t1)
+        session.merge(t2)
+        session.merge(t3)
 
         session.flush()
 
-        res = self.job_runner._executable_task_instances_to_queued(max_tis=32, 
session=session)
-
-        assert 1 == len(res)
-        res_keys = (x.key for x in res)
-        assert ti2.key in res_keys
-
-        ti2.state = State.RUNNING
-        session.merge(ti2)
-        session.flush()
-
-        res = self.job_runner._executable_task_instances_to_queued(max_tis=32, 
session=session)
-
-        assert 0 == len(res)
-        session.rollback()
-
-    def test_find_executable_task_instances_concurrency_queued(self, 
dag_maker):
-        dag_id = 
"SchedulerJobTest.test_find_executable_task_instances_concurrency_queued"
-        with dag_maker(dag_id=dag_id, max_active_tasks=3):
-            task1 = EmptyOperator(task_id="dummy1")
-            task2 = EmptyOperator(task_id="dummy2")
-            task3 = EmptyOperator(task_id="dummy3")
-
-        scheduler_job = Job()
-        self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
-        session = settings.Session()
-
-        dag_run = dag_maker.create_dagrun()
-
-        ti1 = dag_run.get_task_instance(task1.task_id)
-        ti2 = dag_run.get_task_instance(task2.task_id)
-        ti3 = dag_run.get_task_instance(task3.task_id)
-        ti1.state = State.RUNNING
-        ti2.state = State.QUEUED
-        ti3.state = State.SCHEDULED
-
-        session.merge(ti1)
-        session.merge(ti2)
-        session.merge(ti3)
+        queued_tis = 
self.job_runner._executable_task_instances_to_queued(max_tis=32, 
session=session)
+        queued_runs = Counter([x.run_id for x in queued_tis])
+        assert queued_runs["run_1"] == 0
+        assert queued_runs["run_2"] == 1
+        assert queued_runs["run_3"] == 2
 
-        session.flush()
+        session.commit()
+        session.query(TaskInstance).all()
 
-        res = self.job_runner._executable_task_instances_to_queued(max_tis=32, 
session=session)
+        # now we still have max tis running so no more will be queued
+        queued_tis = 
self.job_runner._executable_task_instances_to_queued(max_tis=32, 
session=session)
+        assert queued_tis == []
 
-        assert 1 == len(res)
-        assert res[0].key == ti3.key
         session.rollback()
 
     # TODO: This is a hack, I think I need to just remove the setting and have 
it on always
     def test_find_executable_task_instances_max_active_tis_per_dag(self, 
dag_maker):
         dag_id = 
"SchedulerJobTest.test_find_executable_task_instances_max_active_tis_per_dag"
-        task_id_1 = "dummy"
-        task_id_2 = "dummy2"
         with dag_maker(dag_id=dag_id, max_active_tasks=16):
-            task1 = EmptyOperator(task_id=task_id_1, max_active_tis_per_dag=2)
-            task2 = EmptyOperator(task_id=task_id_2)
+            task1 = EmptyOperator(task_id="dummy", max_active_tis_per_dag=2)
+            task2 = EmptyOperator(task_id="dummy2")
 
         executor = MockExecutor(do_update=True)
 
@@ -1653,65 +1645,88 @@ class TestSchedulerJob:
             ("secondary_exec", "secondary_exec"),
         ],
     )
-    def test_critical_section_enqueue_task_instances(self, task1_exec, 
task2_exec, dag_maker, mock_executors):
+    def test_critical_section_enqueue_task_instances(
+        self, task1_exec, task2_exec, dag_maker, mock_executors, session
+    ):
         dag_id = "SchedulerJobTest.test_execute_task_instances"
-        task_id_1 = "dummy_task"
-        task_id_2 = "dummy_task_nonexistent_queue"
-        session = settings.Session()
         # important that len(tasks) is less than max_active_tasks
         # because before scheduler._execute_task_instances would only
         # check the num tasks once so if max_active_tasks was 3,
         # we could execute arbitrarily many tasks in the second run
         with dag_maker(dag_id=dag_id, max_active_tasks=3, session=session) as 
dag:
-            task1 = EmptyOperator(task_id=task_id_1, executor=task1_exec)
-            task2 = EmptyOperator(task_id=task_id_2, executor=task2_exec)
+            task1 = EmptyOperator(task_id="t1", executor=task1_exec)
+            task2 = EmptyOperator(task_id="t2", executor=task2_exec)
+            task3 = EmptyOperator(task_id="t3", executor=task2_exec)
+            task4 = EmptyOperator(task_id="t4", executor=task2_exec)
 
         scheduler_job = Job()
         self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
 
-        # create first dag run with 2 running tasks
+        # create first dag run with 3 running tasks
 
-        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, 
session=session)
 
-        ti1 = dr1.get_task_instance(task1.task_id, session)
-        ti2 = dr1.get_task_instance(task2.task_id, session)
-        ti1.state = State.RUNNING
-        ti2.state = State.RUNNING
+        dr1_ti1 = dr1.get_task_instance(task1.task_id, session)
+        dr1_ti2 = dr1.get_task_instance(task2.task_id, session)
+        dr1_ti3 = dr1.get_task_instance(task3.task_id, session)
+        dr1_ti4 = dr1.get_task_instance(task4.task_id, session)
+        dr1_ti1.state = State.RUNNING
+        dr1_ti2.state = State.RUNNING
+        dr1_ti3.state = State.RUNNING
+        dr1_ti4.state = State.SCHEDULED
         session.flush()
 
-        assert State.RUNNING == dr1.state
-        assert 2 == DAG.get_num_task_instances(
-            dag_id, task_ids=dag.task_ids, states=[State.RUNNING], 
session=session
+        assert dr1.state == State.RUNNING
+        num_tis = DAG.get_num_task_instances(
+            dag_id=dag_id,
+            task_ids=dag.task_ids,
+            states=[State.RUNNING],
+            session=session,
         )
+        assert num_tis == 3
 
         # create second dag run
-        dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
-        ti3 = dr2.get_task_instance(task1.task_id, session)
-        ti4 = dr2.get_task_instance(task2.task_id, session)
+        dr2 = dag_maker.create_dagrun_after(dr1, 
run_type=DagRunType.SCHEDULED, session=session)
+        dr2_ti1 = dr2.get_task_instance(task1.task_id, session)
+        dr2_ti2 = dr2.get_task_instance(task2.task_id, session)
+        dr2_ti3 = dr2.get_task_instance(task3.task_id, session)
+        dr2_ti4 = dr2.get_task_instance(task4.task_id, session)
         # manually set to scheduled so we can pick them up
-        ti3.state = State.SCHEDULED
-        ti4.state = State.SCHEDULED
+        dr2_ti1.state = State.SCHEDULED
+        dr2_ti2.state = State.SCHEDULED
+        dr2_ti3.state = State.SCHEDULED
+        dr2_ti4.state = State.SCHEDULED
         session.flush()
 
-        assert State.RUNNING == dr2.state
+        assert dr2.state == State.RUNNING
 
-        res = self.job_runner._critical_section_enqueue_task_instances(session)
+        num_queued = 
self.job_runner._critical_section_enqueue_task_instances(session=session)
+        assert num_queued == 3
 
         # check that max_active_tasks is respected
-        ti1.refresh_from_db()
-        ti2.refresh_from_db()
-        ti3.refresh_from_db()
-        ti4.refresh_from_db()
-        assert 3 == DAG.get_num_task_instances(
-            dag_id, task_ids=dag.task_ids, states=[State.RUNNING, 
State.QUEUED], session=session
-        )
-        assert State.RUNNING == ti1.state
-        assert State.RUNNING == ti2.state
-        assert {State.QUEUED, State.SCHEDULED} == {ti3.state, ti4.state}
-        assert 1 == res
 
-        res = self.job_runner._critical_section_enqueue_task_instances(session)
-        assert 0 == res
+        num_tis = DAG.get_num_task_instances(
+            dag_id=dag_id,
+            task_ids=dag.task_ids,
+            states=[State.RUNNING, State.QUEUED],
+            session=session,
+        )
+        assert num_tis == 6
+
+        # this doesn't really tell us anything since we set these values 
manually, but hey
+        dr1_counter = Counter(x.state for x in 
dr1.get_task_instances(session=session))
+        assert dr1_counter[State.RUNNING] == 3
+        assert dr1_counter[State.SCHEDULED] == 1
+
+        # this is the more meaningful bit
+        # three of dr2's tasks should be queued since that's max active tasks
+        # and max active tasks is evaluated per-dag-run
+        dr2_counter = Counter(x.state for x in 
dr2.get_task_instances(session=session))
+        assert dr2_counter[State.QUEUED] == 3
+        assert dr2_counter[State.SCHEDULED] == 1
+
+        num_queued = 
self.job_runner._critical_section_enqueue_task_instances(session=session)
+        assert num_queued == 0
 
     def test_execute_task_instances_limit_second_executor(self, dag_maker, 
mock_executors):
         dag_id = "SchedulerJobTest.test_execute_task_instances_limit"

Reply via email to