This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ca175bb53e AIP 84 - Add auth to connections (#47194)
3ca175bb53e is described below

commit 3ca175bb53efb1d7de518e814de5035376f35081
Author: Kalyan R <[email protected]>
AuthorDate: Mon Mar 3 22:21:39 2025 +0530

    AIP 84 - Add auth to connections (#47194)
    
    * add auth to connections
    
    * fix
    
    * fix failing test
---
 .../api_fastapi/core_api/openapi/v1-generated.yaml | 16 +++++
 .../core_api/routes/public/connections.py          | 19 +++---
 airflow/api_fastapi/core_api/security.py           | 26 ++++++++-
 .../core_api/routes/public/test_connections.py     | 68 ++++++++++++++++++++++
 4 files changed, 120 insertions(+), 9 deletions(-)

diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index f3c88ad6618..259efb6cd49 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1662,6 +1662,8 @@ paths:
       summary: Delete Connection
       description: Delete a connection entry.
       operationId: delete_connection
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: connection_id
         in: path
@@ -1702,6 +1704,8 @@ paths:
       summary: Get Connection
       description: Get a connection entry.
       operationId: get_connection
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: connection_id
         in: path
@@ -1746,6 +1750,8 @@ paths:
       summary: Patch Connection
       description: Update a connection entry.
       operationId: patch_connection
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: connection_id
         in: path
@@ -1813,6 +1819,8 @@ paths:
       summary: Get Connections
       description: Get all connection entries.
       operationId: get_connections
+      security:
+      - OAuth2PasswordBearer: []
       parameters:
       - name: limit
         in: query
@@ -1874,6 +1882,8 @@ paths:
       summary: Post Connection
       description: Create connection entry.
       operationId: post_connection
+      security:
+      - OAuth2PasswordBearer: []
       requestBody:
         required: true
         content:
@@ -1917,6 +1927,8 @@ paths:
       summary: Bulk Connections
       description: Bulk create, update, and delete connections.
       operationId: bulk_connections
+      security:
+      - OAuth2PasswordBearer: []
       requestBody:
         required: true
         content:
@@ -1995,6 +2007,8 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+      security:
+      - OAuth2PasswordBearer: []
   /public/connections/defaults:
     post:
       tags:
@@ -2017,6 +2031,8 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPExceptionResponse'
+      security:
+      - OAuth2PasswordBearer: []
   /public/dags/{dag_id}/dagRuns/{dag_run_id}:
     get:
       tags:
diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py 
b/airflow/api_fastapi/core_api/routes/public/connections.py
index 1ecf55665cf..351f3175263 100644
--- a/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -38,6 +38,7 @@ from airflow.api_fastapi.core_api.datamodels.connections 
import (
     ConnectionTestResponse,
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
+from airflow.api_fastapi.core_api.security import requires_access_connection
 from airflow.api_fastapi.core_api.services.public.connections import 
BulkConnectionService
 from airflow.api_fastapi.logging.decorators import action_logging
 from airflow.configuration import conf
@@ -53,7 +54,7 @@ connections_router = AirflowRouter(tags=["Connection"], 
prefix="/connections")
     "/{connection_id}",
     status_code=status.HTTP_204_NO_CONTENT,
     responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
-    dependencies=[Depends(action_logging())],
+    dependencies=[Depends(requires_access_connection(method="DELETE")), 
Depends(action_logging())],
 )
 def delete_connection(
     connection_id: str,
@@ -73,6 +74,7 @@ def delete_connection(
 @connections_router.get(
     "/{connection_id}",
     responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+    dependencies=[Depends(requires_access_connection(method="GET"))],
 )
 def get_connection(
     connection_id: str,
@@ -92,6 +94,7 @@ def get_connection(
 @connections_router.get(
     "",
     responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+    dependencies=[Depends(requires_access_connection(method="GET"))],
 )
 def get_connections(
     limit: QueryLimit,
@@ -131,7 +134,7 @@ def get_connections(
     responses=create_openapi_http_exception_doc(
         [status.HTTP_409_CONFLICT]
     ),  # handled by global exception handler
-    dependencies=[Depends(action_logging())],
+    dependencies=[Depends(requires_access_connection(method="POST")), 
Depends(action_logging())],
 )
 def post_connection(
     post_body: ConnectionBody,
@@ -143,7 +146,9 @@ def post_connection(
     return connection
 
 
-@connections_router.patch("", dependencies=[Depends(action_logging())])
+@connections_router.patch(
+    "", dependencies=[Depends(requires_access_connection(method="PUT")), 
Depends(action_logging())]
+)
 def bulk_connections(
     request: BulkBody[ConnectionBody],
     session: SessionDep,
@@ -160,7 +165,7 @@ def bulk_connections(
             status.HTTP_404_NOT_FOUND,
         ]
     ),
-    dependencies=[Depends(action_logging())],
+    dependencies=[Depends(requires_access_connection(method="PUT")), 
Depends(action_logging())],
 )
 def patch_connection(
     connection_id: str,
@@ -201,9 +206,7 @@ def patch_connection(
     return connection
 
 
-@connections_router.post(
-    "/test",
-)
+@connections_router.post("/test", 
dependencies=[Depends(requires_access_connection(method="POST"))])
 def test_connection(
     test_body: ConnectionBody,
 ) -> ConnectionTestResponse:
@@ -237,7 +240,7 @@ def test_connection(
 @connections_router.post(
     "/defaults",
     status_code=status.HTTP_204_NO_CONTENT,
-    dependencies=[Depends(action_logging())],
+    dependencies=[Depends(requires_access_connection(method="POST")), 
Depends(action_logging())],
 )
 def create_default_connections(
     session: SessionDep,
diff --git a/airflow/api_fastapi/core_api/security.py 
b/airflow/api_fastapi/core_api/security.py
index 6a073c1c61f..27a91eb2c49 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -25,7 +25,12 @@ from jwt import InvalidTokenError
 
 from airflow.api_fastapi.app import get_auth_manager
 from airflow.auth.managers.models.base_user import BaseUser
-from airflow.auth.managers.models.resource_details import DagAccessEntity, 
DagDetails, PoolDetails
+from airflow.auth.managers.models.resource_details import (
+    ConnectionDetails,
+    DagAccessEntity,
+    DagDetails,
+    PoolDetails,
+)
 from airflow.configuration import conf
 from airflow.utils.jwt_signer import JWTSigner, get_signing_key
 
@@ -101,6 +106,25 @@ def requires_access_pool(method: ResourceMethod) -> 
Callable:
     return inner
 
 
+def requires_access_connection(method: ResourceMethod) -> Callable:
+    def inner(
+        request: Request,
+        user: Annotated[BaseUser | None, Depends(get_user)] = None,
+    ) -> None:
+        connection_id = request.path_params.get("connection_id")
+
+        def callback():
+            return get_auth_manager().is_authorized_pool(
+                method=method, 
details=ConnectionDetails(conn_id=connection_id), user=user
+            )
+
+        _requires_access(
+            is_authorized_callback=callback,
+        )
+
+    return inner
+
+
 def _requires_access(
     *,
     is_authorized_callback: Callable[[], bool],
diff --git a/tests/api_fastapi/core_api/routes/public/test_connections.py 
b/tests/api_fastapi/core_api/routes/public/test_connections.py
index a4ecdef94da..1df47350ce9 100644
--- a/tests/api_fastapi/core_api/routes/public/test_connections.py
+++ b/tests/api_fastapi/core_api/routes/public/test_connections.py
@@ -104,6 +104,14 @@ class TestDeleteConnection(TestConnectionEndpoint):
         assert len(connection) == 0
         _check_last_log(session, dag_id=None, event="delete_connection", 
logical_date=None)
 
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = 
unauthenticated_test_client.delete(f"/public/connections/{TEST_CONN_ID}")
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = 
unauthorized_test_client.delete(f"/public/connections/{TEST_CONN_ID}")
+        assert response.status_code == 403
+
     def test_delete_should_respond_404(self, test_client):
         response = test_client.delete(f"/public/connections/{TEST_CONN_ID}")
         assert response.status_code == 404
@@ -120,6 +128,14 @@ class TestGetConnection(TestConnectionEndpoint):
         assert body["connection_id"] == TEST_CONN_ID
         assert body["conn_type"] == TEST_CONN_TYPE
 
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = 
unauthenticated_test_client.get(f"/public/connections/{TEST_CONN_ID}")
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = 
unauthorized_test_client.get(f"/public/connections/{TEST_CONN_ID}")
+        assert response.status_code == 403
+
     def test_get_should_respond_404(self, test_client):
         response = test_client.get(f"/public/connections/{TEST_CONN_ID}")
         assert response.status_code == 404
@@ -185,6 +201,14 @@ class TestGetConnections(TestConnectionEndpoint):
         assert body["total_entries"] == expected_total_entries
         assert [connection["connection_id"] for connection in 
body["connections"]] == expected_ids
 
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.get("/public/connections", 
params={})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.get("/public/connections", 
params={})
+        assert response.status_code == 403
+
 
 class TestPostConnection(TestConnectionEndpoint):
     @pytest.mark.parametrize(
@@ -213,6 +237,14 @@ class TestPostConnection(TestConnectionEndpoint):
         assert len(connection) == 1
         _check_last_log(session, dag_id=None, event="post_connection", 
logical_date=None)
 
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.post("/public/connections", 
json={})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.post("/public/connections", 
json={})
+        assert response.status_code == 403
+
     @pytest.mark.parametrize(
         "body",
         [
@@ -344,6 +376,14 @@ class TestPatchConnection(TestConnectionEndpoint):
         assert response.status_code == 200
         _check_last_log(session, dag_id=None, event="patch_connection", 
logical_date=None)
 
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = 
unauthenticated_test_client.patch(f"/public/connections/{TEST_CONN_ID}", 
json={})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = 
unauthorized_test_client.patch(f"/public/connections/{TEST_CONN_ID}", json={})
+        assert response.status_code == 403
+
     @pytest.mark.parametrize(
         "body, updated_connection, update_mask",
         [
@@ -603,6 +643,18 @@ class TestConnection(TestConnectionEndpoint):
             "message": "Connection successfully tested",
         }
 
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.post(
+            "/public/connections/test", json={"connection_id": TEST_CONN_ID, 
"conn_type": "sqlite"}
+        )
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.post(
+            "/public/connections/test", json={"connection_id": TEST_CONN_ID, 
"conn_type": "sqlite"}
+        )
+        assert response.status_code == 403
+
     @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
     @pytest.mark.parametrize(
         "body",
@@ -638,6 +690,14 @@ class TestCreateDefaultConnections(TestConnectionEndpoint):
         assert response.content == b""
         _check_last_log(session, dag_id=None, 
event="create_default_connections", logical_date=None)
 
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = 
unauthenticated_test_client.post("/public/connections/defaults")
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = 
unauthorized_test_client.post("/public/connections/defaults")
+        assert response.status_code == 403
+
     
@mock.patch("airflow.api_fastapi.core_api.routes.public.connections.db_create_default_connections")
     def test_should_call_db_create_default_connections(self, 
mock_db_create_default_connections, test_client):
         response = test_client.post("/public/connections/defaults")
@@ -944,3 +1004,11 @@ class TestBulkConnections(TestConnectionEndpoint):
         for connection_id, value in expected_results.items():
             assert response_data[connection_id] == value
         _check_last_log(session, dag_id=None, event="bulk_connections", 
logical_date=None)
+
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.patch("/public/connections", 
json={})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.patch("/public/connections", 
json={})
+        assert response.status_code == 403

Reply via email to