anishgirianish commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r3316039628


##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -3247,6 +3264,97 @@ def _cleanup_orphaned_asset_state(*, session: Session) 
-> None:
         )
         
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))
 
+    def _enqueue_connection_tests(self, *, session: Session) -> None:
+        """
+        Enqueue pending connection tests to executors that support them.
+
+        ``max_concurrency`` is per-scheduler, not global: with N HA schedulers
+        the worst-case per-tick dispatch is ``N * max_concurrency``. Connection
+        tests are user-initiated and rare, so the overshoot self-corrects via
+        the reaper. For a true global cap, wrap the budget+claim below in a
+        sentinel-row ``SELECT ... FOR UPDATE``.
+        """
+        max_concurrency = conf.getint("connection_test", "max_concurrency", 
fallback=4)
+        timeout = conf.getint("connection_test", "timeout", fallback=60)
+
+        active_count = session.scalar(
+            select(func.count(ConnectionTestRequest.id)).where(
+                ConnectionTestRequest.state.in_(DISPATCHED_STATES)
+            )
+        )
+        budget = max_concurrency - (active_count or 0)
+        if budget <= 0:
+            return
+
+        pending_stmt = (
+            select(ConnectionTestRequest)
+            .where(ConnectionTestRequest.state == ConnectionTestState.PENDING)
+            .order_by(ConnectionTestRequest.created_at)
+            .limit(budget)
+        )
+        pending_stmt = with_row_locks(pending_stmt, session, 
of=ConnectionTestRequest, skip_locked=True)
+        pending_tests = session.scalars(pending_stmt).all()
+
+        if not pending_tests:
+            return
+
+        for ct in pending_tests:
+            team_name = (
+                Connection.get_team_name(ct.connection_id, session=session) if 
self._multi_team else None
+            )
+            executor = self._try_to_load_executor(ct, session, 
team_name=team_name)
+            if executor is None:
+                reason = f"No executor matches '{ct.executor}'"
+                ct.state = ConnectionTestState.FAILED
+                ct.result_message = reason
+                self.log.warning("Failing connection test %s: %s", ct.id, 
reason)
+                continue
+            if not executor.supports_connection_test:
+                exec_name = executor.name
+                name = ct.executor or (exec_name and (exec_name.alias or 
exec_name.module_path))
+                reason = f"Executor '{name}' does not support connection 
testing"
+                ct.state = ConnectionTestState.FAILED
+                ct.result_message = reason
+                self.log.warning("Failing connection test %s: %s", ct.id, 
reason)
+                continue
+
+            workload = workloads.TestConnection.make(
+                connection_test_id=ct.id,
+                connection_id=ct.connection_id,
+                timeout=timeout,
+                queue=ct.queue,
+                generator=executor.jwt_generator,
+            )
+            executor.queue_workload(workload, session=session)
+            ct.state = ConnectionTestState.QUEUED
+
+        session.flush()
+
+    @provide_session
+    def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION) 
-> None:
+        """Mark connection tests that have exceeded their timeout as FAILED."""
+        timeout = conf.getint("connection_test", "timeout", fallback=60)
+        grace_period = max(30, timeout // 2)
+        cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period)
+
+        stale_stmt = select(ConnectionTestRequest).where(
+            ConnectionTestRequest.state.in_(CONNECTION_TEST_ACTIVE_STATES),
+            ConnectionTestRequest.updated_at < cutoff,
+        )
+        stale_stmt = with_row_locks(stale_stmt, session, 
of=ConnectionTestRequest, skip_locked=True)
+        stale_tests = session.scalars(stale_stmt).all()
+
+        for ct in stale_tests:
+            ct.state = ConnectionTestState.FAILED
+            ct.result_message = f"Connection test timed out (exceeded 
{timeout}s + {grace_period}s grace)"

Review Comment:
   Done thanks



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to