This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1ebeb19bf7 Add `max_active_tis_per_dagrun` for Dynamic Task Mapping
(#29094)
1ebeb19bf7 is described below
commit 1ebeb19bf7542850fff2f1e2f9795ad70c1b24e2
Author: Hussein Awala <[email protected]>
AuthorDate: Fri Apr 14 12:36:08 2023 +0200
Add `max_active_tis_per_dagrun` for Dynamic Task Mapping (#29094)
* add max_active_tis_per_dagrun param to BaseOperator
* set has_task_concurrency_limits when max_active_tis_per_dagrun is not None
* check if max_active_tis_per_dagrun is reached in the task deps
* check if all the tasks have None max_active_tis_per_dagrun before auto
schedule the dagrun
* check if the max_active_tis_per_dagrun is reached before queuing the ti
* check max_active_tis_per_dagrun in backfill job
* fix current tests and ensure everything is ok before adding new tests
* refacto TestTaskConcurrencyDep
* fix a bug in TaskConcurrencyDep
* test max_active_tis_per_dagrun in TaskConcurrencyDep
* tests max_active_tis_per_dagrun in TestTaskInstance
* test dag_file_processor with max_active_tis_per_dagrun
* test scheduling with max_active_tis_per_dagrun on different DAG runs
* test scheduling mapped task with max_active_tis_per_dagrun
* test max_active_tis_per_dagrun with backfill CLI
* add new starved_tasks filter to avoid affecting the scheduling perf
* unify the usage of TaskInstance filters and use TI
* refacto concurrecy map type and create a new dataclass
* move docstring to ConcurrencyMap class and create a method for
default_factory
* move concurrency_map creation to ConcurrencyMap class
* replace default dicts by counters
* replace all default dicts by counters in the scheduler_job_runner module
* suggestions from review
---
airflow/jobs/backfill_job_runner.py | 16 +++-
airflow/jobs/scheduler_job_runner.py | 126 ++++++++++++++++++--------
airflow/models/baseoperator.py | 6 ++
airflow/models/dag.py | 12 ++-
airflow/models/dagrun.py | 1 +
airflow/models/mappedoperator.py | 4 +
airflow/models/taskinstance.py | 17 ++--
airflow/ti_deps/deps/task_concurrency_dep.py | 17 +++-
tests/conftest.py | 2 +
tests/jobs/test_backfill_job.py | 83 +++++++++++++++++
tests/jobs/test_scheduler_job.py | 116 +++++++++++++++++++++++-
tests/models/test_dag.py | 18 ++--
tests/models/test_taskinstance.py | 13 +++
tests/serialization/test_dag_serialization.py | 1 +
tests/ti_deps/deps/test_task_concurrency.py | 41 +++++----
15 files changed, 390 insertions(+), 83 deletions(-)
diff --git a/airflow/jobs/backfill_job_runner.py
b/airflow/jobs/backfill_job_runner.py
index c99cae2d21..4a78890d3b 100644
--- a/airflow/jobs/backfill_job_runner.py
+++ b/airflow/jobs/backfill_job_runner.py
@@ -618,7 +618,7 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin):
"Not scheduling since DAG max_active_tasks
limit is reached."
)
- if task.max_active_tis_per_dag:
+ if task.max_active_tis_per_dag is not None:
num_running_task_instances_in_task =
DAG.get_num_task_instances(
dag_id=self.dag_id,
task_ids=[task.task_id],
@@ -631,6 +631,20 @@ class BackfillJobRunner(BaseJobRunner, LoggingMixin):
"Not scheduling since Task concurrency
limit is reached."
)
+ if task.max_active_tis_per_dagrun is not None:
+ num_running_task_instances_in_task_dagrun =
DAG.get_num_task_instances(
+ dag_id=self.dag_id,
+ run_id=ti.run_id,
+ task_ids=[task.task_id],
+ states=self.STATES_COUNT_AS_RUNNING,
+ session=session,
+ )
+
+ if num_running_task_instances_in_task_dagrun >=
task.max_active_tis_per_dagrun:
+ raise TaskConcurrencyLimitReached(
+ "Not scheduling since Task concurrency per
DAG run limit is reached."
+ )
+
_per_task_process(key, ti, session)
session.commit()
except (NoAvailablePoolSlot, DagConcurrencyLimitReached,
TaskConcurrencyLimitReached) as e:
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 706d980253..aae98373b6 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -25,10 +25,11 @@ import signal
import sys
import time
import warnings
-from collections import defaultdict
+from collections import Counter
+from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
-from typing import TYPE_CHECKING, Collection, DefaultDict, Iterator
+from typing import TYPE_CHECKING, Collection, Iterable, Iterator
from sqlalchemy import and_, func, not_, or_, text
from sqlalchemy.exc import OperationalError
@@ -85,6 +86,29 @@ DR = DagRun
DM = DagModel
+@dataclass
+class ConcurrencyMap:
+ """
+ Dataclass to represent concurrency maps
+
+ It contains a map from (dag_id, task_id) to # of task instances, a map
from (dag_id, task_id)
+ to # of task instances in the given state list and a map from (dag_id,
run_id, task_id)
+ to # of task instances in the given state list in each DAG run.
+ """
+
+ dag_active_tasks_map: dict[str, int]
+ task_concurrency_map: dict[tuple[str, str], int]
+ task_dagrun_concurrency_map: dict[tuple[str, str, str], int]
+
+ @classmethod
+ def from_concurrency_map(cls, mapping: dict[tuple[str, str, str], int]) ->
ConcurrencyMap:
+ instance = cls(Counter(), Counter(), Counter(mapping))
+ for (d, r, t), c in mapping.items():
+ instance.dag_active_tasks_map[d] += c
+ instance.task_concurrency_map[(d, t)] += c
+ return instance
+
+
def _is_parent_process() -> bool:
"""
Whether this is a parent process.
@@ -231,28 +255,21 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
< scheduler_health_check_threshold
)
- def __get_concurrency_maps(
- self, states: list[TaskInstanceState], session: Session
- ) -> tuple[DefaultDict[str, int], DefaultDict[tuple[str, str], int]]:
+ def __get_concurrency_maps(self, states: Iterable[TaskInstanceState],
session: Session) -> ConcurrencyMap:
"""
Get the concurrency maps.
:param states: List of states to query for
- :return: A map from (dag_id, task_id) to # of task instances and
- a map from (dag_id, task_id) to # of task instances in the given
state list
+ :return: Concurrency map
"""
- ti_concurrency_query: list[tuple[str, str, int]] = (
- session.query(TI.task_id, TI.dag_id, func.count("*"))
+ ti_concurrency_query: list[tuple[str, str, str, int]] = (
+ session.query(TI.task_id, TI.run_id, TI.dag_id, func.count("*"))
.filter(TI.state.in_(states))
- .group_by(TI.task_id, TI.dag_id)
- ).all()
- dag_map: DefaultDict[str, int] = defaultdict(int)
- task_map: DefaultDict[tuple[str, str], int] = defaultdict(int)
- for result in ti_concurrency_query:
- task_id, dag_id, count = result
- dag_map[dag_id] += count
- task_map[(dag_id, task_id)] = count
- return dag_map, task_map
+ .group_by(TI.task_id, TI.run_id, TI.dag_id)
+ )
+ return ConcurrencyMap.from_concurrency_map(
+ {(dag_id, run_id, task_id): count for task_id, run_id, dag_id,
count in ti_concurrency_query}
+ )
def _executable_task_instances_to_queued(self, max_tis: int, session:
Session) -> list[TI]:
"""
@@ -263,6 +280,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
- DAG max_active_tasks
- executor state
- priority
+ - max active tis per DAG
+ - max active tis per DAG run
:param max_tis: Maximum number of TIs to queue in this loop.
:return: list[airflow.models.TaskInstance]
@@ -304,11 +323,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
starved_pools = {pool_name for pool_name, stats in pools.items() if
stats["open"] <= 0}
# dag_id to # of running tasks and (dag_id, task_id) to # of running
tasks.
- dag_active_tasks_map: DefaultDict[str, int]
- task_concurrency_map: DefaultDict[tuple[str, str], int]
- dag_active_tasks_map, task_concurrency_map =
self.__get_concurrency_maps(
- states=list(EXECUTION_STATES), session=session
- )
+ concurrency_map = self.__get_concurrency_maps(states=EXECUTION_STATES,
session=session)
# Number of tasks that cannot be scheduled because of no open slot in
pool
num_starving_tasks_total = 0
@@ -316,14 +331,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# dag and task ids that can't be queued because of concurrency limits
starved_dags: set[str] = set()
starved_tasks: set[tuple[str, str]] = set()
+ starved_tasks_task_dagrun_concurrency: set[tuple[str, str, str]] =
set()
- pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)
+ pool_num_starving_tasks: dict[str, int] = Counter()
for loop_count in itertools.count(start=1):
num_starved_pools = len(starved_pools)
num_starved_dags = len(starved_dags)
num_starved_tasks = len(starved_tasks)
+ num_starved_tasks_task_dagrun_concurrency =
len(starved_tasks_task_dagrun_concurrency)
# Get task instances associated with scheduled
# DagRuns which are not backfilled, in the given states,
@@ -347,7 +364,14 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
query = query.filter(not_(TI.dag_id.in_(starved_dags)))
if starved_tasks:
- task_filter = tuple_in_condition((TaskInstance.dag_id,
TaskInstance.task_id), starved_tasks)
+ task_filter = tuple_in_condition((TI.dag_id, TI.task_id),
starved_tasks)
+ query = query.filter(not_(task_filter))
+
+ if starved_tasks_task_dagrun_concurrency:
+ task_filter = tuple_in_condition(
+ (TI.dag_id, TI.run_id, TI.task_id),
+ starved_tasks_task_dagrun_concurrency,
+ )
query = query.filter(not_(task_filter))
query = query.limit(max_tis)
@@ -439,7 +463,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# reached.
dag_id = task_instance.dag_id
- current_active_tasks_per_dag = dag_active_tasks_map[dag_id]
+ current_active_tasks_per_dag =
concurrency_map.dag_active_tasks_map[dag_id]
max_active_tasks_per_dag_limit =
task_instance.dag_model.max_active_tasks
self.log.info(
"DAG %s has %s/%s running and queued tasks",
@@ -481,7 +505,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
).max_active_tis_per_dag
if task_concurrency_limit is not None:
- current_task_concurrency = task_concurrency_map[
+ current_task_concurrency =
concurrency_map.task_concurrency_map[
(task_instance.dag_id, task_instance.task_id)
]
@@ -494,10 +518,35 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
starved_tasks.add((task_instance.dag_id,
task_instance.task_id))
continue
+ task_dagrun_concurrency_limit: int | None = None
+ if serialized_dag.has_task(task_instance.task_id):
+ task_dagrun_concurrency_limit =
serialized_dag.get_task(
+ task_instance.task_id
+ ).max_active_tis_per_dagrun
+
+ if task_dagrun_concurrency_limit is not None:
+ current_task_dagrun_concurrency =
concurrency_map.task_dagrun_concurrency_map[
+ (task_instance.dag_id, task_instance.run_id,
task_instance.task_id)
+ ]
+
+ if current_task_dagrun_concurrency >=
task_dagrun_concurrency_limit:
+ self.log.info(
+ "Not executing %s since the task concurrency
per DAG run for"
+ " this task has been reached.",
+ task_instance,
+ )
+ starved_tasks_task_dagrun_concurrency.add(
+ (task_instance.dag_id, task_instance.run_id,
task_instance.task_id)
+ )
+ continue
+
executable_tis.append(task_instance)
open_slots -= task_instance.pool_slots
- dag_active_tasks_map[dag_id] += 1
- task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
+ concurrency_map.dag_active_tasks_map[dag_id] += 1
+ concurrency_map.task_concurrency_map[(task_instance.dag_id,
task_instance.task_id)] += 1
+ concurrency_map.task_dagrun_concurrency_map[
+ (task_instance.dag_id, task_instance.run_id,
task_instance.task_id)
+ ] += 1
pool_stats["open"] = open_slots
@@ -507,6 +556,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
len(starved_pools) > num_starved_pools
or len(starved_dags) > num_starved_dags
or len(starved_tasks) > num_starved_tasks
+ or len(starved_tasks_task_dagrun_concurrency) >
num_starved_tasks_task_dagrun_concurrency
)
if is_done or not found_new_filters:
@@ -816,13 +866,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
paused_runs = (
session.query(DagRun)
.join(DagRun.dag_model)
- .join(TaskInstance)
+ .join(TI)
.filter(
DagModel.is_paused == expression.true(),
DagRun.state == DagRunState.RUNNING,
DagRun.run_type != DagRunType.BACKFILL_JOB,
)
- .having(DagRun.last_scheduling_decision <=
func.max(TaskInstance.updated_at))
+ .having(DagRun.last_scheduling_decision <=
func.max(TI.updated_at))
.group_by(DagRun)
)
for dag_run in paused_runs:
@@ -1079,8 +1129,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
.all()
)
- active_runs_of_dags = defaultdict(
- int,
+ active_runs_of_dags = Counter(
DagRun.active_runs_of_dags(dag_ids=(dm.dag_id for dm in
dag_models), session=session),
)
@@ -1237,8 +1286,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
"""Find DagRuns in queued state and decide moving them to running
state."""
dag_runs = self._get_next_dagruns_to_examine(DagRunState.QUEUED,
session)
- active_runs_of_dags = defaultdict(
- int,
+ active_runs_of_dags = Counter(
DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs),
only_running=True, session=session),
)
@@ -1533,10 +1581,10 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
or execution timeout has passed, so they can be marked as failed.
"""
num_timed_out_tasks = (
- session.query(TaskInstance)
+ session.query(TI)
.filter(
- TaskInstance.state == TaskInstanceState.DEFERRED,
- TaskInstance.trigger_timeout < timezone.utcnow(),
+ TI.state == TaskInstanceState.DEFERRED,
+ TI.trigger_timeout < timezone.utcnow(),
)
.update(
# We have to schedule these to fail themselves so it doesn't
@@ -1599,7 +1647,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
@staticmethod
- def _generate_zombie_message_details(ti: TaskInstance):
+ def _generate_zombie_message_details(ti: TI):
zombie_message_details = {
"DAG Id": ti.dag_id,
"Task Id": ti.task_id,
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 1760bd1f42..d81d45dc7e 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -214,6 +214,7 @@ def partial(
weight_rule: str = DEFAULT_WEIGHT_RULE,
sla: timedelta | None = None,
max_active_tis_per_dag: int | None = None,
+ max_active_tis_per_dagrun: int | None = None,
on_execute_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
on_failure_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback |
list[TaskStateChangeCallback] = None,
@@ -274,6 +275,7 @@ def partial(
partial_kwargs.setdefault("weight_rule", weight_rule)
partial_kwargs.setdefault("sla", sla)
partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
+ partial_kwargs.setdefault("max_active_tis_per_dagrun",
max_active_tis_per_dagrun)
partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
@@ -578,6 +580,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
:param run_as_user: unix username to impersonate while running the task
:param max_active_tis_per_dag: When set, a task will be able to limit the
concurrent
runs across execution_dates.
+ :param max_active_tis_per_dagrun: When set, a task will be able to limit
the concurrent
+ task instances per DAG run.
:param executor_config: Additional task-level configuration parameters
that are
interpreted by a specific executor. Parameters are namespaced by the
name of
executor.
@@ -729,6 +733,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
run_as_user: str | None = None,
task_concurrency: int | None = None,
max_active_tis_per_dag: int | None = None,
+ max_active_tis_per_dagrun: int | None = None,
executor_config: dict | None = None,
do_xcom_push: bool = True,
inlets: Any | None = None,
@@ -872,6 +877,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
)
max_active_tis_per_dag = task_concurrency
self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
+ self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
self.do_xcom_push = do_xcom_push
self.doc_md = doc_md
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8b259cd3e8..9888e5cbd8 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2789,7 +2789,10 @@ class DAG(LoggingMixin):
orm_dag.description = dag.description
orm_dag.max_active_tasks = dag.max_active_tasks
orm_dag.max_active_runs = dag.max_active_runs
- orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag
is not None for t in dag.tasks)
+ orm_dag.has_task_concurrency_limits = any(
+ t.max_active_tis_per_dag is not None or
t.max_active_tis_per_dagrun is not None
+ for t in dag.tasks
+ )
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
orm_dag.processor_subdir = processor_subdir
@@ -2990,12 +2993,13 @@ class DAG(LoggingMixin):
@staticmethod
@provide_session
- def get_num_task_instances(dag_id, task_ids=None, states=None,
session=NEW_SESSION) -> int:
+ def get_num_task_instances(dag_id, run_id=None, task_ids=None,
states=None, session=NEW_SESSION) -> int:
"""
Returns the number of task instances in the given DAG.
:param session: ORM session
:param dag_id: ID of the DAG to get the task concurrency of
+ :param run_id: ID of the DAG run to get the task concurrency of
:param task_ids: A list of valid task IDs for the given DAG
:param states: A list of states to filter by if supplied
:return: The number of running tasks
@@ -3003,6 +3007,10 @@ class DAG(LoggingMixin):
qry = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == dag_id,
)
+ if run_id:
+ qry = qry.filter(
+ TaskInstance.run_id == run_id,
+ )
if task_ids:
qry = qry.filter(
TaskInstance.task_id.in_(task_ids),
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 101bda0a47..edb0ec78ac 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -553,6 +553,7 @@ class DagRun(Base, LoggingMixin):
bool(self.tis)
and all(not t.task.depends_on_past for t in self.tis)
and all(t.task.max_active_tis_per_dag is None for t in
self.tis)
+ and all(t.task.max_active_tis_per_dagrun is None for t in
self.tis)
and all(t.state != TaskInstanceState.DEFERRED for t in
self.tis)
)
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index a10fd10cdc..345329ef6d 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -451,6 +451,10 @@ class MappedOperator(AbstractOperator):
def max_active_tis_per_dag(self) -> int | None:
return self.partial_kwargs.get("max_active_tis_per_dag")
+ @property
+ def max_active_tis_per_dagrun(self) -> int | None:
+ return self.partial_kwargs.get("max_active_tis_per_dagrun")
+
@property
def resources(self) -> Resources | None:
return self.partial_kwargs.get("resources")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c2ed7b7868..b02076c076 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2479,18 +2479,17 @@ class TaskInstance(Base, LoggingMixin):
return LazyXComAccess.build_from_xcom_query(query)
@provide_session
- def get_num_running_task_instances(self, session: Session) -> int:
+ def get_num_running_task_instances(self, session: Session,
same_dagrun=False) -> int:
"""Return Number of running TIs from the DB"""
# .count() is inefficient
- return (
- session.query(func.count())
- .filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.state == State.RUNNING,
- )
- .scalar()
+ num_running_task_instances_query = session.query(func.count()).filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.state == State.RUNNING,
)
+ if same_dagrun:
+ num_running_task_instances_query.filter(TaskInstance.run_id ==
self.run_id)
+ return num_running_task_instances_query.scalar()
def init_run_context(self, raw: bool = False) -> None:
"""Sets the log context."""
diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py
b/airflow/ti_deps/deps/task_concurrency_dep.py
index 5b5f4f515a..1f1416214c 100644
--- a/airflow/ti_deps/deps/task_concurrency_dep.py
+++ b/airflow/ti_deps/deps/task_concurrency_dep.py
@@ -30,13 +30,22 @@ class TaskConcurrencyDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, session, dep_context):
- if ti.task.max_active_tis_per_dag is None:
+ if ti.task.max_active_tis_per_dag is None and
ti.task.max_active_tis_per_dagrun is None:
yield self._passing_status(reason="Task concurrency is not set.")
return
- if ti.get_num_running_task_instances(session) >=
ti.task.max_active_tis_per_dag:
+ if (
+ ti.task.max_active_tis_per_dag is not None
+ and ti.get_num_running_task_instances(session) >=
ti.task.max_active_tis_per_dag
+ ):
yield self._failing_status(reason="The max task concurrency has
been reached.")
return
- else:
- yield self._passing_status(reason="The max task concurrency has
not been reached.")
+ if (
+ ti.task.max_active_tis_per_dagrun is not None
+ and ti.get_num_running_task_instances(session, same_dagrun=True)
+ >= ti.task.max_active_tis_per_dagrun
+ ):
+ yield self._failing_status(reason="The max task concurrency per
run has been reached.")
return
+ yield self._passing_status(reason="The max task concurrency has not
been reached.")
+ return
diff --git a/tests/conftest.py b/tests/conftest.py
index cfae846895..f169e2b62e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -726,6 +726,7 @@ def create_dummy_dag(dag_maker):
dag_id="dag",
task_id="op1",
max_active_tis_per_dag=16,
+ max_active_tis_per_dagrun=None,
pool="default_pool",
executor_config={},
trigger_rule="all_done",
@@ -741,6 +742,7 @@ def create_dummy_dag(dag_maker):
op = EmptyOperator(
task_id=task_id,
max_active_tis_per_dag=max_active_tis_per_dag,
+ max_active_tis_per_dagrun=max_active_tis_per_dagrun,
executor_config=executor_config,
on_success_callback=on_success_callback,
on_execute_callback=on_execute_callback,
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 7001f6f1b9..164d5af1df 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -21,6 +21,7 @@ import datetime
import json
import logging
import threading
+from collections import defaultdict
from unittest import mock
from unittest.mock import patch
@@ -376,6 +377,88 @@ class TestBackfillJob:
assert 0 == times_dag_concurrency_limit_reached_in_debug
assert times_task_concurrency_limit_reached_in_debug > 0
+ @pytest.mark.parametrize("with_max_active_tis_per_dag", [False, True])
+ @patch("airflow.jobs.backfill_job_runner.BackfillJobRunner.log")
+ def test_backfill_respect_max_active_tis_per_dagrun_limit(
+ self, mock_log, dag_maker, with_max_active_tis_per_dag
+ ):
+ max_active_tis_per_dag = 3
+ max_active_tis_per_dagrun = 2
+ kwargs = {"max_active_tis_per_dagrun": max_active_tis_per_dagrun}
+ if with_max_active_tis_per_dag:
+ kwargs["max_active_tis_per_dag"] = max_active_tis_per_dag
+
+ with
dag_maker(dag_id="test_backfill_respect_max_active_tis_per_dag_limit",
schedule="@daily") as dag:
+ EmptyOperator.partial(task_id="task1",
**kwargs).expand_kwargs([{"x": i} for i in range(10)])
+
+ dag_maker.create_dagrun(state=None)
+
+ executor = MockExecutor()
+
+ job = Job(executor=executor)
+ job_runner = BackfillJobRunner(
+ job=job,
+ dag=dag,
+ start_date=DEFAULT_DATE,
+ end_date=DEFAULT_DATE + datetime.timedelta(days=7),
+ )
+
+ run_job(job=job, execute_callable=job_runner._execute)
+
+ assert len(executor.history) > 0
+
+ task_concurrency_limit_reached_at_least_once = False
+
+ def get_running_tis_per_dagrun(running_tis):
+ running_tis_per_dagrun_dict = defaultdict(int)
+ for running_ti in running_tis:
+ running_tis_per_dagrun_dict[running_ti[3].dag_run.id] += 1
+ return running_tis_per_dagrun_dict
+
+ num_running_task_instances = 0
+ for running_task_instances in executor.history:
+ if with_max_active_tis_per_dag:
+ assert len(running_task_instances) <= max_active_tis_per_dag
+ running_tis_per_dagrun_dict =
get_running_tis_per_dagrun(running_task_instances)
+ assert all(
+ [
+ num_running_tis <= max_active_tis_per_dagrun
+ for num_running_tis in running_tis_per_dagrun_dict.values()
+ ]
+ )
+ num_running_task_instances += len(running_task_instances)
+ task_concurrency_limit_reached_at_least_once = (
+ task_concurrency_limit_reached_at_least_once
+ or any(
+ [
+ num_running_tis == max_active_tis_per_dagrun
+ for num_running_tis in
running_tis_per_dagrun_dict.values()
+ ]
+ )
+ )
+
+ assert 80 == num_running_task_instances # (7 backfill run + 1 manual
run ) * 10 mapped task per run
+ assert task_concurrency_limit_reached_at_least_once
+
+ times_dag_concurrency_limit_reached_in_debug = self._times_called_with(
+ mock_log.debug,
+ DagConcurrencyLimitReached,
+ )
+
+ times_pool_limit_reached_in_debug = self._times_called_with(
+ mock_log.debug,
+ NoAvailablePoolSlot,
+ )
+
+ times_task_concurrency_limit_reached_in_debug =
self._times_called_with(
+ mock_log.debug,
+ TaskConcurrencyLimitReached,
+ )
+
+ assert 0 == times_pool_limit_reached_in_debug
+ assert 0 == times_dag_concurrency_limit_reached_in_debug
+ assert times_task_concurrency_limit_reached_in_debug > 0
+
@patch("airflow.jobs.backfill_job_runner.BackfillJobRunner.log")
def test_backfill_respect_dag_concurrency_limit(self, mock_log, dag_maker):
dag = self._get_dummy_dag(dag_maker,
dag_id="test_backfill_respect_concurrency_limit")
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 4870a2aa13..fafc5ee2b4 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -1264,6 +1264,68 @@ class TestSchedulerJob:
session.rollback()
+ def
test_find_executable_task_instances_task_concurrency_per_dagrun_for_first(self,
dag_maker):
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
+ session = settings.Session()
+
+ dag_id =
"SchedulerJobTest.test_find_executable_task_instances_task_concurrency_per_dagrun_for_first"
+
+ with dag_maker(dag_id=dag_id):
+ op1a = EmptyOperator(task_id="dummy1-a", priority_weight=2,
max_active_tis_per_dagrun=1)
+ op1b = EmptyOperator(task_id="dummy1-b", priority_weight=1)
+ dr1 = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+ dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
+
+ ti1a = dr1.get_task_instance(op1a.task_id, session)
+ ti1b = dr1.get_task_instance(op1b.task_id, session)
+ ti2a = dr2.get_task_instance(op1a.task_id, session)
+ ti1a.state = State.RUNNING
+ ti1b.state = State.SCHEDULED
+ ti2a.state = State.SCHEDULED
+ session.flush()
+
+ # Schedule ti with higher priority,
+ # because it's running in a different DAG run with 0 active tis
+ res = self.job_runner._executable_task_instances_to_queued(max_tis=1,
session=session)
+ assert 1 == len(res)
+ assert res[0].key == ti2a.key
+
+ session.rollback()
+
+ def
test_find_executable_task_instances_not_enough_task_concurrency_per_dagrun_for_first(self,
dag_maker):
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
+ session = settings.Session()
+
+ dag_id = (
+ "SchedulerJobTest"
+
".test_find_executable_task_instances_not_enough_task_concurrency_per_dagrun_for_first"
+ )
+
+ with dag_maker(dag_id=dag_id):
+ op1a = EmptyOperator.partial(
+ task_id="dummy1-a", priority_weight=2,
max_active_tis_per_dagrun=1
+ ).expand_kwargs([{"inputs": 1}, {"inputs": 2}])
+ op1b = EmptyOperator(task_id="dummy1-b", priority_weight=1)
+ dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+
+ ti1a0 = dr.get_task_instance(op1a.task_id, session, map_index=0)
+ ti1a1 = dr.get_task_instance(op1a.task_id, session, map_index=1)
+ ti1b = dr.get_task_instance(op1b.task_id, session)
+ ti1a0.state = State.RUNNING
+ ti1a1.state = State.SCHEDULED
+ ti1b.state = State.SCHEDULED
+ session.flush()
+
+ # Schedule ti with lower priority,
+ # because the one with higher priority is limited by a concurrency
limit
+ res = self.job_runner._executable_task_instances_to_queued(max_tis=1,
session=session)
+ assert 1 == len(res)
+ assert res[0].key == ti1b.key
+
+ session.rollback()
+
def test_find_executable_task_instances_negative_open_pool_slots(self,
dag_maker):
"""
Pools with negative open slots should not block other pools.
@@ -1419,7 +1481,9 @@ class TestSchedulerJob:
session.flush()
assert State.RUNNING == dr1.state
- assert 2 == DAG.get_num_task_instances(dag_id, dag.task_ids,
states=[State.RUNNING], session=session)
+ assert 2 == DAG.get_num_task_instances(
+ dag_id, task_ids=dag.task_ids, states=[State.RUNNING],
session=session
+ )
# create second dag run
dr2 = dag_maker.create_dagrun_after(dr1, run_type=DagRunType.SCHEDULED)
@@ -1440,7 +1504,7 @@ class TestSchedulerJob:
ti3.refresh_from_db()
ti4.refresh_from_db()
assert 3 == DAG.get_num_task_instances(
- dag_id, dag.task_ids, states=[State.RUNNING, State.QUEUED],
session=session
+ dag_id, task_ids=dag.task_ids, states=[State.RUNNING,
State.QUEUED], session=session
)
assert State.RUNNING == ti1.state
assert State.RUNNING == ti2.state
@@ -4127,6 +4191,54 @@ class TestSchedulerJob:
session.refresh(ti)
assert ti.state == State.SCHEDULED
+ @pytest.mark.parametrize(
+ "state,start_date,end_date",
+ [
+ [State.NONE, None, None],
+ [
+ State.UP_FOR_RETRY,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ [
+ State.UP_FOR_RESCHEDULE,
+ timezone.utcnow() - datetime.timedelta(minutes=30),
+ timezone.utcnow() - datetime.timedelta(minutes=15),
+ ],
+ ],
+ )
+ def
test_dag_file_processor_process_task_instances_with_max_active_tis_per_dagrun(
+ self, state, start_date, end_date, dag_maker
+ ):
+ """
+ Test if _process_task_instances puts the right task instances into the
+ mock_list.
+ """
+ with
dag_maker(dag_id="test_scheduler_process_execute_task_with_max_active_tis_per_dagrun"):
+ BashOperator(task_id="dummy", max_active_tis_per_dagrun=2,
bash_command="echo Hi")
+
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=scheduler_job,
subdir=os.devnull)
+
+ self.job_runner.processor_agent = mock.MagicMock()
+
+ dr = dag_maker.create_dagrun(
+ run_type=DagRunType.SCHEDULED,
+ )
+ assert dr is not None
+
+ with create_session() as session:
+ ti = dr.get_task_instances(session=session)[0]
+ ti.state = state
+ ti.start_date = start_date
+ ti.end_date = end_date
+
+ self.job_runner._schedule_dag_run(dr, session)
+ assert
session.query(TaskInstance).filter_by(state=State.SCHEDULED).count() == 1
+
+ session.refresh(ti)
+ assert ti.state == State.SCHEDULED
+
@pytest.mark.parametrize(
"state, start_date, end_date",
[
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 571dde87a6..b4ff9a0bed 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -461,18 +461,22 @@ class TestDag:
session.merge(ti4)
session.commit()
- assert 0 == DAG.get_num_task_instances(test_dag_id, ["fakename"],
session=session)
- assert 4 == DAG.get_num_task_instances(test_dag_id, [test_task_id],
session=session)
- assert 4 == DAG.get_num_task_instances(test_dag_id, ["fakename",
test_task_id], session=session)
- assert 1 == DAG.get_num_task_instances(test_dag_id, [test_task_id],
states=[None], session=session)
+ assert 0 == DAG.get_num_task_instances(test_dag_id,
task_ids=["fakename"], session=session)
+ assert 4 == DAG.get_num_task_instances(test_dag_id,
task_ids=[test_task_id], session=session)
+ assert 4 == DAG.get_num_task_instances(
+ test_dag_id, task_ids=["fakename", test_task_id], session=session
+ )
+ assert 1 == DAG.get_num_task_instances(
+ test_dag_id, task_ids=[test_task_id], states=[None],
session=session
+ )
assert 2 == DAG.get_num_task_instances(
- test_dag_id, [test_task_id], states=[State.RUNNING],
session=session
+ test_dag_id, task_ids=[test_task_id], states=[State.RUNNING],
session=session
)
assert 3 == DAG.get_num_task_instances(
- test_dag_id, [test_task_id], states=[None, State.RUNNING],
session=session
+ test_dag_id, task_ids=[test_task_id], states=[None,
State.RUNNING], session=session
)
assert 4 == DAG.get_num_task_instances(
- test_dag_id, [test_task_id], states=[None, State.QUEUED,
State.RUNNING], session=session
+ test_dag_id, task_ids=[test_task_id], states=[None, State.QUEUED,
State.RUNNING], session=session
)
session.close()
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 2d86cdfde5..fb91982c80 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -288,6 +288,19 @@ class TestTaskInstance:
ti.run()
assert ti.state == State.NONE
+ def test_requeue_over_max_active_tis_per_dagrun(self,
create_task_instance):
+ ti = create_task_instance(
+ dag_id="test_requeue_over_max_active_tis_per_dagrun",
+ task_id="test_requeue_over_max_active_tis_per_dagrun_op",
+ max_active_tis_per_dagrun=0,
+ max_active_runs=1,
+ max_active_tasks=2,
+ dagrun_state=State.QUEUED,
+ )
+
+ ti.run()
+ assert ti.state == State.NONE
+
def test_requeue_over_pool_concurrency(self, create_task_instance,
test_pool):
ti = create_task_instance(
dag_id="test_requeue_over_pool_concurrency",
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 2d9a5fa72e..b025a61ae5 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1189,6 +1189,7 @@ class TestStringifiedDAGs:
"ignore_first_depends_on_past": True,
"inlets": [],
"max_active_tis_per_dag": None,
+ "max_active_tis_per_dagrun": None,
"max_retry_delay": None,
"on_execute_callback": None,
"on_failure_callback": None,
diff --git a/tests/ti_deps/deps/test_task_concurrency.py
b/tests/ti_deps/deps/test_task_concurrency.py
index 5d208a650c..aa6c8e116c 100644
--- a/tests/ti_deps/deps/test_task_concurrency.py
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -20,6 +20,8 @@ from __future__ import annotations
from datetime import datetime
from unittest.mock import Mock
+import pytest
+
from airflow.models import DAG
from airflow.models.baseoperator import BaseOperator
from airflow.ti_deps.dep_context import DepContext
@@ -30,24 +32,25 @@ class TestTaskConcurrencyDep:
def _get_task(self, **kwargs):
return BaseOperator(task_id="test_task", dag=DAG("test_dag"), **kwargs)
- def test_not_task_concurrency(self):
- task = self._get_task(start_date=datetime(2016, 1, 1))
- dep_context = DepContext()
- ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
- assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
-
- def test_not_reached_concurrency(self):
- task = self._get_task(start_date=datetime(2016, 1, 1),
max_active_tis_per_dag=1)
- dep_context = DepContext()
- ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
- ti.get_num_running_task_instances = lambda x: 0
- assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
-
- def test_reached_concurrency(self):
- task = self._get_task(start_date=datetime(2016, 1, 1),
max_active_tis_per_dag=2)
+ @pytest.mark.parametrize(
+ "kwargs, num_running_tis, is_task_concurrency_dep_met",
+ [
+ ({}, None, True),
+ ({"max_active_tis_per_dag": 1}, 0, True),
+ ({"max_active_tis_per_dag": 2}, 1, True),
+ ({"max_active_tis_per_dag": 2}, 2, False),
+ ({"max_active_tis_per_dagrun": 2}, 1, True),
+ ({"max_active_tis_per_dagrun": 2}, 2, False),
+ ({"max_active_tis_per_dag": 2, "max_active_tis_per_dagrun": 2}, 1,
True),
+ ({"max_active_tis_per_dag": 1, "max_active_tis_per_dagrun": 2}, 1,
False),
+ ({"max_active_tis_per_dag": 2, "max_active_tis_per_dagrun": 1}, 1,
False),
+ ({"max_active_tis_per_dag": 1, "max_active_tis_per_dagrun": 1}, 1,
False),
+ ],
+ )
+ def test_concurrency(self, kwargs, num_running_tis,
is_task_concurrency_dep_met):
+ task = self._get_task(start_date=datetime(2016, 1, 1), **kwargs)
dep_context = DepContext()
ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
- ti.get_num_running_task_instances = lambda x: 1
- assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
- ti.get_num_running_task_instances = lambda x: 2
- assert not TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)
+ if num_running_tis is not None:
+ ti.get_num_running_task_instances.return_value = num_running_tis
+ assert TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context) ==
is_task_concurrency_dep_met