kaxil commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r3307402411
##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py:
##########
@@ -259,6 +321,78 @@ def test_connection(test_body: ConnectionBody) ->
ConnectionTestResponse:
os.environ.pop(conn_env_var, None)
+@connections_router.post(
+ "/enqueue-test",
+ status_code=status.HTTP_202_ACCEPTED,
+ responses=create_openapi_http_exception_doc(
+ [
+ status.HTTP_403_FORBIDDEN,
+ status.HTTP_409_CONFLICT,
+ status.HTTP_422_UNPROCESSABLE_ENTITY,
+ ]
+ ),
+ dependencies=[Depends(action_logging())],
+)
+def enqueue_connection_test(
+ test_body: ConnectionTestRequestBody,
+ session: SessionDep,
+ user: GetUserDep,
+) -> ConnectionTestQueuedResponse:
+ """Enqueue a connection test for deferred execution on a worker; returns a
polling token."""
+ _ensure_test_connection_enabled()
+ _ensure_executor_is_configured(test_body.executor)
+
+ existing =
session.scalar(select(Connection).filter_by(conn_id=test_body.connection_id))
+ if existing is not None:
+ effective_team = existing.team_name
+ if test_body.team_name is not None and test_body.team_name !=
effective_team:
+ raise HTTPException(
+ status.HTTP_403_FORBIDDEN,
+ f"team_name `{test_body.team_name}` does not match the team of
connection "
+ f"`{test_body.connection_id}`.",
+ )
+ else:
+ effective_team = test_body.team_name
Review Comment:
`effective_team` from `test_body.team_name` is used to authorize POST when
there's no existing connection row (the `commit_on_success=True` flow for a
brand-new connection), but it's never persisted on `ConnectionTestRequest`. The
polling endpoint at line 141 then re-resolves team via
`Connection.get_team_name(connection_test.connection_id)`, which returns `None`
because the connection still doesn't exist, and authz collapses to the global
"read connections" right. The scheduler dispatch at
`scheduler_job_runner.py:3302-3304` does the same lookup and routes the
workload to the global executor instead of the team-scoped one.
Cross-team isolation is broken end-to-end for the new-connection path in
multi-team mode. Add a `team_name` column to `connection_test_request`, persist
`effective_team` here, and read from the row in both the poll endpoint and the
scheduler dispatch path.
##########
airflow-core/src/airflow/models/connection_test.py:
##########
@@ -0,0 +1,216 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import secrets
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+import structlog
+import uuid6
+from sqlalchemy import (
+ Boolean,
+ Index,
+ Integer,
+ String,
+ Text,
+ UniqueConstraint,
+ Uuid,
+ select,
+)
+from sqlalchemy.orm import Mapped, mapped_column, validates
+
+from airflow._shared.timezones import timezone
+from airflow.models.base import Base
+from airflow.models.connection import Connection
+from airflow.models.crypto import FernetFieldsMixin
+from airflow.utils.sqlalchemy import UtcDateTime
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+log = structlog.get_logger(__name__)
+
+
+class ConnectionTestState(str, Enum):
+ """All possible states of a connection test."""
+
+ PENDING = "pending"
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+ACTIVE_STATES = frozenset(
+ (ConnectionTestState.PENDING, ConnectionTestState.QUEUED,
ConnectionTestState.RUNNING)
+)
+DISPATCHED_STATES = frozenset((ConnectionTestState.QUEUED,
ConnectionTestState.RUNNING))
+TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS,
ConnectionTestState.FAILED))
+
+
+@dataclass(frozen=True, slots=True)
+class ConnectionTestKey:
+ """Typed key for connection-test workloads (wraps str(UUID))."""
+
+ id: str
+
+ def __str__(self) -> str:
+ return self.id
+
+
+class ConnectionTestRequest(Base, FernetFieldsMixin):
+ """
+ Tracks an async connection test request dispatched to a worker.
+
+ Stores the full connection details so the worker reads from this table
+ instead of the real ``connection`` table. The real ``connection`` table
+ is only modified if the test succeeds and ``commit_on_success`` is True.
+ """
+
+ __tablename__ = "connection_test_request"
+
+ id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True,
default=uuid6.uuid7)
+ token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
+ connection_id: Mapped[str] = mapped_column(String(250), nullable=False)
+ state: Mapped[str] = mapped_column(String(20), nullable=False,
default=ConnectionTestState.PENDING)
+ result_message: Mapped[str | None] = mapped_column(String(2000),
nullable=True)
+ created_at: Mapped[datetime] = mapped_column(UtcDateTime,
default=timezone.utcnow, nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(
+ UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow,
nullable=False
+ )
+ executor: Mapped[str | None] = mapped_column(String(256), nullable=True)
+ queue: Mapped[str | None] = mapped_column(String(256), nullable=True)
+
+ conn_type: Mapped[str] = mapped_column(String(500), nullable=False)
+ host: Mapped[str | None] = mapped_column(String(500), nullable=True)
+ login: Mapped[str | None] = mapped_column(Text, nullable=True)
+ schema: Mapped[str | None] = mapped_column("schema", String(500),
nullable=True)
+ port: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ commit_on_success: Mapped[bool] = mapped_column(
+ Boolean, nullable=False, default=False, server_default="0"
+ )
+ is_encrypted: Mapped[bool] = mapped_column(
Review Comment:
`is_encrypted` / `is_extra_encrypted` flags are inherited from
`FernetFieldsMixin`, but `airflow rotate-fernet-key`
(`cli/commands/rotate_fernet_key_command.py`) only walks `Connection`,
`Variable`, `Trigger`. After rotation, any in-flight or db-cleanup-retained
`connection_test_request` row with `is_encrypted=True` becomes undecryptable.
Either add a `rotate_fernet_key()` method here and include it in the CLI
walk, or document that connection-test rows must be drained before rotation.
##########
task-sdk/src/airflow/sdk/api/client.py:
##########
@@ -1047,6 +1051,28 @@ def get_detail_response(self, ti_id: uuid.UUID) ->
HITLDetailResponse:
return HITLDetailResponse.model_validate_json(resp.read())
+class ConnectionTestOperations:
+ __slots__ = ("client",)
+
+ def __init__(self, client: Client):
+ self.client = client
+
+ def get_connection(self, connection_test_id: uuid.UUID) ->
ConnectionTestConnectionResponse:
+ """Fetch connection data for a test request from the API server."""
+ resp =
self.client.get(f"connection-tests/{connection_test_id}/connection")
+ return
ConnectionTestConnectionResponse.model_validate_json(resp.read())
+
+ def update_state(
+ self, id: uuid.UUID, state: ConnectionTestState, result_message: str |
None = None
+ ) -> None:
+ """Report the state of a connection test to the API server."""
+ body = ConnectionTestResultBody(
+ state=state,
+ result_message=ResultMessage(result_message) if result_message is
not None else None,
Review Comment:
`ResultMessage(result_message)` runs Pydantic validation with
`max_length=2000`. In the happy path at `connection_test_supervisor.py:87`,
`message` comes from `hook.test_connection()` and is user-controlled /
unbounded. A 3000-char message raises `ValidationError`, the supervisor exits
without PATCHing, and the row sits in RUNNING until the reaper labels it "timed
out" with no signal that the test actually completed.
Truncate inside `update_state` before constructing `ResultMessage` (e.g.
`result_message[:2000]`).
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -3247,6 +3264,97 @@ def _cleanup_orphaned_asset_state(*, session: Session)
-> None:
)
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))
+ def _enqueue_connection_tests(self, *, session: Session) -> None:
+ """
+ Enqueue pending connection tests to executors that support them.
+
+ ``max_concurrency`` is per-scheduler, not global: with N HA schedulers
+ the worst-case per-tick dispatch is ``N * max_concurrency``. Connection
+ tests are user-initiated and rare, so the overshoot self-corrects via
+ the reaper. For a true global cap, wrap the budget+claim below in a
+ sentinel-row ``SELECT ... FOR UPDATE``.
+ """
+ max_concurrency = conf.getint("connection_test", "max_concurrency",
fallback=4)
+ timeout = conf.getint("connection_test", "timeout", fallback=60)
+
+ active_count = session.scalar(
+ select(func.count(ConnectionTestRequest.id)).where(
+ ConnectionTestRequest.state.in_(DISPATCHED_STATES)
+ )
+ )
+ budget = max_concurrency - (active_count or 0)
+ if budget <= 0:
+ return
+
+ pending_stmt = (
+ select(ConnectionTestRequest)
+ .where(ConnectionTestRequest.state == ConnectionTestState.PENDING)
+ .order_by(ConnectionTestRequest.created_at)
+ .limit(budget)
+ )
+ pending_stmt = with_row_locks(pending_stmt, session,
of=ConnectionTestRequest, skip_locked=True)
+ pending_tests = session.scalars(pending_stmt).all()
+
+ if not pending_tests:
+ return
+
+ for ct in pending_tests:
+ team_name = (
+ Connection.get_team_name(ct.connection_id, session=session) if
self._multi_team else None
+ )
+ executor = self._try_to_load_executor(ct, session,
team_name=team_name)
+ if executor is None:
+ reason = f"No executor matches '{ct.executor}'"
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = reason
+ self.log.warning("Failing connection test %s: %s", ct.id,
reason)
+ continue
+ if not executor.supports_connection_test:
+ exec_name = executor.name
+ name = ct.executor or (exec_name and (exec_name.alias or
exec_name.module_path))
+ reason = f"Executor '{name}' does not support connection
testing"
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = reason
+ self.log.warning("Failing connection test %s: %s", ct.id,
reason)
+ continue
+
+ workload = workloads.TestConnection.make(
+ connection_test_id=ct.id,
+ connection_id=ct.connection_id,
+ timeout=timeout,
+ queue=ct.queue,
+ generator=executor.jwt_generator,
+ )
+ executor.queue_workload(workload, session=session)
+ ct.state = ConnectionTestState.QUEUED
+
+ session.flush()
+
+ @provide_session
+ def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION)
-> None:
+ """Mark connection tests that have exceeded their timeout as FAILED."""
+ timeout = conf.getint("connection_test", "timeout", fallback=60)
+ grace_period = max(30, timeout // 2)
+ cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period)
+
+ stale_stmt = select(ConnectionTestRequest).where(
+ ConnectionTestRequest.state.in_(CONNECTION_TEST_ACTIVE_STATES),
+ ConnectionTestRequest.updated_at < cutoff,
+ )
+ stale_stmt = with_row_locks(stale_stmt, session,
of=ConnectionTestRequest, skip_locked=True)
+ stale_tests = session.scalars(stale_stmt).all()
+
+ for ct in stale_tests:
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = f"Connection test timed out (exceeded
{timeout}s + {grace_period}s grace)"
Review Comment:
`CONNECTION_TEST_ACTIVE_STATES` covers PENDING/QUEUED/RUNNING and the
message is the same for all three. A PENDING row reaped here was never
dispatched (no executor picked it up before the cutoff); a RUNNING row exceeded
its timeout. Same `result_message="timed out"` is misleading for the former.
Branch on `ct.state` before composing the message.
##########
airflow-core/src/airflow/config_templates/config.yml:
##########
@@ -2699,6 +2699,42 @@ scheduler:
type: boolean
example: ~
default: "False"
+connection_test:
+ description: |
+ Configuration for the deferred connection-test workflow that dispatches
+ test requests to workers via the executor (instead of running them
+ in-process on the API server).
+ options:
+ timeout:
+ description: |
+ Maximum number of seconds a worker-dispatched connection test is
+ allowed to run before it is considered timed out. The scheduler
+ reaper uses this value plus a grace period to mark stale tests as
+ failed.
+ version_added: 3.3.0
+ type: integer
+ example: ~
+ default: "60"
+ max_concurrency:
+ description: |
+ Maximum number of connection tests that can be active
+ (QUEUED + RUNNING) at the same time. Excess tests will remain in
+ PENDING state until slots become available. With N HA schedulers and
+ the dispatch lock seeded by the migration, this is enforced as a
Review Comment:
This description promises "with N HA schedulers and the dispatch lock seeded
by the migration, this is enforced as a true global cap". The migration only
creates `connection_test_request`; there's no sentinel / dispatch-lock row, and
`scheduler_job_runner.py:3271-3275` explicitly documents the opposite
("per-scheduler, not global; worst-case `N * max_concurrency`").
Either align this description with the scheduler docstring, or implement the
sentinel-row `SELECT ... FOR UPDATE` the docstring suggests.
##########
airflow-core/src/airflow/models/connection_test.py:
##########
@@ -0,0 +1,216 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import secrets
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+import structlog
+import uuid6
+from sqlalchemy import (
+ Boolean,
+ Index,
+ Integer,
+ String,
+ Text,
+ UniqueConstraint,
+ Uuid,
+ select,
+)
+from sqlalchemy.orm import Mapped, mapped_column, validates
+
+from airflow._shared.timezones import timezone
+from airflow.models.base import Base
+from airflow.models.connection import Connection
+from airflow.models.crypto import FernetFieldsMixin
+from airflow.utils.sqlalchemy import UtcDateTime
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+log = structlog.get_logger(__name__)
+
+
+class ConnectionTestState(str, Enum):
+ """All possible states of a connection test."""
+
+ PENDING = "pending"
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+ACTIVE_STATES = frozenset(
+ (ConnectionTestState.PENDING, ConnectionTestState.QUEUED,
ConnectionTestState.RUNNING)
+)
+DISPATCHED_STATES = frozenset((ConnectionTestState.QUEUED,
ConnectionTestState.RUNNING))
+TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS,
ConnectionTestState.FAILED))
+
+
+@dataclass(frozen=True, slots=True)
+class ConnectionTestKey:
+ """Typed key for connection-test workloads (wraps str(UUID))."""
+
+ id: str
+
+ def __str__(self) -> str:
+ return self.id
+
+
+class ConnectionTestRequest(Base, FernetFieldsMixin):
+ """
+ Tracks an async connection test request dispatched to a worker.
+
+ Stores the full connection details so the worker reads from this table
+ instead of the real ``connection`` table. The real ``connection`` table
+ is only modified if the test succeeds and ``commit_on_success`` is True.
+ """
+
+ __tablename__ = "connection_test_request"
+
+ id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True,
default=uuid6.uuid7)
+ token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
+ connection_id: Mapped[str] = mapped_column(String(250), nullable=False)
+ state: Mapped[str] = mapped_column(String(20), nullable=False,
default=ConnectionTestState.PENDING)
+ result_message: Mapped[str | None] = mapped_column(String(2000),
nullable=True)
+ created_at: Mapped[datetime] = mapped_column(UtcDateTime,
default=timezone.utcnow, nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(
+ UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow,
nullable=False
+ )
+ executor: Mapped[str | None] = mapped_column(String(256), nullable=True)
+ queue: Mapped[str | None] = mapped_column(String(256), nullable=True)
+
+ conn_type: Mapped[str] = mapped_column(String(500), nullable=False)
+ host: Mapped[str | None] = mapped_column(String(500), nullable=True)
+ login: Mapped[str | None] = mapped_column(Text, nullable=True)
+ schema: Mapped[str | None] = mapped_column("schema", String(500),
nullable=True)
+ port: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ commit_on_success: Mapped[bool] = mapped_column(
+ Boolean, nullable=False, default=False, server_default="0"
+ )
+ is_encrypted: Mapped[bool] = mapped_column(
+ Boolean, unique=False, default=False, nullable=False,
server_default="0"
+ )
+ is_extra_encrypted: Mapped[bool] = mapped_column(
+ Boolean, unique=False, default=False, nullable=False,
server_default="0"
+ )
+
+ active_connection_id: Mapped[str | None] = mapped_column(String(250),
nullable=True)
+
+ __table_args__ = (
+ Index("idx_connection_test_request_state_created_at", state,
created_at),
+ UniqueConstraint(
+ "active_connection_id",
+ name="uq_connection_test_request_active_conn",
+ ),
+ )
+
+ def __init__(
+ self,
+ *,
+ connection_id: str,
+ conn_type: str,
+ host: str | None = None,
+ login: str | None = None,
+ password: str | None = None,
+ schema: str | None = None,
+ port: int | None = None,
+ extra: str | None = None,
+ commit_on_success: bool = False,
+ executor: str | None = None,
+ queue: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.connection_id = connection_id
+ self.conn_type = conn_type
+ self.host = host
+ self.login = login
+ self.password = password
+ self.schema = schema
+ self.port = port
+ self.extra = extra
+ self.commit_on_success = commit_on_success
+ self.executor = executor
+ self.queue = queue
+ self.token = secrets.token_urlsafe(32)
+ self.state = ConnectionTestState.PENDING
+
+ @validates("state")
+ def _sync_active_connection_id(self, _key: str, value: str) -> str:
Review Comment:
`@validates("state")` runs on every `state` assignment, including the enum
assignments in the route (`ct.state = body.state`) and the supervisor flow.
Annotation should be `str | ConnectionTestState` (or just the enum) so callers
/ IDE / type checkers see the real contract.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]