kaxil commented on code in PR #54103:
URL: https://github.com/apache/airflow/pull/54103#discussion_r2705400293
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -490,10 +515,60 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
.where(~DM.is_paused)
.where(TI.state == TaskInstanceState.SCHEDULED)
.where(DM.bundle_name.is_not(None))
+ .join(
+ dr_task_concurrency_subquery,
+ and_(
+ TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
+ TI.run_id == dr_task_concurrency_subquery.c.run_id,
+ ),
+ isouter=True,
+ )
+ .where(
+
func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) <
DM.max_active_tasks
+ )
.options(selectinload(TI.dag_model))
.order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
)
+ # Create a subquery with row numbers partitioned by dag_id and
run_id.
+ # Different dags can have the same run_id but
+ # the dag_id combined with the run_id uniquely identify a run.
+ ranked_query = (
+ query.add_columns(
+ func.row_number()
+ .over(
+ partition_by=[TI.dag_id, TI.run_id],
+ order_by=[-TI.priority_weight, DR.logical_date,
TI.map_index],
+ )
+ .label("row_num"),
+ DM.max_active_tasks.label("dr_max_active_tasks"),
+ # Create columns for the order_by checks here for sqlite.
+ TI.priority_weight.label("priority_weight_for_ordering"),
+ DR.logical_date.label("logical_date_for_ordering"),
+ TI.map_index.label("map_index_for_ordering"),
+ )
+ ).subquery()
+
+ # Select only rows where row_number <= max_active_tasks.
+ query = (
+ select(TI)
+ .select_from(ranked_query)
+ .join(
+ TI,
+ (TI.dag_id == ranked_query.c.dag_id)
+ & (TI.task_id == ranked_query.c.task_id)
+ & (TI.run_id == ranked_query.c.run_id)
+ & (TI.map_index == ranked_query.c.map_index),
+ )
+ .where(ranked_query.c.row_num <=
ranked_query.c.dr_max_active_tasks)
+ # Add the order_by columns from the ranked query for sqlite.
+ .order_by(
+ -ranked_query.c.priority_weight_for_ordering,
+ ranked_query.c.logical_date_for_ordering,
+ ranked_query.c.map_index_for_ordering,
+ )
+ )
Review Comment:
Not sure what happened to my earlier comment so re-adding: This new query is
missing `.options(selectinload(TI.dag_model))` and `.with_hint(TI, "USE INDEX
(ti_state)", dialect_name="mysql")` which were on the original query. Every
`ti.dag_model` access later will trigger a separate query - with 50 TIs that's
50+ extra queries per loop.
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -194,6 +207,16 @@ def _is_parent_process() -> bool:
return multiprocessing.current_process().name == "MainProcess"
+def _get_current_dr_task_concurrency(states: Iterable[TaskInstanceState]) ->
Subquery:
+ """Get the dag_run IDs and how many tasks are in the provided states for
each one."""
Review Comment:
This queries the same data as `ConcurrencyMap.load()` which is still called
and used for the check at lines ~680-695. Worth adding a comment explaining why
we keep both? (race condition protection between query time and check time?)
##########
airflow-core/tests/unit/jobs/test_scheduler_job.py:
##########
@@ -194,6 +194,52 @@ def _create_dagrun(
return _create_dagrun
+def task_maker(
Review Comment:
Couple test cases worth adding:
1. **Starvation filter ordering**: dag run with tasks in mixed pools (some
starved, some not). Verify non-starved pool tasks aren't excluded because
starved-pool tasks consumed row_number slots.
2. **Partial capacity**: dag run with `max_active_tasks=4` where 2 are
already RUNNING + 10 SCHEDULED. Verify query returns only 2 (not 4) for that
run.
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -490,10 +515,60 @@ def _executable_task_instances_to_queued(self, max_tis:
int, session: Session) -
.where(~DM.is_paused)
.where(TI.state == TaskInstanceState.SCHEDULED)
.where(DM.bundle_name.is_not(None))
+ .join(
+ dr_task_concurrency_subquery,
+ and_(
+ TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
+ TI.run_id == dr_task_concurrency_subquery.c.run_id,
+ ),
+ isouter=True,
+ )
+ .where(
+
func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) <
DM.max_active_tasks
+ )
.options(selectinload(TI.dag_model))
.order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
)
+ # Create a subquery with row numbers partitioned by dag_id and
run_id.
+ # Different dags can have the same run_id but
+ # the dag_id combined with the run_id uniquely identify a run.
+ ranked_query = (
+ query.add_columns(
+ func.row_number()
+ .over(
+ partition_by=[TI.dag_id, TI.run_id],
+ order_by=[-TI.priority_weight, DR.logical_date,
TI.map_index],
+ )
+ .label("row_num"),
+ DM.max_active_tasks.label("dr_max_active_tasks"),
+ # Create columns for the order_by checks here for sqlite.
+ TI.priority_weight.label("priority_weight_for_ordering"),
+ DR.logical_date.label("logical_date_for_ordering"),
+ TI.map_index.label("map_index_for_ordering"),
+ )
+ ).subquery()
Review Comment:
Row numbers are assigned here before starvation filters (starved_pools,
starved_dags, etc) are applied below. In the original code, those filters were
applied BEFORE the limit. Tasks in starved pools will now consume row_number
slots and then get filtered out, potentially excluding schedulable tasks from
the same dag run.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]