jason810496 commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r2876261697


##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py:
##########
@@ -212,6 +227,69 @@ def patch_connection(
     return connection
 
 
+@connections_router.patch(
+    "/{connection_id}/save-and-test",

Review Comment:
   Would `PUT` (or remain as `PATCH`) `/test/{connection_id}/` or 
`/{connection_id}/connection_test` be more RESTful?



##########
airflow-core/src/airflow/executors/local_executor.py:
##########
@@ -168,6 +177,84 @@ def _execute_callback(log: Logger, workload: 
workloads.ExecuteCallback, team_con
         raise RuntimeError(error_msg or "Callback execution failed")
 
 
+def _execute_connection_test(log: Logger, workload: workloads.TestConnection, 
team_conf) -> None:

Review Comment:
   It seems we need to add dedicated supervisor like #62645 as follow-up.



##########
airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py:
##########
@@ -72,12 +73,46 @@ class ConnectionCollectionResponse(BaseModel):
 
 
 class ConnectionTestResponse(BaseModel):
-    """Connection Test serializer for responses."""
+    """Connection Test serializer for synchronous test responses."""
 
     status: bool
     message: str
 
 
+class ConnectionTestRequestBody(StrictBaseModel):
+    """Request body for async connection test."""
+
+    connection_id: str
+    queue: str | None = None

Review Comment:
   Maybe we should allow users to specify ‎`executor` here.
   
   For example, specifying ‎`executor` and ‎`queue` would have different 
meanings at the task level.
   
   



##########
airflow-core/src/airflow/executors/base_executor.py:
##########
@@ -240,10 +242,14 @@ def queue_workload(self, workload: workloads.All, 
session: Session) -> None:
                     f"See LocalExecutor or CeleryExecutor for reference 
implementation."
                 )
             self.queued_callbacks[workload.callback.id] = workload
+        elif isinstance(workload, workloads.TestConnection):
+            if not self.supports_connection_test:
+                raise ValueError(f"Executor {type(self).__name__} does not 
support connection testing")

Review Comment:
   How about raising `NotImplementedError` like `ExecuteCallback` above?



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -3179,6 +3189,91 @@ def _activate_assets_generate_warnings() -> 
Iterator[tuple[str, str]]:
             session.add(warning)
             existing_warned_dag_ids.add(warning.dag_id)
 
+    @provide_session
+    def _dispatch_connection_tests(self, *, session: Session = NEW_SESSION) -> 
None:
+        """Dispatch pending connection tests to executors that support them."""
+        max_concurrency = conf.getint("core", 
"max_connection_test_concurrency", fallback=4)
+        timeout = conf.getint("core", "connection_test_timeout", fallback=60)
+
+        active_count = session.scalar(
+            select(func.count(ConnectionTest.id)).where(
+                ConnectionTest.state.in_([ConnectionTestState.QUEUED, 
ConnectionTestState.RUNNING])
+            )
+        )
+        budget = max_concurrency - (active_count or 0)
+        if budget <= 0:
+            return
+
+        pending_stmt = (
+            select(ConnectionTest)
+            .where(ConnectionTest.state == ConnectionTestState.PENDING)
+            .order_by(ConnectionTest.created_at)
+            .limit(budget)
+        )
+        pending_stmt = with_row_locks(pending_stmt, session, 
of=ConnectionTest, skip_locked=True)
+        pending_tests = session.scalars(pending_stmt).all()
+
+        if not pending_tests:
+            return
+
+        for ct in pending_tests:
+            executor = self._find_executor_for_connection_test(ct.queue)
+            if executor is None:
+                reason = (
+                    f"No executor serves queue '{ct.queue}'"
+                    if ct.queue
+                    else "No executor supports connection testing"
+                )
+                ct.state = ConnectionTestState.FAILED
+                ct.result_message = reason
+                self.log.warning("Failing connection test %s: %s", ct.id, 
reason)
+                continue
+
+            workload = workloads.TestConnection.make(
+                connection_test_id=ct.id,
+                connection_id=ct.connection_id,
+                timeout=timeout,
+                generator=executor.jwt_generator,
+            )
+            executor.queue_workload(workload, session=session)
+            ct.state = ConnectionTestState.QUEUED
+
+        session.flush()
+
+    @provide_session
+    def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION) 
-> None:
+        """Mark connection tests that have exceeded their timeout as FAILED."""
+        timeout = conf.getint("core", "connection_test_timeout", fallback=60)
+        grace_period = max(30, timeout // 2)
+        cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period)
+
+        stale_stmt = select(ConnectionTest).where(
+            ConnectionTest.state.in_([ConnectionTestState.QUEUED, 
ConnectionTestState.RUNNING]),
+            ConnectionTest.updated_at < cutoff,
+        )
+        stale_tests = session.scalars(stale_stmt).all()

Review Comment:
   Should we add row-level skip lock like `_dispatch_connection_tests`?
   ```suggestion
           stale_tests = with_row_locks(stale_tests, session, 
of=ConnectionTest, skip_locked=True)
           stale_tests = session.scalars(stale_stmt).all()
   ```



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -3179,6 +3189,91 @@ def _activate_assets_generate_warnings() -> 
Iterator[tuple[str, str]]:
             session.add(warning)
             existing_warned_dag_ids.add(warning.dag_id)
 
+    @provide_session
+    def _dispatch_connection_tests(self, *, session: Session = NEW_SESSION) -> 
None:
+        """Dispatch pending connection tests to executors that support them."""
+        max_concurrency = conf.getint("core", 
"max_connection_test_concurrency", fallback=4)
+        timeout = conf.getint("core", "connection_test_timeout", fallback=60)
+
+        active_count = session.scalar(
+            select(func.count(ConnectionTest.id)).where(
+                ConnectionTest.state.in_([ConnectionTestState.QUEUED, 
ConnectionTestState.RUNNING])
+            )
+        )
+        budget = max_concurrency - (active_count or 0)
+        if budget <= 0:
+            return
+
+        pending_stmt = (
+            select(ConnectionTest)
+            .where(ConnectionTest.state == ConnectionTestState.PENDING)
+            .order_by(ConnectionTest.created_at)
+            .limit(budget)
+        )
+        pending_stmt = with_row_locks(pending_stmt, session, 
of=ConnectionTest, skip_locked=True)
+        pending_tests = session.scalars(pending_stmt).all()
+
+        if not pending_tests:
+            return
+
+        for ct in pending_tests:
+            executor = self._find_executor_for_connection_test(ct.queue)
+            if executor is None:
+                reason = (
+                    f"No executor serves queue '{ct.queue}'"
+                    if ct.queue
+                    else "No executor supports connection testing"
+                )
+                ct.state = ConnectionTestState.FAILED
+                ct.result_message = reason
+                self.log.warning("Failing connection test %s: %s", ct.id, 
reason)
+                continue
+
+            workload = workloads.TestConnection.make(
+                connection_test_id=ct.id,
+                connection_id=ct.connection_id,
+                timeout=timeout,
+                generator=executor.jwt_generator,
+            )
+            executor.queue_workload(workload, session=session)
+            ct.state = ConnectionTestState.QUEUED
+
+        session.flush()
+
+    @provide_session
+    def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION) 
-> None:
+        """Mark connection tests that have exceeded their timeout as FAILED."""
+        timeout = conf.getint("core", "connection_test_timeout", fallback=60)
+        grace_period = max(30, timeout // 2)
+        cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period)
+
+        stale_stmt = select(ConnectionTest).where(
+            ConnectionTest.state.in_([ConnectionTestState.QUEUED, 
ConnectionTestState.RUNNING]),
+            ConnectionTest.updated_at < cutoff,
+        )
+        stale_tests = session.scalars(stale_stmt).all()
+
+        for ct in stale_tests:
+            ct.state = ConnectionTestState.FAILED
+            ct.result_message = f"Connection test timed out (exceeded 
{timeout}s + {grace_period}s grace)"
+            self.log.warning("Reaped stale connection test %s", ct.id)
+            if ct.connection_snapshot:
+                attempt_revert(ct, session=session)
+
+        session.flush()
+
+    def _find_executor_for_connection_test(self, queue: str | None) -> 
BaseExecutor | None:
+        """Find an executor that supports connection testing, optionally 
matching a queue."""
+        if queue is not None:
+            for executor in self.executors:
+                if executor.supports_connection_test and executor.team_name == 
queue:
+                    return executor

Review Comment:
   I haven’t seen ‎`team_name` mixed with ‎`TI.queue` anywhere so far. I’m 
concerned this usage might create ambiguity for users.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to