kaxil commented on code in PR #54103:
URL: https://github.com/apache/airflow/pull/54103#discussion_r2705351655


##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -490,10 +515,60 @@ def _executable_task_instances_to_queued(self, max_tis: 
int, session: Session) -
                 .where(~DM.is_paused)
                 .where(TI.state == TaskInstanceState.SCHEDULED)
                 .where(DM.bundle_name.is_not(None))
+                .join(
+                    dr_task_concurrency_subquery,
+                    and_(
+                        TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
+                        TI.run_id == dr_task_concurrency_subquery.c.run_id,
+                    ),
+                    isouter=True,
+                )
+                .where(
+                    
func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < 
DM.max_active_tasks
+                )
                 .options(selectinload(TI.dag_model))
                 .order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
             )
 
+            # Create a subquery with row numbers partitioned by dag_id and 
run_id.
+            # Different dags can have the same run_id but
+            # the dag_id combined with the run_id uniquely identify a run.
+            ranked_query = (
+                query.add_columns(
+                    func.row_number()
+                    .over(
+                        partition_by=[TI.dag_id, TI.run_id],
+                        order_by=[-TI.priority_weight, DR.logical_date, 
TI.map_index],
+                    )
+                    .label("row_num"),
+                    DM.max_active_tasks.label("dr_max_active_tasks"),
+                    # Create columns for the order_by checks here for sqlite.
+                    TI.priority_weight.label("priority_weight_for_ordering"),
+                    DR.logical_date.label("logical_date_for_ordering"),
+                    TI.map_index.label("map_index_for_ordering"),
+                )
+            ).subquery()
+
+            # Select only rows where row_number <= max_active_tasks.
+            query = (
+                select(TI)
+                .select_from(ranked_query)
+                .join(
+                    TI,
+                    (TI.dag_id == ranked_query.c.dag_id)
+                    & (TI.task_id == ranked_query.c.task_id)
+                    & (TI.run_id == ranked_query.c.run_id)
+                    & (TI.map_index == ranked_query.c.map_index),
+                )
+                .where(ranked_query.c.row_num <= 
ranked_query.c.dr_max_active_tasks)
+                # Add the order_by columns from the ranked query for sqlite.
+                .order_by(
+                    -ranked_query.c.priority_weight_for_ordering,
+                    ranked_query.c.logical_date_for_ordering,
+                    ranked_query.c.map_index_for_ordering,
+                )
+            )
+
             if starved_pools:
                 query = query.where(TI.pool.not_in(starved_pools))

Review Comment:
   This new query is missing `.options(selectinload(TI.dag_model))` which was 
on the original query above. When we rebuild the query here, we lose the eager 
loading - so every access to `ti.dag_model` later will trigger a separate 
query. With 50 TIs that's 50+ extra queries per loop, partially negating the 
perf gains.
   
   Also missing `.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")` 
- should add both here.



##########
airflow-core/tests/unit/jobs/test_scheduler_job.py:
##########
@@ -1266,6 +1312,71 @@ def 
test_find_executable_task_instances_executor_with_teams(self, dag_maker, moc
         ]
         assert len(b_tis_in_wrong_executor) == 0
 
+    @conf_vars(
+        {
+            ("scheduler", "max_tis_per_query"): "100",
+            ("scheduler", "max_dagruns_to_create_per_loop"): "10",
+            ("scheduler", "max_dagruns_per_loop_to_schedule"): "20",
+            ("core", "parallelism"): "100",
+            ("core", "max_active_tasks_per_dag"): "4",
+            ("core", "max_active_runs_per_dag"): "10",
+            ("core", "default_pool_task_slot_count"): "64",
+        }
+    )
+    def test_per_dr_limit_applied_in_task_query(self, dag_maker, 
mock_executors):
+        scheduler_job = Job()
+        scheduler_job.executor.parallelism = 100
+        scheduler_job.executor.slots_available = 70
+        scheduler_job.max_tis_per_query = 100
+        self.job_runner = SchedulerJobRunner(job=scheduler_job)
+        session = settings.Session()
+
+        # Use the same run_id.
+        task_maker(dag_maker, session, "dag_1300_tasks", 1300, 4, "run1")
+        task_maker(dag_maker, session, "dag_1200_tasks", 1200, 4, "run1")
+        task_maker(dag_maker, session, "dag_1100_tasks", 1100, 4, "run1")
+        task_maker(dag_maker, session, "dag_100_tasks", 100, 4, "run1")
+        task_maker(dag_maker, session, "dag_90_tasks", 90, 4, "run1")
+        task_maker(dag_maker, session, "dag_80_tasks", 80, 4, "run1")
+
+        count = 0
+        iterations = 0
+
+        from airflow.configuration import conf
+
+        task_num = conf.getint("core", "max_active_tasks_per_dag") * 6
+
+        # 6 dags * 4 = 24.
+        assert task_num == 24
+
+        queued_tis = None
+        while count < task_num:
+            # Use `_executable_task_instances_to_queued` because it returns a 
list of TIs
+            # while `_critical_section_enqueue_task_instances` just returns 
the number of the TIs.
+            queued_tis = self.job_runner._executable_task_instances_to_queued(
+                max_tis=scheduler_job.executor.slots_available, session=session
+            )
+            count += len(queued_tis)
+            iterations += 1
+
+        assert iterations == 1
+        assert count == task_num
+
+        assert queued_tis is not None
+
+        dag_counts = Counter(ti.dag_id for ti in queued_tis)
+
+        # Tasks from all 6 dags should have been queued.
+        assert len(dag_counts) == 6
+        assert dag_counts == {
+            "dag_1300_tasks": 4,
+            "dag_1200_tasks": 4,
+            "dag_1100_tasks": 4,
+            "dag_100_tasks": 4,
+            "dag_90_tasks": 4,
+            "dag_80_tasks": 4,
+        }, "Count for each dag_id should be 4 but it isn't"
+
     def test_find_executable_task_instances_order_priority_with_pools(self, 
dag_maker):

Review Comment:
   Couple test cases worth adding:
   
   1. **Starvation filter ordering**: dag run with tasks in mixed pools (some 
starved, some not). Verify non-starved pool tasks aren't excluded because 
starved-pool tasks consumed row_number slots.
   
   2. **Partial capacity**: dag run with `max_active_tasks=4` where 2 are 
already RUNNING + 10 SCHEDULED. Verify query returns only 2 (not 4) for that 
run.



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -490,10 +515,60 @@ def _executable_task_instances_to_queued(self, max_tis: 
int, session: Session) -
                 .where(~DM.is_paused)
                 .where(TI.state == TaskInstanceState.SCHEDULED)
                 .where(DM.bundle_name.is_not(None))
+                .join(
+                    dr_task_concurrency_subquery,
+                    and_(
+                        TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
+                        TI.run_id == dr_task_concurrency_subquery.c.run_id,
+                    ),
+                    isouter=True,
+                )
+                .where(
+                    
func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < 
DM.max_active_tasks
+                )
                 .options(selectinload(TI.dag_model))
                 .order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
             )
 
+            # Create a subquery with row numbers partitioned by dag_id and 
run_id.
+            # Different dags can have the same run_id but
+            # the dag_id combined with the run_id uniquely identify a run.
+            ranked_query = (
+                query.add_columns(
+                    func.row_number()
+                    .over(
+                        partition_by=[TI.dag_id, TI.run_id],
+                        order_by=[-TI.priority_weight, DR.logical_date, 
TI.map_index],
+                    )
+                    .label("row_num"),
+                    DM.max_active_tasks.label("dr_max_active_tasks"),
+                    # Create columns for the order_by checks here for sqlite.

Review Comment:
   The row_number ranking happens here before starvation filters 
(starved_pools, etc) are applied below. In the original code, those filters 
were applied BEFORE the limit.
   
   Tasks in starved pools will consume row_number slots and then get filtered 
out, potentially excluding schedulable tasks from the same dag run. Should we 
apply starvation filters to the base query before building ranked_query?



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -490,10 +515,60 @@ def _executable_task_instances_to_queued(self, max_tis: 
int, session: Session) -
                 .where(~DM.is_paused)
                 .where(TI.state == TaskInstanceState.SCHEDULED)
                 .where(DM.bundle_name.is_not(None))
+                .join(
+                    dr_task_concurrency_subquery,
+                    and_(
+                        TI.dag_id == dr_task_concurrency_subquery.c.dag_id,
+                        TI.run_id == dr_task_concurrency_subquery.c.run_id,
+                    ),
+                    isouter=True,
+                )
+                .where(
+                    
func.coalesce(dr_task_concurrency_subquery.c.task_per_dr_count, 0) < 
DM.max_active_tasks
+                )
                 .options(selectinload(TI.dag_model))
                 .order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
             )
 
+            # Create a subquery with row numbers partitioned by dag_id and 
run_id.
+            # Different dags can have the same run_id but
+            # the dag_id combined with the run_id uniquely identify a run.
+            ranked_query = (
+                query.add_columns(
+                    func.row_number()
+                    .over(
+                        partition_by=[TI.dag_id, TI.run_id],
+                        order_by=[-TI.priority_weight, DR.logical_date, 
TI.map_index],
+                    )
+                    .label("row_num"),
+                    DM.max_active_tasks.label("dr_max_active_tasks"),
+                    # Create columns for the order_by checks here for sqlite.
+                    TI.priority_weight.label("priority_weight_for_ordering"),
+                    DR.logical_date.label("logical_date_for_ordering"),
+                    TI.map_index.label("map_index_for_ordering"),
+                )
+            ).subquery()
+
+            # Select only rows where row_number <= max_active_tasks.
+            query = (
+                select(TI)
+                .select_from(ranked_query)
+                .join(
+                    TI,
+                    (TI.dag_id == ranked_query.c.dag_id)
+                    & (TI.task_id == ranked_query.c.task_id)
+                    & (TI.run_id == ranked_query.c.run_id)
+                    & (TI.map_index == ranked_query.c.map_index),
+                )
+                .where(ranked_query.c.row_num <= 
ranked_query.c.dr_max_active_tasks)
+                # Add the order_by columns from the ranked query for sqlite.
+                .order_by(
+                    -ranked_query.c.priority_weight_for_ordering,

Review Comment:
   This doesn't account for already running/queued tasks. If `max_active_tasks 
= 4` and there are already 2 running, we still return up to 4 scheduled tasks, 
but can only queue 2.
   
   Should be:
   ```python
   .where(ranked_query.c.row_num <= (ranked_query.c.dr_max_active_tasks - 
func.coalesce(ranked_query.c.task_per_dr_count, 0)))
   ```
   
   You'd need to add `task_per_dr_count` as a column in ranked_query.



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -515,7 +590,13 @@ def _executable_task_instances_to_queued(self, max_tis: 
int, session: Session) -
 
             try:
                 locked_query = with_row_locks(query, of=TI, session=session, 
skip_locked=True)
-                task_instances_to_examine: list[TI] = 
list(session.scalars(locked_query).all())
+                task_instances_to_examine = session.scalars(locked_query).all()
+
+                self.log.debug("Length of the tis to examine is %d", 
len(task_instances_to_examine))
+                self.log.debug(
+                    "TaskInstance selection is: %s",
+                    dict(Counter(ti.dag_id for ti in 
task_instances_to_examine)),
+                )
 
                 timer.stop(send=True)

Review Comment:
   nit: The `Counter()` iteration runs even when debug logging is disabled. If 
we're optimizing for perf:
   ```python
   if self.log.isEnabledFor(logging.DEBUG):
       self.log.debug(...)
   ```



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -194,6 +207,16 @@ def _is_parent_process() -> bool:
     return multiprocessing.current_process().name == "MainProcess"
 

Review Comment:
   This queries the same data as `ConcurrencyMap.load()` which is still called 
and used for the check at lines ~665-680. With SQL-level filtering now in 
place, that Python check should mostly pass (barring race conditions). Worth 
adding a comment explaining why we keep both?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to