This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ebeb19bf7 Add `max_active_tis_per_dagrun` for Dynamic Task Mapping 
(#29094)
1ebeb19bf7 is described below

commit 1ebeb19bf7542850fff2f1e2f9795ad70c1b24e2
Author: Hussein Awala <[email protected]>
AuthorDate: Fri Apr 14 12:36:08 2023 +0200

    Add `max_active_tis_per_dagrun` for Dynamic Task Mapping (#29094)
    
    * add max_active_tis_per_dagrun param to BaseOperator
    
    * set has_task_concurrency_limits when max_active_tis_per_dagrun is not None
    
    * check if max_active_tis_per_dagrun is reached in the task deps
    
    * check if all the tasks have None max_active_tis_per_dagrun before auto 
schedule the dagrun
    
    * check if the max_active_tis_per_dagrun is reached before queuing the ti
    
    * check max_active_tis_per_dagrun in backfill job
    
    * fix current tests and ensure everything is ok before adding new tests
    
    * refacto TestTaskConcurrencyDep
    
    * fix a bug in TaskConcurrencyDep
    
    * test max_active_tis_per_dagrun in TaskConcurrencyDep
    
    * tests max_active_tis_per_dagrun in TestTaskInstance
    
    * test dag_file_processor with max_active_tis_per_dagrun
    
    * test scheduling with max_active_tis_per_dagrun on different DAG runs
    
    * test scheduling mapped task with max_active_tis_per_dagrun
    
    * test max_active_tis_per_dagrun with backfill CLI
    
    * add new starved_tasks filter to avoid affecting the scheduling perf
    
    * unify the usage of TaskInstance filters and use TI
    
    * refacto concurrecy map type and create a new dataclass
    
    * move docstring to ConcurrencyMap class and create a method for 
default_factory
    
    * move concurrency_map creation to ConcurrencyMap class
    
    * replace default dicts by counters
    
    * replace all default dicts by counters in the scheduler_job_runner module
    
    * suggestions from review
---
 airflow/jobs/backfill_job_runner.py           |  16 +++-
 airflow/jobs/scheduler_job_runner.py          | 126 ++++++++++++++++++--------
 airflow/models/baseoperator.py                |   6 ++
 airflow/models/dag.py                         |  12 ++-
 airflow/models/dagrun.py                      |   1 +
 airflow/models/mappedoperator.py              |   4 +
 airflow/models/taskinstance.py                |  17 ++--
 airflow/ti_deps/deps/task_concurrency_dep.py  |  17 +++-
 tests/conftest.py                             |   2 +
 tests/jobs/test_backfill_job.py               |  83 +++++++++++++++++
 tests/jobs/test_scheduler_job.py              | 116 +++++++++++++++++++++++-
 tests/models/test_dag.py                      |  18 ++--
 tests/models/test_taskinstance.py             |  13 +++
 tests/serialization/test_dag_serialization.py |   1 +
 tests/ti_deps/deps/test_task_concurrency.py   |  41 +++++----
 15 files changed, 390 insertions(+), 83 deletions(-)

diff --git a/airflow/jobs/backfill_job_runner.py 
b/airflow/jobs/backfill_job_runner.py
index c99cae2d21..4a78890d3b 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -618,7 +618,7 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin):
                                 "Not scheduling since DAG max_active_tasks 
limit is reached."
                             )
 
-                        if task.max_active_tis_per_dag:
+                        if task.max_active_tis_per_dag is not None:
                             num_running_task_instances_in_task = 
DAG.get_num_task_instances(
                                 dag_id=self.dag_id,
                                 task_ids=[task.task_id],
@@ -631,6 +631,20 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin):
                                     "Not scheduling since Task concurrency 
limit is reached."
                                 )
 
+                        if task.max_active_tis_per_dagrun is not None:
+                            num_running_task_instances_in_task_dagrun = 
DAG.get_num_task_instances(
+                                dag_id=self.dag_id,
+                                run_id=ti.run_id,
+                                task_ids=[task.task_id],
+                                states=self.STATES_COUNT_AS_RUNNING,
+                                session=session,
+                            )
+
+                            if num_running_task_instances_in_task_dagrun >= 
task.max_active_tis_per_dagrun:
+                                raise TaskConcurrencyLimitReached(
+                                    "Not scheduling since Task concurrency per 
DAG run limit is reached."
+                                )
+
                         _per_task_process(key, ti, session)
                         session.commit()
             except (NoAvailablePoolSlot, DagConcurrencyLimitReached, 
TaskConcurrencyLimitReached) as e:
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 706d980253..aae98373b6 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -25,10 +25,11 @@ import signal
 import sys
 import time
 import warnings
-from collections import defaultdict
+from collections import Counter
+from dataclasses import dataclass
 from datetime import datetime, timedelta
 from pathlib import Path
-from typing import TYPE_CHECKING, Collection, DefaultDict, Iterator
+from typing import TYPE_CHECKING, Collection, Iterable, Iterator
 
 from sqlalchemy import and_, func, not_, or_, text
 from sqlalchemy.exc import OperationalError
@@ -85,6 +86,29 @@ DR = DagRun
 DM = DagModel
 
 
+@dataclass
+class ConcurrencyMap:
+    """
+    Dataclass to represent concurrency maps
+
+    It contains a map from (dag_id, task_id) to # of task instances, a map 
from (dag_id, task_id)
+    to # of task instances in the given state list and a map from (dag_id, 
run_id, task_id)
+    to # of task instances in the given state list in each DAG run.
+    """
+
+    dag_active_tasks_map: dict[str, int]
+    task_concurrency_map: dict[tuple[str, str], int]
+    task_dagrun_concurrency_map: dict[tuple[str, str, str], int]
+
+    @classmethod
+    def from_concurrency_map(cls, mapping: dict[tuple[str, str, str], int]) -> 
ConcurrencyMap:
+        instance = cls(Counter(), Counter(), Counter(mapping))
+        for (d, r, t), c in mapping.items():
+            instance.dag_active_tasks_map[d] += c
+            instance.task_concurrency_map[(d, t)] += c
+        return instance
+
+
 def _is_parent_process() -> bool:
     """
     Whether this is a parent process.
@@ -231,28 +255,21 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             < scheduler_health_check_threshold
         )
 
-    def __get_concurrency_maps(
-        self, states: list[TaskInstanceState], session: Session
-    ) -> tuple[DefaultDict[str, int], DefaultDict[tuple[str, str], int]]:
+    def __get_concurrency_maps(self, states: Iterable[TaskInstanceState], 
session: Session) -> ConcurrencyMap:
         """
         Get the concurrency maps.
 
         :param states: List of states to query for
-        :return: A map from (dag_id, task_id) to # of task instances and
-         a map from (dag_id, task_id) to # of task instances in the given 
state list
+        :return: Concurrency map
         """
-        ti_concurrency_query: list[tuple[str, str, int]] = (
-            session.query(TI.task_id, TI.dag_id, func.count("*"))
+        ti_concurrency_query: list[tuple[str, str, str, int]] = (
+            session.query(TI.task_id, TI.run_id, TI.dag_id, func.count("*"))
             .filter(TI.state.in_(states))
-            .group_by(TI.task_id, TI.dag_id)
-        ).all()
-        dag_map: DefaultDict[str, int] = defaultdict(int)
-        task_map: DefaultDict[tuple[str, str], int] = defaultdict(int)
-        for result in ti_concurrency_query:
-            task_id, dag_id, count = result
-            dag_map[dag_id] += count
-            task_map[(dag_id, task_id)] = count
-        return dag_map, task_map
+            .group_by(TI.task_id, TI.run_id, TI.dag_id)
+        )
+        return ConcurrencyMap.from_concurrency_map(
+            {(dag_id, run_id, task_id): count for task_id, run_id, dag_id, 
count in ti_concurrency_query}
+        )
 
     def _executable_task_instances_to_queued(self, max_tis: int, session: 
Session) -> list[TI]:
         """
@@ -263,6 +280,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         - DAG max_active_tasks
         - executor state
         - priority
+        - max active tis per DAG
+        - max active tis per DAG run
 
         :param max_tis: Maximum number of TIs to queue in this loop.
         :return: list[airflow.models.TaskInstance]
@@ -304,11 +323,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         starved_pools = {pool_name for pool_name, stats in pools.items() if 
stats["open"] <= 0}
 
         # dag_id to # of running tasks and (dag_id, task_id) to # of running 
tasks.
-        dag_active_tasks_map: DefaultDict[str, int]
-        task_concurrency_map: DefaultDict[tuple[str, str], int]
-        dag_active_tasks_map, task_concurrency_map = 
self.__get_concurrency_maps(
-            states=list(EXECUTION_STATES), session=session
-        )
+        concurrency_map = self.__get_concurrency_maps(states=EXECUTION_STATES, 
session=session)
 
         # Number of tasks that cannot be scheduled because of no open slot in 
pool
         num_starving_tasks_total = 0
@@ -316,14 +331,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         # dag and task ids that can't be queued because of concurrency limits
         starved_dags: set[str] = set()
         starved_tasks: set[tuple[str, str]] = set()
+        starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] = 
set()
 
-        pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)
+        pool_num_starving_tasks: dict[str, int] = Counter()
 
         for loop_count in itertools.count(start=1):
 
             num_starved_pools = len(starved_pools)
             num_starved_dags = len(starved_dags)
             num_starved_tasks = len(starved_tasks)
+            num_starved_tasks_task_dagrun_concurrency = 
len(starved_tasks_task_dagrun_concurrency)
 
             # Get task instances associated with scheduled
             # DagRuns which are not backfilled, in the given states,
@@ -347,7 +364,14 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 query = query.filter(not_(TI.dag_id.in_(starved_dags)))
 
             if starved_tasks:
-                task_filter = tuple_in_condition((TaskInstance.dag_id, 
TaskInstance.task_id), starved_tasks)
+                task_filter = tuple_in_condition((TI.dag_id, TI.task_id), 
starved_tasks)
+                query = query.filter(not_(task_filter))
+
+            if starved_tasks_task_dagrun_concurrency:
+                task_filter = tuple_in_condition(
+                    (TI.dag_id, TI.run_id, TI.task_id),
+                    starved_tasks_task_dagrun_concurrency,
+                )
                 query = query.filter(not_(task_filter))
 
             query = query.limit(max_tis)
@@ -439,7 +463,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 # reached.
                 dag_id = task_instance.dag_id
 
-                current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
+                current_active_tasks_per_dag = 
concurrency_map.dag_active_tasks_map[dag_id]
                 max_active_tasks_per_dag_limit = 
task_instance.dag_model.max_active_tasks
                 self.log.info(
                     "DAG %s has %s/%s running and queued tasks",
@@ -481,7 +505,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                         ).max_active_tis_per_dag
 
                     if task_concurrency_limit is not None:
-                        current_task_concurrency = task_concurrency_map[
+                        current_task_concurrency = 
concurrency_map.task_concurrency_map[
                             (task_instance.dag_id, task_instance.task_id)
                         ]
 
@@ -494,10 +518,35 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                             starved_tasks.add((task_instance.dag_id, 
task_instance.task_id))
                             continue
 
+                    task_dagrun_concurrency_limit: int | None = None
+                    if serialized_dag.has_task(task_instance.task_id):
+                        task_dagrun_concurrency_limit = 
serialized_dag.get_task(
+                            task_instance.task_id
+                        ).max_active_tis_per_dagrun
+
+                    if task_dagrun_concurrency_limit is not None:
+                        current_task_dagrun_concurrency = 
concurrency_map.task_dagrun_concurrency_map[
+                            (task_instance.dag_id, task_instance.run_id, 
task_instance.task_id)
+                        ]
+
+                        if current_task_dagrun_concurrency >= 
task_dagrun_concurrency_limit:
+                            self.log.info(
+                                "Not executing %s since the task concurrency 
per DAG run for"
+                                " this task has been reached.",
+                                task_instance,
+                            )
+                            starved_tasks_task_dagrun_concurrency.add(
+                                (task_instance.dag_id, task_instance.run_id, 
task_instance.task_id)
+                            )
+                            continue
+
                 executable_tis.append(task_instance)
                 open_slots -= task_instance.pool_slots
-                dag_active_tasks_map[dag_id] += 1
-                task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
+                concurrency_map.dag_active_tasks_map[dag_id] += 1
+                concurrency_map.task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
+                concurrency_map.task_dagrun_concurrency_map[
+                    (task_instance.dag_id, task_instance.run_id, 
task_instance.task_id)
+                ] += 1
 
                 pool_stats["open"] = open_slots
 
@@ -507,6 +556,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 len(starved_pools) > num_starved_pools
                 or len(starved_dags) > num_starved_dags
                 or len(starved_tasks) > num_starved_tasks
+                or len(starved_tasks_task_dagrun_concurrency) > 
num_starved_tasks_task_dagrun_concurrency
             )
 
             if is_done or not found_new_filters:
@@ -816,13 +866,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             paused_runs = (
                 session.query(DagRun)
                 .join(DagRun.dag_model)
-                .join(TaskInstance)
+                .join(TI)
                 .filter(
                     DagModel.is_paused == expression.true(),
                     DagRun.state == DagRunState.RUNNING,
                     DagRun.run_type != DagRunType.BACKFILL_JOB,
                 )
-                .having(DagRun.last_scheduling_decision <= 
func.max(TaskInstance.updated_at))
+                .having(DagRun.last_scheduling_decision <= 
func.max(TI.updated_at))
                 .group_by(DagRun)
             )
             for dag_run in paused_runs:
@@ -1079,8 +1129,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             .all()
         )
 
-        active_runs_of_dags = defaultdict(
-            int,
+        active_runs_of_dags = Counter(
             DagRun.active_runs_of_dags(dag_ids=(dm.dag_id for dm in 
dag_models), session=session),
         )
 
@@ -1237,8 +1286,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         """Find DagRuns in queued state and decide moving them to running 
state."""
         dag_runs = self._get_next_dagruns_to_examine(DagRunState.QUEUED, 
session)
 
-        active_runs_of_dags = defaultdict(
-            int,
+        active_runs_of_dags = Counter(
             DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs), 
only_running=True, session=session),
         )
 
@@ -1533,10 +1581,10 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         or execution timeout has passed, so they can be marked as failed.
         """
         num_timed_out_tasks = (
-            session.query(TaskInstance)
+            session.query(TI)
             .filter(
-                TaskInstance.state == TaskInstanceState.DEFERRED,
-                TaskInstance.trigger_timeout < timezone.utcnow(),
+                TI.state == TaskInstanceState.DEFERRED,
+                TI.trigger_timeout < timezone.utcnow(),
             )
             .update(
                 # We have to schedule these to fail themselves so it doesn't
@@ -1599,7 +1647,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             )
 
     @staticmethod
-    def _generate_zombie_message_details(ti: TaskInstance):
+    def _generate_zombie_message_details(ti: TI):
         zombie_message_details = {
             "DAG Id": ti.dag_id,
             "Task Id": ti.task_id,
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 1760bd1f42..d81d45dc7e 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -214,6 +214,7 @@ def partial(
     weight_rule: str = DEFAULT_WEIGHT_RULE,
     sla: timedelta | None = None,
     max_active_tis_per_dag: int | None = None,
+    max_active_tis_per_dagrun: int | None = None,
     on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
     on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
     on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None,
@@ -274,6 +275,7 @@ def partial(
     partial_kwargs.setdefault("weight_rule", weight_rule)
     partial_kwargs.setdefault("sla", sla)
     partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
+    partial_kwargs.setdefault("max_active_tis_per_dagrun", 
max_active_tis_per_dagrun)
     partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
     partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
     partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
@@ -578,6 +580,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     :param run_as_user: unix username to impersonate while running the task
     :param max_active_tis_per_dag: When set, a task will be able to limit the 
concurrent
         runs across execution_dates.
+    :param max_active_tis_per_dagrun: When set, a task will be able to limit 
the concurrent
+        task instances per DAG run.
     :param executor_config: Additional task-level configuration parameters 
that are
         interpreted by a specific executor. Parameters are namespaced by the 
name of
         executor.
@@ -729,6 +733,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         run_as_user: str | None = None,
         task_concurrency: int | None = None,
         max_active_tis_per_dag: int | None = None,
+        max_active_tis_per_dagrun: int | None = None,
         executor_config: dict | None = None,
         do_xcom_push: bool = True,
         inlets: Any | None = None,
@@ -872,6 +877,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
             )
             max_active_tis_per_dag = task_concurrency
         self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
+        self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
         self.do_xcom_push = do_xcom_push
 
         self.doc_md = doc_md
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8b259cd3e8..9888e5cbd8 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2789,7 +2789,10 @@ class DAG(LoggingMixin):
             orm_dag.description = dag.description
             orm_dag.max_active_tasks = dag.max_active_tasks
             orm_dag.max_active_runs = dag.max_active_runs
-            orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag 
is not None for t in dag.tasks)
+            orm_dag.has_task_concurrency_limits = any(
+                t.max_active_tis_per_dag is not None or 
t.max_active_tis_per_dagrun is not None
+                for t in dag.tasks
+            )
             orm_dag.schedule_interval = dag.schedule_interval
             orm_dag.timetable_description = dag.timetable.description
             orm_dag.processor_subdir = processor_subdir
@@ -2990,12 +2993,13 @@ class DAG(LoggingMixin):
 
     @staticmethod
     @provide_session
-    def get_num_task_instances(dag_id, task_ids=None, states=None, 
session=NEW_SESSION) -> int:
+    def get_num_task_instances(dag_id, run_id=None, task_ids=None, 
states=None, session=NEW_SESSION) -> int:
         """
         Returns the number of task instances in the given DAG.
 
         :param session: ORM session
         :param dag_id: ID of the DAG to get the task concurrency of
+        :param run_id: ID of the DAG run to get the task concurrency of
         :param task_ids: A list of valid task IDs for the given DAG
         :param states: A list of states to filter by if supplied
         :return: The number of running tasks
@@ -3003,6 +3007,10 @@ class DAG(LoggingMixin):
         qry = session.query(func.count(TaskInstance.task_id)).filter(
             TaskInstance.dag_id == dag_id,
         )
+        if run_id:
+            qry = qry.filter(
+                TaskInstance.run_id == run_id,
+            )
         if task_ids:
             qry = qry.filter(
                 TaskInstance.task_id.in_(task_ids),
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 101bda0a47..edb0ec78ac 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -553,6 +553,7 @@ class DagRun(Base, LoggingMixin):
                     bool(self.tis)
                     and all(not t.task.depends_on_past for t in self.tis)
                     and all(t.task.max_active_tis_per_dag is None for t in 
self.tis)
+                    and all(t.task.max_active_tis_per_dagrun is None for t in 
self.tis)
                     and all(t.state != TaskInstanceState.DEFERRED for t in 
self.tis)
                 )
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index a10fd10cdc..345329ef6d 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -451,6 +451,10 @@ class MappedOperator(AbstractOperator):
     def max_active_tis_per_dag(self) -> int | None:
         return self.partial_kwargs.get("max_active_tis_per_dag")
 
+    @property
+    def max_active_tis_per_dagrun(self) -> int | None:
+        return self.partial_kwargs.get("max_active_tis_per_dagrun")
+
     @property
     def resources(self) -> Resources | None:
         return self.partial_kwargs.get("resources")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c2ed7b7868..b02076c076 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2479,18 +2479,17 @@ class TaskInstance(Base, LoggingMixin):
         return LazyXComAccess.build_from_xcom_query(query)
 
     @provide_session
-    def get_num_running_task_instances(self, session: Session) -> int:
+    def get_num_running_task_instances(self, session: Session, 
same_dagrun=False) -> int:
         """Return Number of running TIs from the DB"""
         # .count() is inefficient
-        return (
-            session.query(func.count())
-            .filter(
-                TaskInstance.dag_id == self.dag_id,
-                TaskInstance.task_id == self.task_id,
-                TaskInstance.state == State.RUNNING,
-            )
-            .scalar()
+        num_running_task_instances_query = session.query(func.count()).filter(
+            TaskInstance.dag_id == self.dag_id,
+            TaskInstance.task_id == self.task_id,
+            TaskInstance.state == State.RUNNING,
         )
+        if same_dagrun:
+            num_running_task_instances_query.filter(TaskInstance.run_id == 
self.run_id)
+        return num_running_task_instances_query.scalar()
 
     def init_run_context(self, raw: bool = False) -> None:
         """Sets the log context."""
diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py 
b/airflow/ti_deps/deps/task_concurrency_dep.py
index 5b5f4f515a..1f1416214c 100644
--- a/airflow/ti_deps/deps/task_concurrency_dep.py
+++ b/airflow/ti_deps/deps/task_concurrency_dep.py
@@ -30,13 +30,22 @@ class TaskConcurrencyDep(BaseTIDep):
 
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context):
-        if ti.task.max_active_tis_per_dag is None:
+        if ti.task.max_active_tis_per_dag is None and 
ti.task.max_active_tis_per_dagrun is None:
             yield self._passing_status(reason="Task concurrency is not set.")
             return
 
-        if ti.get_num_running_task_instances(session) >= 
ti.task.max_active_tis_per_dag:
+        if (
+            ti.task.max_active_tis_per_dag is not None
+            and ti.get_num_running_task_instances(session) >= 
ti.task.max_active_tis_per_dag
+        ):
             yield self._failing_status(reason="The max task concurrency has 
been reached.")
             return
-        else:
-            yield self._passing_status(reason="The max task concurrency has 
not been reached.")
+        if (
+            ti.task.max_active_tis_per_dagrun is not None
+            and ti.get_num_running_task_instances(session, same_dagrun=True)
+            >= ti.task.max_active_tis_per_dagrun
+        ):
+            yield self._failing_status(reason="The max task concurrency per 
run has been reached.")
             return
+        yield self._passing_status(reason="The max task concurrency has not 
been reached.")
+        return
diff --git a/tests/conftest.py b/tests/conftest.py
index cfae846895..f169e2b62e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -726,6 +726,7 @@ def create_dummy_dag(dag_maker):
         dag_id="dag",
         task_id="op1",
         max_active_tis_per_dag=16,
+        max_active_tis_per_dagrun=None,
         pool="default_pool",
         executor_config={},
         trigger_rule="all_done",
@@ -741,6 +742,7 @@ def create_dummy_dag(dag_maker):
             op = EmptyOperator(
                 task_id=task_id,
                 max_active_tis_per_dag=max_active_tis_per_dag,
+                max_active_tis_per_dagrun=max_active_tis_per_dagrun,
                 executor_config=executor_config,
                 on_success_callback=on_success_callback,
                 on_execute_callback=on_execute_callback,
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 7001f6f1b9..164d5af1df 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -21,6 +21,7 @@ import datetime
 import json
 import logging
 import threading
+from collections import defaultdict
 from unittest import mock
 from unittest.mock import patch
 
@@ -376,6 +377,88 @@ class TestBackfillJob:
         assert 0 == times_dag_concurrency_limit_reached_in_debug
         assert times_task_concurrency_limit_reached_in_debug > 0
 
+    @pytest.mark.parametrize("with_max_active_tis_per_dag", [False, True])
+    @patch("airflow.jobs.backfill_job_runner.BackfillJobRunner.log")
+    def test_backfill_respect_max_active_tis_per_dagrun_limit(
+        self, mock_log, dag_maker, with_max_active_tis_per_dag
+    ):
+        max_active_tis_per_dag = 3
+        max_active_tis_per_dagrun = 2
+        kwargs = {"max_active_tis_per_dagrun": max_active_tis_per_dagrun}
+        if with_max_active_tis_per_dag:
+            kwargs["max_active_tis_per_dag"] = max_active_tis_per_dag
+
+        with 
dag_maker(dag_id="test_backfill_respect_max_active_tis_per_dag_limit", 
schedule="@daily") as dag:
+            EmptyOperator.partial(task_id="task1", 
**kwargs).expand_kwargs([{"x": i} for i in range(10)])
+
+        dag_maker.create_dagrun(state=None)
+
+        executor = MockExecutor()
+
+        job = Job(executor=executor)
+        job_runner = BackfillJobRunner(
+            job=job,
+            dag=dag,
+            start_date=DEFAULT_DATE,
+            end_date=DEFAULT_DATE + datetime.timedelta(days=7),
+        )
+
+        run_job(job=job, execute_callable=job_runner._execute)
+
+        assert len(executor.history) > 0
+
+        task_concurrency_limit_reached_at_least_once = False
+
+        def get_running_tis_per_dagrun(running_tis):
+            running_tis_per_dagrun_dict = defaultdict(int)
+            for running_ti in running_tis:
+                running_tis_per_dagrun_dict[running_ti[3].dag_run.id] += 1
+            return running_tis_per_dagrun_dict
+
+        num_running_task_instances = 0
+        for running_task_instances in executor.history:
+            if with_max_active_tis_per_dag:
+                assert len(running_task_instances) <= max_active_tis_per_dag
+            running_tis_per_dagrun_dict = 
get_running_tis_per_dagrun(running_task_instances)
+            assert all(
+                [
+                    num_running_tis <= max_active_tis_per_dagrun
+                    for num_running_tis in running_tis_per_dagrun_dict.values()
+                ]
+            )
+            num_running_task_instances += len(running_task_instances)
+            task_concurrency_limit_reached_at_least_once = (
+                task_concurrency_limit_reached_at_least_once
+                or any(
+                    [
+                        num_running_tis == max_active_tis_per_dagrun
+                        for num_running_tis in 
running_tis_per_dagrun_dict.values()
+                    ]
+                )
+            )
+
+        assert 80 == num_running_task_instances  # (7 backfill run + 1 manual 
run ) * 10 mapped task per run
+        assert task_concurrency_limit_reached_at_least_once
+
+        times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
+            mock_log.debug,
+            DagConcurrencyLimitReached,
+        )
+
+        times_pool_limit_reached_in_debug = self._times_called_with(
+            mock_log.debug,
+            NoAvailablePoolSlot,
+        )
+
+        times_task_concurrency_limit_reached_in_debug = 
self._times_called_with(
+            mock_log.debug,
+            TaskConcurrencyLimitReached,
+        )
+
+        assert 0 == times_pool_limit_reached_in_debug
+        assert 0 == times_dag_concurrency_limit_reached_in_debug
+        assert times_task_concurrency_limit_reached_in_debug > 0
+
     @patch("airflow.jobs.backfill_job_runner.BackfillJobRunner.log")
     def test_backfill_respect_dag_concurrency_limit(self, mock_log, dag_maker):
         dag = self._get_dummy_dag(dag_maker, 
dag_id="test_backfill_respect_concurrency_limit")
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 4870a2aa13..fafc5ee2b4 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -1264,6 +1264,68 @@ class TestSchedulerJob:
 
         session.rollback()
 
+    def 
test_find_executable_task_instances_task_concurrency_per_dagrun_for_first(self, 
dag_maker):
+        scheduler_job = Job()
+        self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
+        session = settings.Session()
+
+        dag_id = 
"SchedulerJobTest.test_find_executable_task_instances_task_concurrency_per_dagrun_for_first"
+
+        with dag_maker(dag_id=dag_id):
+            op1a = EmptyOperator(task_id="dummy1-a", priority_weight=2, 
max_active_tis_per_dagrun=1)
+            op1b = EmptyOperator(task_id="dummy1-b", priority_weight=1)
+        dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
+
+        ti1a = dr1.get_task_instance(op1a.task_id, session)
+        ti1b = dr1.get_task_instance(op1b.task_id, session)
+        ti2a = dr2.get_task_instance(op1a.task_id, session)
+        ti1a.state = State.RUNNING
+        ti1b.state = State.SCHEDULED
+        ti2a.state = State.SCHEDULED
+        session.flush()
+
+        # Schedule ti with higher priority,
+        # because it's running in a different DAG run with 0 active tis
+        res = self.job_runner._executable_task_instances_to_queued(max_tis=1, 
session=session)
+        assert 1 == len(res)
+        assert res[0].key == ti2a.key
+
+        session.rollback()
+
+    def 
test_find_executable_task_instances_not_enough_task_concurrency_per_dagrun_for_first(self,
 dag_maker):
+        scheduler_job = Job()
+        self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
+        session = settings.Session()
+
+        dag_id = (
+            "SchedulerJobTest"
+            
".test_find_executable_task_instances_not_enough_task_concurrency_per_dagrun_for_first"
+        )
+
+        with dag_maker(dag_id=dag_id):
+            op1a = EmptyOperator.partial(
+                task_id="dummy1-a", priority_weight=2, 
max_active_tis_per_dagrun=1
+            ).expand_kwargs([{"inputs": 1}, {"inputs": 2}])
+            op1b = EmptyOperator(task_id="dummy1-b", priority_weight=1)
+        dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+        ti1a0 = dr.get_task_instance(op1a.task_id, session, map_index=0)
+        ti1a1 = dr.get_task_instance(op1a.task_id, session, map_index=1)
+        ti1b = dr.get_task_instance(op1b.task_id, session)
+        ti1a0.state = State.RUNNING
+        ti1a1.state = State.SCHEDULED
+        ti1b.state = State.SCHEDULED
+        session.flush()
+
+        # Schedule ti with lower priority,
+        # because the one with higher priority is limited by a concurrency 
limit
+        res = self.job_runner._executable_task_instances_to_queued(max_tis=1, 
session=session)
+        assert 1 == len(res)
+        assert res[0].key == ti1b.key
+
+        session.rollback()
+
     def test_find_executable_task_instances_negative_open_pool_slots(self, 
dag_maker):
         """
         Pools with negative open slots should not block other pools.
@@ -1419,7 +1481,9 @@ class TestSchedulerJob:
         session.flush()
 
         assert State.RUNNING == dr1.state
-        assert 2 == DAG.get_num_task_instances(dag_id, dag.task_ids, 
states=[State.RUNNING], session=session)
+        assert 2 == DAG.get_num_task_instances(
+            dag_id, task_ids=dag.task_ids, states=[State.RUNNING], 
session=session
+        )
 
         # create second dag run
         dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
@@ -1440,7 +1504,7 @@ class TestSchedulerJob:
         ti3.refresh_from_db()
         ti4.refresh_from_db()
         assert 3 == DAG.get_num_task_instances(
-            dag_id, dag.task_ids, states=[State.RUNNING, State.QUEUED], 
session=session
+            dag_id, task_ids=dag.task_ids, states=[State.RUNNING, 
State.QUEUED], session=session
         )
         assert State.RUNNING == ti1.state
         assert State.RUNNING == ti2.state
@@ -4127,6 +4191,54 @@ class TestSchedulerJob:
             session.refresh(ti)
             assert ti.state == State.SCHEDULED
 
+    @pytest.mark.parametrize(
+        "state,start_date,end_date",
+        [
+            [State.NONE, None, None],
+            [
+                State.UP_FOR_RETRY,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+            [
+                State.UP_FOR_RESCHEDULE,
+                timezone.utcnow() - datetime.timedelta(minutes=30),
+                timezone.utcnow() - datetime.timedelta(minutes=15),
+            ],
+        ],
+    )
+    def 
test_dag_file_processor_process_task_instances_with_max_active_tis_per_dagrun(
+        self, state, start_date, end_date, dag_maker
+    ):
+        """
+        Test if _process_task_instances puts the right task instances into the
+        mock_list.
+        """
+        with 
dag_maker(dag_id="test_scheduler_process_execute_task_with_max_active_tis_per_dagrun"):
+            BashOperator(task_id="dummy", max_active_tis_per_dagrun=2, 
bash_command="echo Hi")
+
+        scheduler_job = Job()
+        self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
+
+        self.job_runner.processor_agent = mock.MagicMock()
+
+        dr = dag_maker.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+        )
+        assert dr is not None
+
+        with create_session() as session:
+            ti = dr.get_task_instances(session=session)[0]
+            ti.state = state
+            ti.start_date = start_date
+            ti.end_date = end_date
+
+            self.job_runner._schedule_dag_run(dr, session)
+            assert 
session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 1
+
+            session.refresh(ti)
+            assert ti.state == State.SCHEDULED
+
     @pytest.mark.parametrize(
         "state, start_date, end_date",
         [
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 571dde87a6..b4ff9a0bed 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -461,18 +461,22 @@ class TestDag:
         session.merge(ti4)
         session.commit()
 
-        assert 0 == DAG.get_num_task_instances(test_dag_id, ["fakename"], 
session=session)
-        assert 4 == DAG.get_num_task_instances(test_dag_id, [test_task_id], 
session=session)
-        assert 4 == DAG.get_num_task_instances(test_dag_id, ["fakename", 
test_task_id], session=session)
-        assert 1 == DAG.get_num_task_instances(test_dag_id, [test_task_id], 
states=[None], session=session)
+        assert 0 == DAG.get_num_task_instances(test_dag_id, 
task_ids=["fakename"], session=session)
+        assert 4 == DAG.get_num_task_instances(test_dag_id, 
task_ids=[test_task_id], session=session)
+        assert 4 == DAG.get_num_task_instances(
+            test_dag_id, task_ids=["fakename", test_task_id], session=session
+        )
+        assert 1 == DAG.get_num_task_instances(
+            test_dag_id, task_ids=[test_task_id], states=[None], 
session=session
+        )
         assert 2 == DAG.get_num_task_instances(
-            test_dag_id, [test_task_id], states=[State.RUNNING], 
session=session
+            test_dag_id, task_ids=[test_task_id], states=[State.RUNNING], 
session=session
         )
         assert 3 == DAG.get_num_task_instances(
-            test_dag_id, [test_task_id], states=[None, State.RUNNING], 
session=session
+            test_dag_id, task_ids=[test_task_id], states=[None, 
State.RUNNING], session=session
         )
         assert 4 == DAG.get_num_task_instances(
-            test_dag_id, [test_task_id], states=[None, State.QUEUED, 
State.RUNNING], session=session
+            test_dag_id, task_ids=[test_task_id], states=[None, State.QUEUED, 
State.RUNNING], session=session
         )
         session.close()
 
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 2d86cdfde5..fb91982c80 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -288,6 +288,19 @@ class TestTaskInstance:
         ti.run()
         assert ti.state == State.NONE
 
+    def test_requeue_over_max_active_tis_per_dagrun(self, 
create_task_instance):
+        ti = create_task_instance(
+            dag_id="test_requeue_over_max_active_tis_per_dagrun",
+            task_id="test_requeue_over_max_active_tis_per_dagrun_op",
+            max_active_tis_per_dagrun=0,
+            max_active_runs=1,
+            max_active_tasks=2,
+            dagrun_state=State.QUEUED,
+        )
+
+        ti.run()
+        assert ti.state == State.NONE
+
     def test_requeue_over_pool_concurrency(self, create_task_instance, 
test_pool):
         ti = create_task_instance(
             dag_id="test_requeue_over_pool_concurrency",
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index 2d9a5fa72e..b025a61ae5 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1189,6 +1189,7 @@ class TestStringifiedDAGs:
             "ignore_first_depends_on_past": True,
             "inlets": [],
             "max_active_tis_per_dag": None,
+            "max_active_tis_per_dagrun": None,
             "max_retry_delay": None,
             "on_execute_callback": None,
             "on_failure_callback": None,
diff --git a/tests/ti_deps/deps/test_task_concurrency.py 
b/tests/ti_deps/deps/test_task_concurrency.py
index 5d208a650c..aa6c8e116c 100644
--- a/tests/ti_deps/deps/test_task_concurrency.py
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -20,6 +20,8 @@ from __future__ import annotations
 from datetime import datetime
 from unittest.mock import Mock
 
+import pytest
+
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator
 from airflow.ti_deps.dep_context import DepContext
@@ -30,24 +32,25 @@ class TestTaskConcurrencyDep:
     def _get_task(self, **kwargs):
         return BaseOperator(task_id="test_task", dag=DAG("test_dag"), **kwargs)
 
-    def test_not_task_concurrency(self):
-        task = self._get_task(start_date=datetime(2016, 1, 1))
-        dep_context = DepContext()
-        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
-        assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
-
-    def test_not_reached_concurrency(self):
-        task = self._get_task(start_date=datetime(2016, 1, 1), 
max_active_tis_per_dag=1)
-        dep_context = DepContext()
-        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
-        ti.get_num_running_task_instances = lambda x: 0
-        assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
-
-    def test_reached_concurrency(self):
-        task = self._get_task(start_date=datetime(2016, 1, 1), 
max_active_tis_per_dag=2)
+    @pytest.mark.parametrize(
+        "kwargs, num_running_tis, is_task_concurrency_dep_met",
+        [
+            ({}, None, True),
+            ({"max_active_tis_per_dag": 1}, 0, True),
+            ({"max_active_tis_per_dag": 2}, 1, True),
+            ({"max_active_tis_per_dag": 2}, 2, False),
+            ({"max_active_tis_per_dagrun": 2}, 1, True),
+            ({"max_active_tis_per_dagrun": 2}, 2, False),
+            ({"max_active_tis_per_dag": 2, "max_active_tis_per_dagrun": 2}, 1, 
True),
+            ({"max_active_tis_per_dag": 1, "max_active_tis_per_dagrun": 2}, 1, 
False),
+            ({"max_active_tis_per_dag": 2, "max_active_tis_per_dagrun": 1}, 1, 
False),
+            ({"max_active_tis_per_dag": 1, "max_active_tis_per_dagrun": 1}, 1, 
False),
+        ],
+    )
+    def test_concurrency(self, kwargs, num_running_tis, 
is_task_concurrency_dep_met):
+        task = self._get_task(start_date=datetime(2016, 1, 1), **kwargs)
         dep_context = DepContext()
         ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
-        ti.get_num_running_task_instances = lambda x: 1
-        assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
-        ti.get_num_running_task_instances = lambda x: 2
-        assert not TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
+        if num_running_tis is not None:
+            ti.get_num_running_task_instances.return_value = num_running_tis
+        assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context) == 
is_task_concurrency_dep_met


Reply via email to