kaxil commented on code in PR #62343: URL: https://github.com/apache/airflow/pull/62343#discussion_r2963468962
########## airflow-core/src/airflow/models/connection_test.py: ########## @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import secrets +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING +from uuid import UUID + +import structlog +import uuid6 +from sqlalchemy import JSON, Boolean, Index, String, Text, Uuid, select +from sqlalchemy.orm import Mapped, mapped_column + +from airflow._shared.timezones import timezone +from airflow.models.base import Base +from airflow.models.connection import Connection +from airflow.models.crypto import get_fernet +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +log = structlog.get_logger(__name__) + + +class ConnectionTestState(str, Enum): + """All possible states of a connection test.""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + +ACTIVE_STATES = frozenset((ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)) +TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS, ConnectionTestState.FAILED)) + + +class ConnectionTest(Base): + """Tracks an async connection test dispatched to a worker via a TestConnection workload.""" + + __tablename__ = "connection_test" + + id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7) + token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + connection_id: Mapped[str] = mapped_column(String(250), nullable=False) + state: Mapped[str] = mapped_column(String(10), nullable=False, default=ConnectionTestState.PENDING) + result_message: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) + executor: Mapped[str | None] = mapped_column(String(256), nullable=True) + queue: Mapped[str | None] = mapped_column(String(256), nullable=True) + connection_snapshot: Mapped[dict | None] = mapped_column(JSON(none_as_null=True), nullable=True) + reverted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="0") + + __table_args__ = (Index("idx_connection_test_state_created_at", state, created_at),) + + def __init__( + self, *, connection_id: str, executor: str | None = None, queue: str | None = None, **kwargs + ): + super().__init__(**kwargs) + self.connection_id = connection_id + self.executor = executor + self.queue = queue + self.token = secrets.token_urlsafe(32) + self.state = ConnectionTestState.PENDING + + def __repr__(self) -> str: + return f"<ConnectionTest id={self.id!r} connection_id={self.connection_id!r} state={self.state}>" + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.executor + + def get_dag_id(self) -> None: + """Return None — connection tests are not associated with any DAG.""" + return None + + +def run_connection_test(*, conn: Connection) -> tuple[bool, str]: + """ + Worker-side function to execute a connection test. + + Returns a (success, message) tuple. The caller is responsible for + reporting the result back via the Execution API. + """ + try: + return conn.test_connection() + except Exception as e: + log.exception("Connection test failed", connection_id=conn.conn_id) + return False, str(e) + + +_SNAPSHOT_FIELDS = ( + "conn_type", + "description", + "host", + "login", + "_password", + "schema", + "port", + "_extra", + "is_encrypted", + "is_extra_encrypted", +) + + +def snapshot_connection(conn: Connection) -> dict: Review Comment: **[Critical] Secrets in `connection_snapshot` JSON column** `snapshot_connection()` stores Fernet-encrypted `_password` and `_extra` ciphertext in a JSON column. This expands the attack surface — anyone with read access to the `connection_test` table gets credential ciphertext that was previously confined to the `connection` table. Two additional problems: 1. If Fernet keys rotate, stored snapshots become undecryptable and reverts will write garbage. 2. DB-level audit/masking rules on `connection._password` won't cover this secondary storage. If the snapshot approach survives (vs. Pierre's request-buffer redesign), consider storing only non-secret fields and re-fetching credentials at revert time, or encrypting the entire snapshot blob with a short-lived key. ########## airflow-core/src/airflow/models/connection_test.py: ########## @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import secrets +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING +from uuid import UUID + +import structlog +import uuid6 +from sqlalchemy import JSON, Boolean, Index, String, Text, Uuid, select +from sqlalchemy.orm import Mapped, mapped_column + +from airflow._shared.timezones import timezone +from airflow.models.base import Base +from airflow.models.connection import Connection +from airflow.models.crypto import get_fernet +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +log = structlog.get_logger(__name__) + + +class ConnectionTestState(str, Enum): + """All possible states of a connection test.""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + +ACTIVE_STATES = frozenset((ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)) Review Comment: **[High] PENDING tests never get reaped** `ACTIVE_STATES` includes QUEUED and RUNNING but not PENDING. The reaper (`_reap_stale_connection_tests`) only checks `ACTIVE_STATES`, so tests stuck in PENDING — e.g., scheduler restart between creation and dispatch, or no supporting executor found — accumulate forever. If those PENDING tests have snapshots (from save-and-test), the connection stays in its post-edit state with no resolution path. ```suggestion ACTIVE_STATES = frozenset((ConnectionTestState.PENDING, ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)) ``` ########## airflow-core/src/airflow/models/connection_test.py: ########## @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import secrets +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING +from uuid import UUID + +import structlog +import uuid6 +from sqlalchemy import JSON, Boolean, Index, String, Text, Uuid, select +from sqlalchemy.orm import Mapped, mapped_column + +from airflow._shared.timezones import timezone +from airflow.models.base import Base +from airflow.models.connection import Connection +from airflow.models.crypto import get_fernet +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +log = structlog.get_logger(__name__) + + +class ConnectionTestState(str, Enum): + """All possible states of a connection test.""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + +ACTIVE_STATES = frozenset((ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)) +TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS, ConnectionTestState.FAILED)) + + +class ConnectionTest(Base): + """Tracks an async connection test dispatched to a worker via a TestConnection workload.""" + + __tablename__ = "connection_test" + + id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7) + token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + connection_id: Mapped[str] = mapped_column(String(250), nullable=False) + state: Mapped[str] = mapped_column(String(10), nullable=False, default=ConnectionTestState.PENDING) + result_message: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) + executor: Mapped[str | None] = mapped_column(String(256), nullable=True) + queue: Mapped[str | None] = mapped_column(String(256), nullable=True) + connection_snapshot: Mapped[dict | None] = mapped_column(JSON(none_as_null=True), nullable=True) + reverted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="0") + + __table_args__ = (Index("idx_connection_test_state_created_at", state, created_at),) + + def __init__( + self, *, connection_id: str, executor: str | None = None, queue: str | None = None, **kwargs + ): + super().__init__(**kwargs) + self.connection_id = connection_id + self.executor = executor + self.queue = queue + self.token = secrets.token_urlsafe(32) + self.state = ConnectionTestState.PENDING + + def __repr__(self) -> str: + return f"<ConnectionTest id={self.id!r} connection_id={self.connection_id!r} state={self.state}>" + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.executor + + def get_dag_id(self) -> None: + """Return None — connection tests are not associated with any DAG.""" + return None + + +def run_connection_test(*, conn: Connection) -> tuple[bool, str]: + """ + Worker-side function to execute a connection test. + + Returns a (success, message) tuple. The caller is responsible for + reporting the result back via the Execution API. + """ + try: + return conn.test_connection() + except Exception as e: + log.exception("Connection test failed", connection_id=conn.conn_id) + return False, str(e) + + +_SNAPSHOT_FIELDS = ( + "conn_type", + "description", + "host", + "login", + "_password", + "schema", + "port", + "_extra", + "is_encrypted", + "is_extra_encrypted", +) + + +def snapshot_connection(conn: Connection) -> dict: + """ + Capture raw DB column values from a Connection for later restore. + + Encrypted fields (``_password``, ``_extra``) are stored as ciphertext + so they can be written directly back without re-encryption. + """ + return {field: getattr(conn, field) for field in _SNAPSHOT_FIELDS} + + +def _revert_connection(conn: Connection, snapshot: dict) -> None: + """ + Restore a Connection's columns from a snapshot dict. + + Writes directly to ``_password`` and ``_extra`` (bypassing the + encrypting property setters) so the stored ciphertext is preserved. + """ + for field, value in snapshot.items(): + setattr(conn, field, value) + + +def _decrypt_snapshot_field(snapshot: dict, field: str) -> str | None: + """Decrypt a single encrypted field from a snapshot dict using Fernet.""" + raw = snapshot.get(field) + if raw is None: + return None + encrypted_flag = "is_encrypted" if field == "_password" else "is_extra_encrypted" + if not snapshot.get(encrypted_flag, False): + return raw + fernet = get_fernet() + return fernet.decrypt(bytes(raw, "utf-8")).decode() + + +def _can_safely_revert(conn: Connection, post_snapshot: dict) -> bool: + """ + Check whether the connection's current state matches the post-edit snapshot. + + Compares **decrypted** values for encrypted fields and direct values for + non-encrypted fields. Returns ``False`` if any field differs, indicating + a concurrent edit has occurred and the revert should be skipped. + """ + for field in _SNAPSHOT_FIELDS: + if field in ("is_encrypted", "is_extra_encrypted"): + continue + + if field == "_password": + current_val = conn.password + snapshot_val = _decrypt_snapshot_field(post_snapshot, "_password") + elif field == "_extra": + current_val = conn.extra + snapshot_val = _decrypt_snapshot_field(post_snapshot, "_extra") + else: + current_val = getattr(conn, field) + snapshot_val = post_snapshot.get(field) + + if current_val != snapshot_val: + return False + return True + + +def attempt_revert(ct: ConnectionTest, *, session: Session) -> None: + """Revert a connection to its pre-edit values if no concurrent edit has occurred.""" + if not ct.connection_snapshot: + log.warning("attempt_revert called without snapshot", connection_test_id=ct.id) + return + + pre_snapshot = ct.connection_snapshot["pre"] + post_snapshot = ct.connection_snapshot["post"] + + ct.connection_snapshot = None + + conn = session.scalar(select(Connection).filter_by(conn_id=ct.connection_id)) + if conn is None: + ct.result_message = (ct.result_message or "") + " | Revert skipped: connection no longer exists." + log.warning("Revert skipped: connection no longer exists", connection_id=ct.connection_id) + return + + if not _can_safely_revert(conn, post_snapshot): + ct.result_message = ( + ct.result_message or "" Review Comment: **[High] `attempt_revert` doesn't lock the Connection row** This `select(Connection)` reads without `with_for_update()`. A concurrent edit can land between the `_can_safely_revert` check and `_revert_connection`, causing the revert to silently overwrite the concurrent change — the exact scenario `_can_safely_revert` is meant to prevent. ```python conn = session.scalar( select(Connection).filter_by(conn_id=ct.connection_id).with_for_update() ) ``` ########## airflow-core/src/airflow/models/connection_test.py: ########## @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import secrets +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING +from uuid import UUID + +import structlog +import uuid6 +from sqlalchemy import JSON, Boolean, Index, String, Text, Uuid, select +from sqlalchemy.orm import Mapped, mapped_column + +from airflow._shared.timezones import timezone +from airflow.models.base import Base +from airflow.models.connection import Connection +from airflow.models.crypto import get_fernet +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +log = structlog.get_logger(__name__) + + +class ConnectionTestState(str, Enum): + """All possible states of a connection test.""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + +ACTIVE_STATES = frozenset((ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)) +TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS, ConnectionTestState.FAILED)) + + +class ConnectionTest(Base): + """Tracks an async connection test dispatched to a worker via a TestConnection workload.""" + + __tablename__ = "connection_test" + + id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7) + token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + connection_id: Mapped[str] = mapped_column(String(250), nullable=False) + state: Mapped[str] = mapped_column(String(10), nullable=False, default=ConnectionTestState.PENDING) + result_message: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) + executor: Mapped[str | None] = mapped_column(String(256), nullable=True) + queue: Mapped[str | None] = mapped_column(String(256), nullable=True) + connection_snapshot: Mapped[dict | None] = mapped_column(JSON(none_as_null=True), nullable=True) + reverted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default="0") + + __table_args__ = (Index("idx_connection_test_state_created_at", state, created_at),) + + def __init__( + self, *, connection_id: str, executor: str | None = None, queue: str | None = None, **kwargs + ): + super().__init__(**kwargs) + self.connection_id = connection_id + self.executor = executor + self.queue = queue + self.token = secrets.token_urlsafe(32) + self.state = ConnectionTestState.PENDING + + def __repr__(self) -> str: + return f"<ConnectionTest id={self.id!r} connection_id={self.connection_id!r} state={self.state}>" + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.executor + + def get_dag_id(self) -> None: + """Return None — connection tests are not associated with any DAG.""" + return None + + +def run_connection_test(*, conn: Connection) -> tuple[bool, str]: + """ + Worker-side function to execute a connection test. + + Returns a (success, message) tuple. The caller is responsible for + reporting the result back via the Execution API. + """ + try: + return conn.test_connection() + except Exception as e: + log.exception("Connection test failed", connection_id=conn.conn_id) + return False, str(e) + + +_SNAPSHOT_FIELDS = ( + "conn_type", + "description", + "host", + "login", + "_password", + "schema", + "port", + "_extra", + "is_encrypted", + "is_extra_encrypted", +) + + +def snapshot_connection(conn: Connection) -> dict: + """ + Capture raw DB column values from a Connection for later restore. + + Encrypted fields (``_password``, ``_extra``) are stored as ciphertext + so they can be written directly back without re-encryption. + """ + return {field: getattr(conn, field) for field in _SNAPSHOT_FIELDS} + + +def _revert_connection(conn: Connection, snapshot: dict) -> None: + """ + Restore a Connection's columns from a snapshot dict. + + Writes directly to ``_password`` and ``_extra`` (bypassing the + encrypting property setters) so the stored ciphertext is preserved. + """ + for field, value in snapshot.items(): + setattr(conn, field, value) + + +def _decrypt_snapshot_field(snapshot: dict, field: str) -> str | None: + """Decrypt a single encrypted field from a snapshot dict using Fernet.""" + raw = snapshot.get(field) + if raw is None: + return None + encrypted_flag = "is_encrypted" if field == "_password" else "is_extra_encrypted" + if not snapshot.get(encrypted_flag, False): + return raw + fernet = get_fernet() + return fernet.decrypt(bytes(raw, "utf-8")).decode() + + +def _can_safely_revert(conn: Connection, post_snapshot: dict) -> bool: + """ + Check whether the connection's current state matches the post-edit snapshot. + + Compares **decrypted** values for encrypted fields and direct values for + non-encrypted fields. Returns ``False`` if any field differs, indicating + a concurrent edit has occurred and the revert should be skipped. + """ + for field in _SNAPSHOT_FIELDS: + if field in ("is_encrypted", "is_extra_encrypted"): + continue + + if field == "_password": + current_val = conn.password + snapshot_val = _decrypt_snapshot_field(post_snapshot, "_password") + elif field == "_extra": + current_val = conn.extra + snapshot_val = _decrypt_snapshot_field(post_snapshot, "_extra") + else: + current_val = getattr(conn, field) + snapshot_val = post_snapshot.get(field) + + if current_val != snapshot_val: + return False + return True + + +def attempt_revert(ct: ConnectionTest, *, session: Session) -> None: + """Revert a connection to its pre-edit values if no concurrent edit has occurred.""" + if not ct.connection_snapshot: + log.warning("attempt_revert called without snapshot", connection_test_id=ct.id) + return + + pre_snapshot = ct.connection_snapshot["pre"] + post_snapshot = ct.connection_snapshot["post"] + + ct.connection_snapshot = None + + conn = session.scalar(select(Connection).filter_by(conn_id=ct.connection_id)) + if conn is None: + ct.result_message = (ct.result_message or "") + " | Revert skipped: connection no longer exists." + log.warning("Revert skipped: connection no longer exists", connection_id=ct.connection_id) + return + + if not _can_safely_revert(conn, post_snapshot): Review Comment: **[High] Snapshot cleared before revert completes** `ct.connection_snapshot = None` is set here, before the actual revert logic runs. If the subsequent `select(Connection)` query or `_revert_connection` call fails (DB hiccup, connection dropped), the snapshot is lost and the revert can never be retried. Move `ct.connection_snapshot = None` to the end of the function, after the revert succeeds or is definitively skipped. ########## airflow-core/src/airflow/executors/local_executor.py: ########## @@ -168,6 +177,84 @@ def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_con raise RuntimeError(error_msg or "Callback execution failed") +def _execute_connection_test(log: Logger, workload: workloads.TestConnection, team_conf) -> None: + """ + Execute a connection test workload. + + Constructs an SDK ``Client``, fetches the connection via the Execution API, + enforces a timeout via ``signal.alarm``, and reports all outcomes back + through the Execution API. + + :param log: Logger instance + :param workload: The TestConnection workload to execute + :param team_conf: Team-specific executor configuration + """ + # Lazy import: SDK modules must not be loaded at module level to avoid + # coupling core (scheduler-loaded) code to the SDK. + from airflow.sdk.api.client import Client + from airflow.sdk.execution_time.comms import ErrorResponse + + setproctitle( + f"{_get_executor_process_title_prefix(team_conf.team_name)} connection-test {workload.connection_id}", + log, + ) + + base_url = team_conf.get("api", "base_url", fallback="/") + if base_url.startswith("/"): + base_url = f"http://localhost:8080{base_url}" + default_execution_api_server = f"{base_url.rstrip('/')}/execution/" + server = team_conf.get("core", "execution_api_server_url", fallback=default_execution_api_server) + + client = Client(base_url=server, token=workload.token) + + def _handle_timeout(signum, frame): + raise TimeoutError(f"Connection test timed out after {workload.timeout}s") + + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(workload.timeout) + try: + client.connection_tests.update_state(workload.connection_test_id, ConnectionTestState.RUNNING) + + conn_response = client.connections.get(workload.connection_id) + if isinstance(conn_response, ErrorResponse): + raise RuntimeError(f"Connection '{workload.connection_id}' not found via Execution API") + + conn = Connection( + conn_id=conn_response.conn_id, + conn_type=conn_response.conn_type, + host=conn_response.host, + login=conn_response.login, + password=conn_response.password, + schema=conn_response.schema_, + port=conn_response.port, + extra=conn_response.extra, + ) + success, message = run_connection_test(conn=conn) + + state = ConnectionTestState.SUCCESS if success else ConnectionTestState.FAILED + client.connection_tests.update_state(workload.connection_test_id, state, message) + except TimeoutError: + log.error( + "Connection test timed out after %ds", + workload.timeout, + connection_id=workload.connection_id, + ) + client.connection_tests.update_state( + workload.connection_test_id, + ConnectionTestState.FAILED, + f"Connection test timed out after {workload.timeout}s", + ) + except Exception as e: + log.exception("Connection test failed unexpectedly", connection_id=workload.connection_id) + client.connection_tests.update_state( + workload.connection_test_id, + ConnectionTestState.FAILED, Review Comment: **[Low] Exception messages in `result_message` may leak connection details** `f"Connection test failed unexpectedly: {e}"[:500]` stores raw exception text. Some DB drivers include connection strings (with credentials) in error messages. This gets stored in `result_message` and exposed via the polling API. Consider sanitizing or using a generic message with the exception type only: ```python f"Connection test failed unexpectedly: {type(e).__name__}" ``` ########## airflow-core/src/airflow/executors/base_executor.py: ########## @@ -305,10 +315,19 @@ def heartbeat(self) -> None: self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads) self.trigger_tasks(open_slots) + self.trigger_connection_tests() + # Calling child class sync method self.log.debug("Calling the %s sync method", self.__class__) self.sync() + def trigger_connection_tests(self) -> None: + """Process queued connection tests.""" + if not self.supports_connection_test or not self.queued_connection_tests: + return + + self._process_workloads(list(self.queued_connection_tests.values())) Review Comment: **[Medium] `trigger_connection_tests` dispatches all queued tests without rechecking slots** This is called from `heartbeat()` after `trigger_tasks(open_slots)` has already consumed slots. `trigger_connection_tests` then dispatches *all* queued connection tests via `_process_workloads` without checking remaining capacity. The scheduler's `_enqueue_connection_tests` does slot accounting, but slots may have been consumed between enqueue time and executor heartbeat. With `max_connection_test_concurrency=4` the overshoot is bounded, but it'd be cleaner to pass remaining slots through or check `self.slots_available` here. ########## airflow-core/src/airflow/migrations/versions/0110_3_2_0_add_connection_test_table.py: ########## @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add connection_test table for async connection testing. + +Revision ID: a7e6d4c3b2f1 +Revises: 1d6611b6ab7c +Create Date: 2026-02-22 00:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.utils.sqlalchemy import UtcDateTime + +# revision identifiers, used by Alembic. +revision = "a7e6d4c3b2f1" +down_revision = "1d6611b6ab7c" +branch_labels = None +depends_on = None +airflow_version = "3.2.0" + + +def upgrade(): + """Create connection_test table.""" + op.create_table( + "connection_test", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("token", sa.String(64), nullable=False), + sa.Column("connection_id", sa.String(250), nullable=False), + sa.Column("state", sa.String(10), nullable=False), + sa.Column("result_message", sa.Text(), nullable=True), + sa.Column("created_at", UtcDateTime(timezone=True), nullable=False), + sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False), Review Comment: **[Medium] `state` column `String(10)` is tight** Longest current value is `"pending"` (7 chars). `String(10)` leaves almost no room — `"timed_out"` (9) or `"cancelled"` (9) would need a new migration. Use `String(20)` for breathing room. ########## airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py: ########## @@ -174,6 +189,70 @@ def bulk_connections( return BulkConnectionService(session=session, request=request).handle_request() +@connections_router.patch( + "/{connection_id}/test", + responses=create_openapi_http_exception_doc( + [ + status.HTTP_400_BAD_REQUEST, + status.HTTP_403_FORBIDDEN, + status.HTTP_404_NOT_FOUND, + ] + ), + dependencies=[Depends(requires_access_connection(method="PUT")), Depends(action_logging())], +) +def patch_connection_and_test( + connection_id: str, + patch_body: ConnectionBody, + session: SessionDep, + update_mask: list[str] | None = Query(None), + executor: str | None = Query(None, description="Executor to route the connection test to"), + queue: str | None = Query(None, description="Queue to route the connection test to"), +) -> ConnectionSaveAndTestResponse: + """ + Update a connection and queue an async test with revert-on-failure. + + Atomically saves the edit and creates a ConnectionTest with snapshots of the + pre-edit and post-edit state. If the test fails, the connection is automatically + reverted to its pre-edit values. + """ + _ensure_test_connection_enabled() + + if patch_body.connection_id != connection_id: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + "The connection_id in the request body does not match the URL parameter", + ) + + connection = session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1)) + if connection is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"The Connection with connection_id: `{connection_id}` was not found", + ) + + try: + ConnectionBody(**patch_body.model_dump()) Review Comment: **[Medium] Redundant validation** `patch_body` is already a `ConnectionBody` instance — FastAPI validated it on parameter injection. Re-constructing `ConnectionBody(**patch_body.model_dump())` just validates the same data again. This block is dead code. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
