potiuk commented on code in PR #32991:
URL: https://github.com/apache/airflow/pull/32991#discussion_r1287789679


##########
airflow/jobs/backfill_job_runner.py:
##########
@@ -588,61 +588,81 @@ def _per_task_process(key, ti: TaskInstance, session):
             try:
                 for task in 
self.dag.topological_sort(include_subdag_tasks=True):
                     for key, ti in list(ti_status.to_run.items()):
-                        if task.task_id != ti.task_id:
-                            continue
-
-                        pool = session.scalar(
-                            select(models.Pool).where(models.Pool.pool == 
task.pool).limit(1)
-                        )
-                        if not pool:
-                            raise PoolNotFound(f"Unknown pool: {task.pool}")
-
-                        open_slots = pool.open_slots(session=session)
-                        if open_slots <= 0:
-                            raise NoAvailablePoolSlot(
-                                f"Not scheduling since there are {open_slots} 
open slots in pool {task.pool}"
-                            )
-
-                        num_running_task_instances_in_dag = 
DAG.get_num_task_instances(
-                            self.dag_id,
-                            states=self.STATES_COUNT_AS_RUNNING,
-                            session=session,
-                        )
-
-                        if num_running_task_instances_in_dag >= 
self.dag.max_active_tasks:
-                            raise DagConcurrencyLimitReached(
-                                "Not scheduling since DAG max_active_tasks 
limit is reached."
+                        # Attempt to workaround deadlock on backfill by 
attempting to commit the transaction
+                        # state update few times before giving up
+                        max_attempts = 5
+                        for i in range(max_attempts):
+                            if task.task_id != ti.task_id:
+                                continue
+
+                            pool = session.scalar(
+                                select(models.Pool).where(models.Pool.pool == 
task.pool).limit(1)
                             )
+                            if not pool:
+                                raise PoolNotFound(f"Unknown pool: 
{task.pool}")
+
+                            open_slots = pool.open_slots(session=session)
+                            if open_slots <= 0:
+                                raise NoAvailablePoolSlot(
+                                    f"Not scheduling since there are 
{open_slots} "
+                                    f"open slots in pool {task.pool}"
+                                )
 
-                        if task.max_active_tis_per_dag is not None:
-                            num_running_task_instances_in_task = 
DAG.get_num_task_instances(
-                                dag_id=self.dag_id,
-                                task_ids=[task.task_id],
+                            num_running_task_instances_in_dag = 
DAG.get_num_task_instances(
+                                self.dag_id,
                                 states=self.STATES_COUNT_AS_RUNNING,
                                 session=session,
                             )
 
-                            if num_running_task_instances_in_task >= 
task.max_active_tis_per_dag:
-                                raise TaskConcurrencyLimitReached(
-                                    "Not scheduling since Task concurrency 
limit is reached."
+                            if num_running_task_instances_in_dag >= 
self.dag.max_active_tasks:
+                                raise DagConcurrencyLimitReached(
+                                    "Not scheduling since DAG max_active_tasks 
limit is reached."
                                 )
 
-                        if task.max_active_tis_per_dagrun is not None:
-                            num_running_task_instances_in_task_dagrun = 
DAG.get_num_task_instances(
-                                dag_id=self.dag_id,
-                                run_id=ti.run_id,
-                                task_ids=[task.task_id],
-                                states=self.STATES_COUNT_AS_RUNNING,
-                                session=session,
-                            )
+                            if task.max_active_tis_per_dag is not None:
+                                num_running_task_instances_in_task = 
DAG.get_num_task_instances(
+                                    dag_id=self.dag_id,
+                                    task_ids=[task.task_id],
+                                    states=self.STATES_COUNT_AS_RUNNING,
+                                    session=session,
+                                )
 
-                            if num_running_task_instances_in_task_dagrun >= 
task.max_active_tis_per_dagrun:
-                                raise TaskConcurrencyLimitReached(
-                                    "Not scheduling since Task concurrency per 
DAG run limit is reached."
+                                if num_running_task_instances_in_task >= 
task.max_active_tis_per_dag:
+                                    raise TaskConcurrencyLimitReached(
+                                        "Not scheduling since Task concurrency 
limit is reached."
+                                    )
+
+                            if task.max_active_tis_per_dagrun is not None:
+                                num_running_task_instances_in_task_dagrun = 
DAG.get_num_task_instances(
+                                    dag_id=self.dag_id,
+                                    run_id=ti.run_id,
+                                    task_ids=[task.task_id],
+                                    states=self.STATES_COUNT_AS_RUNNING,
+                                    session=session,
                                 )
 
-                        _per_task_process(key, ti, session)
-                        session.commit()
+                                if (
+                                    num_running_task_instances_in_task_dagrun
+                                    >= task.max_active_tis_per_dagrun
+                                ):
+                                    raise TaskConcurrencyLimitReached(
+                                        "Not scheduling since Task concurrency 
per DAG run limit is reached."
+                                    )
+
+                            _per_task_process(key, ti, session)
+                            try:
+                                session.commit()
+                                # break the retry loop
+                                break
+                            except OperationalError:
+                                self.log.error(
+                                    "Failed to commit task state due to 
operational error. "
+                                    "The job will retry this operation so if 
your backfill succeeds, "
+                                    "you can safely ignore this message.",
+                                    exc_info=True,
+                                )
+                                session.rollback()
+                                # outer loop will retry

Review Comment:
   > is there a missing check for the current attempt number?
   
   I don't think so . Why ? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to