This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3ca175bb53e AIP 84 - Add auth to connections (#47194)
3ca175bb53e is described below
commit 3ca175bb53efb1d7de518e814de5035376f35081
Author: Kalyan R <[email protected]>
AuthorDate: Mon Mar 3 22:21:39 2025 +0530
AIP 84 - Add auth to connections (#47194)
* add auth to connections
* fix
* fix failing test
---
.../api_fastapi/core_api/openapi/v1-generated.yaml | 16 +++++
.../core_api/routes/public/connections.py | 19 +++---
airflow/api_fastapi/core_api/security.py | 26 ++++++++-
.../core_api/routes/public/test_connections.py | 68 ++++++++++++++++++++++
4 files changed, 120 insertions(+), 9 deletions(-)
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index f3c88ad6618..259efb6cd49 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1662,6 +1662,8 @@ paths:
summary: Delete Connection
description: Delete a connection entry.
operationId: delete_connection
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: connection_id
in: path
@@ -1702,6 +1704,8 @@ paths:
summary: Get Connection
description: Get a connection entry.
operationId: get_connection
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: connection_id
in: path
@@ -1746,6 +1750,8 @@ paths:
summary: Patch Connection
description: Update a connection entry.
operationId: patch_connection
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: connection_id
in: path
@@ -1813,6 +1819,8 @@ paths:
summary: Get Connections
description: Get all connection entries.
operationId: get_connections
+ security:
+ - OAuth2PasswordBearer: []
parameters:
- name: limit
in: query
@@ -1874,6 +1882,8 @@ paths:
summary: Post Connection
description: Create connection entry.
operationId: post_connection
+ security:
+ - OAuth2PasswordBearer: []
requestBody:
required: true
content:
@@ -1917,6 +1927,8 @@ paths:
summary: Bulk Connections
description: Bulk create, update, and delete connections.
operationId: bulk_connections
+ security:
+ - OAuth2PasswordBearer: []
requestBody:
required: true
content:
@@ -1995,6 +2007,8 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+ security:
+ - OAuth2PasswordBearer: []
/public/connections/defaults:
post:
tags:
@@ -2017,6 +2031,8 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
+ security:
+ - OAuth2PasswordBearer: []
/public/dags/{dag_id}/dagRuns/{dag_run_id}:
get:
tags:
diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py
b/airflow/api_fastapi/core_api/routes/public/connections.py
index 1ecf55665cf..351f3175263 100644
--- a/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -38,6 +38,7 @@ from airflow.api_fastapi.core_api.datamodels.connections
import (
ConnectionTestResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
+from airflow.api_fastapi.core_api.security import requires_access_connection
from airflow.api_fastapi.core_api.services.public.connections import
BulkConnectionService
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.configuration import conf
@@ -53,7 +54,7 @@ connections_router = AirflowRouter(tags=["Connection"],
prefix="/connections")
"/{connection_id}",
status_code=status.HTTP_204_NO_CONTENT,
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
- dependencies=[Depends(action_logging())],
+ dependencies=[Depends(requires_access_connection(method="DELETE")),
Depends(action_logging())],
)
def delete_connection(
connection_id: str,
@@ -73,6 +74,7 @@ def delete_connection(
@connections_router.get(
"/{connection_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+ dependencies=[Depends(requires_access_connection(method="GET"))],
)
def get_connection(
connection_id: str,
@@ -92,6 +94,7 @@ def get_connection(
@connections_router.get(
"",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+ dependencies=[Depends(requires_access_connection(method="GET"))],
)
def get_connections(
limit: QueryLimit,
@@ -131,7 +134,7 @@ def get_connections(
responses=create_openapi_http_exception_doc(
[status.HTTP_409_CONFLICT]
), # handled by global exception handler
- dependencies=[Depends(action_logging())],
+ dependencies=[Depends(requires_access_connection(method="POST")),
Depends(action_logging())],
)
def post_connection(
post_body: ConnectionBody,
@@ -143,7 +146,9 @@ def post_connection(
return connection
-@connections_router.patch("", dependencies=[Depends(action_logging())])
+@connections_router.patch(
+ "", dependencies=[Depends(requires_access_connection(method="PUT")),
Depends(action_logging())]
+)
def bulk_connections(
request: BulkBody[ConnectionBody],
session: SessionDep,
@@ -160,7 +165,7 @@ def bulk_connections(
status.HTTP_404_NOT_FOUND,
]
),
- dependencies=[Depends(action_logging())],
+ dependencies=[Depends(requires_access_connection(method="PUT")),
Depends(action_logging())],
)
def patch_connection(
connection_id: str,
@@ -201,9 +206,7 @@ def patch_connection(
return connection
-@connections_router.post(
- "/test",
-)
+@connections_router.post("/test",
dependencies=[Depends(requires_access_connection(method="POST"))])
def test_connection(
test_body: ConnectionBody,
) -> ConnectionTestResponse:
@@ -237,7 +240,7 @@ def test_connection(
@connections_router.post(
"/defaults",
status_code=status.HTTP_204_NO_CONTENT,
- dependencies=[Depends(action_logging())],
+ dependencies=[Depends(requires_access_connection(method="POST")),
Depends(action_logging())],
)
def create_default_connections(
session: SessionDep,
diff --git a/airflow/api_fastapi/core_api/security.py
b/airflow/api_fastapi/core_api/security.py
index 6a073c1c61f..27a91eb2c49 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -25,7 +25,12 @@ from jwt import InvalidTokenError
from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.models.base_user import BaseUser
-from airflow.auth.managers.models.resource_details import DagAccessEntity,
DagDetails, PoolDetails
+from airflow.auth.managers.models.resource_details import (
+ ConnectionDetails,
+ DagAccessEntity,
+ DagDetails,
+ PoolDetails,
+)
from airflow.configuration import conf
from airflow.utils.jwt_signer import JWTSigner, get_signing_key
@@ -101,6 +106,25 @@ def requires_access_pool(method: ResourceMethod) ->
Callable:
return inner
+def requires_access_connection(method: ResourceMethod) -> Callable:
+ def inner(
+ request: Request,
+ user: Annotated[BaseUser | None, Depends(get_user)] = None,
+ ) -> None:
+ connection_id = request.path_params.get("connection_id")
+
+ def callback():
+ return get_auth_manager().is_authorized_pool(
+ method=method,
details=ConnectionDetails(conn_id=connection_id), user=user
+ )
+
+ _requires_access(
+ is_authorized_callback=callback,
+ )
+
+ return inner
+
+
def _requires_access(
*,
is_authorized_callback: Callable[[], bool],
diff --git a/tests/api_fastapi/core_api/routes/public/test_connections.py
b/tests/api_fastapi/core_api/routes/public/test_connections.py
index a4ecdef94da..1df47350ce9 100644
--- a/tests/api_fastapi/core_api/routes/public/test_connections.py
+++ b/tests/api_fastapi/core_api/routes/public/test_connections.py
@@ -104,6 +104,14 @@ class TestDeleteConnection(TestConnectionEndpoint):
assert len(connection) == 0
_check_last_log(session, dag_id=None, event="delete_connection",
logical_date=None)
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response =
unauthenticated_test_client.delete(f"/public/connections/{TEST_CONN_ID}")
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response =
unauthorized_test_client.delete(f"/public/connections/{TEST_CONN_ID}")
+ assert response.status_code == 403
+
def test_delete_should_respond_404(self, test_client):
response = test_client.delete(f"/public/connections/{TEST_CONN_ID}")
assert response.status_code == 404
@@ -120,6 +128,14 @@ class TestGetConnection(TestConnectionEndpoint):
assert body["connection_id"] == TEST_CONN_ID
assert body["conn_type"] == TEST_CONN_TYPE
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response =
unauthenticated_test_client.get(f"/public/connections/{TEST_CONN_ID}")
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response =
unauthorized_test_client.get(f"/public/connections/{TEST_CONN_ID}")
+ assert response.status_code == 403
+
def test_get_should_respond_404(self, test_client):
response = test_client.get(f"/public/connections/{TEST_CONN_ID}")
assert response.status_code == 404
@@ -185,6 +201,14 @@ class TestGetConnections(TestConnectionEndpoint):
assert body["total_entries"] == expected_total_entries
assert [connection["connection_id"] for connection in
body["connections"]] == expected_ids
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.get("/public/connections",
params={})
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.get("/public/connections",
params={})
+ assert response.status_code == 403
+
class TestPostConnection(TestConnectionEndpoint):
@pytest.mark.parametrize(
@@ -213,6 +237,14 @@ class TestPostConnection(TestConnectionEndpoint):
assert len(connection) == 1
_check_last_log(session, dag_id=None, event="post_connection",
logical_date=None)
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.post("/public/connections",
json={})
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.post("/public/connections",
json={})
+ assert response.status_code == 403
+
@pytest.mark.parametrize(
"body",
[
@@ -344,6 +376,14 @@ class TestPatchConnection(TestConnectionEndpoint):
assert response.status_code == 200
_check_last_log(session, dag_id=None, event="patch_connection",
logical_date=None)
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response =
unauthenticated_test_client.patch(f"/public/connections/{TEST_CONN_ID}",
json={})
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response =
unauthorized_test_client.patch(f"/public/connections/{TEST_CONN_ID}", json={})
+ assert response.status_code == 403
+
@pytest.mark.parametrize(
"body, updated_connection, update_mask",
[
@@ -603,6 +643,18 @@ class TestConnection(TestConnectionEndpoint):
"message": "Connection successfully tested",
}
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.post(
+ "/public/connections/test", json={"connection_id": TEST_CONN_ID,
"conn_type": "sqlite"}
+ )
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.post(
+ "/public/connections/test", json={"connection_id": TEST_CONN_ID,
"conn_type": "sqlite"}
+ )
+ assert response.status_code == 403
+
@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
@pytest.mark.parametrize(
"body",
@@ -638,6 +690,14 @@ class TestCreateDefaultConnections(TestConnectionEndpoint):
assert response.content == b""
_check_last_log(session, dag_id=None,
event="create_default_connections", logical_date=None)
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response =
unauthenticated_test_client.post("/public/connections/defaults")
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response =
unauthorized_test_client.post("/public/connections/defaults")
+ assert response.status_code == 403
+
@mock.patch("airflow.api_fastapi.core_api.routes.public.connections.db_create_default_connections")
def test_should_call_db_create_default_connections(self,
mock_db_create_default_connections, test_client):
response = test_client.post("/public/connections/defaults")
@@ -944,3 +1004,11 @@ class TestBulkConnections(TestConnectionEndpoint):
for connection_id, value in expected_results.items():
assert response_data[connection_id] == value
_check_last_log(session, dag_id=None, event="bulk_connections",
logical_date=None)
+
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.patch("/public/connections",
json={})
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.patch("/public/connections",
json={})
+ assert response.status_code == 403