This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 25e173abb1d AIP-81 Implement POST/Insert Multiple Connections in REST
API (FastAPI) (#44531)
25e173abb1d is described below
commit 25e173abb1dc2bcf04a56e06a130cfe421b68568
Author: Bugra Ozturk <[email protected]>
AuthorDate: Sun Dec 1 15:19:44 2024 +0100
AIP-81 Implement POST/Insert Multiple Connections in REST API (FastAPI)
(#44531)
* Move validate key to pydantic model, leave unique check to database
session for post endpoints, include bulk connection insert endpoint
* Fix test naming
---
.../api_fastapi/core_api/datamodels/connections.py | 8 +-
.../api_fastapi/core_api/openapi/v1-generated.yaml | 58 ++++++
.../core_api/routes/public/connections.py | 35 ++--
airflow/ui/openapi-gen/queries/common.ts | 3 +
airflow/ui/openapi-gen/queries/queries.ts | 40 +++++
airflow/ui/openapi-gen/requests/schemas.gen.ts | 18 ++
airflow/ui/openapi-gen/requests/services.gen.ts | 27 +++
airflow/ui/openapi-gen/requests/types.gen.ts | 40 +++++
.../core_api/routes/public/test_connections.py | 199 ++++++++++++++++++++-
9 files changed, 405 insertions(+), 23 deletions(-)
diff --git a/airflow/api_fastapi/core_api/datamodels/connections.py
b/airflow/api_fastapi/core_api/datamodels/connections.py
index d74ced1ba4d..98ac5389e5d 100644
--- a/airflow/api_fastapi/core_api/datamodels/connections.py
+++ b/airflow/api_fastapi/core_api/datamodels/connections.py
@@ -79,7 +79,7 @@ class ConnectionTestResponse(BaseModel):
class ConnectionBody(BaseModel):
"""Connection Serializer for requests body."""
- connection_id: str = Field(serialization_alias="conn_id")
+ connection_id: str = Field(serialization_alias="conn_id", max_length=200,
pattern=r"^[\w.-]+$")
conn_type: str
description: str | None = Field(default=None)
host: str | None = Field(default=None)
@@ -88,3 +88,9 @@ class ConnectionBody(BaseModel):
port: int | None = Field(default=None)
password: str | None = Field(default=None)
extra: str | None = Field(default=None)
+
+
+class ConnectionBulkBody(BaseModel):
+ """Connections Serializer for requests body."""
+
+ connections: list[ConnectionBody]
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 6c34fc4cf41..a331a637c2e 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1399,6 +1399,50 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+ /public/connections/bulk:
+ post:
+ tags:
+ - Connection
+ summary: Post Connections
+ description: Create connection entry.
+ operationId: post_connections
+ requestBody:
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ConnectionBulkBody'
+ required: true
+ responses:
+ '201':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ConnectionCollectionResponse'
+ '401':
+ description: Unauthorized
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ '403':
+ description: Forbidden
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ '409':
+ description: Conflict
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
/public/connections/test:
post:
tags:
@@ -6207,6 +6251,8 @@ components:
properties:
connection_id:
type: string
+ maxLength: 200
+ pattern: ^[\w.-]+$
title: Connection Id
conn_type:
type: string
@@ -6252,6 +6298,18 @@ components:
- conn_type
title: ConnectionBody
description: Connection Serializer for requests body.
+ ConnectionBulkBody:
+ properties:
+ connections:
+ items:
+ $ref: '#/components/schemas/ConnectionBody'
+ type: array
+ title: Connections
+ type: object
+ required:
+ - connections
+ title: ConnectionBulkBody
+ description: Connections Serializer for requests body.
ConnectionCollectionResponse:
properties:
connections:
diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py
b/airflow/api_fastapi/core_api/routes/public/connections.py
index e2236addeda..37e94a98974 100644
--- a/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import os
-from typing import Annotated
+from typing import Annotated, cast
from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
@@ -27,6 +27,7 @@ from airflow.api_fastapi.common.parameters import QueryLimit,
QueryOffset, SortP
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.connections import (
ConnectionBody,
+ ConnectionBulkBody,
ConnectionCollectionResponse,
ConnectionResponse,
ConnectionTestResponse,
@@ -35,7 +36,6 @@ from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_
from airflow.configuration import conf
from airflow.models import Connection
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
-from airflow.utils import helpers
from airflow.utils.strings import get_random_string
connections_router = AirflowRouter(tags=["Connection"], prefix="/connections")
@@ -126,24 +126,29 @@ def post_connection(
session: SessionDep,
) -> ConnectionResponse:
"""Create connection entry."""
- try:
- helpers.validate_key(post_body.connection_id, max_length=200)
- except Exception as e:
- raise HTTPException(status.HTTP_400_BAD_REQUEST, f"{e}")
-
- connection =
session.scalar(select(Connection).filter_by(conn_id=post_body.connection_id))
- if connection is not None:
- raise HTTPException(
- status.HTTP_409_CONFLICT,
- f"Connection with connection_id: `{post_body.connection_id}`
already exists",
- )
-
connection = Connection(**post_body.model_dump(by_alias=True))
session.add(connection)
-
return connection
+@connections_router.post(
+ "/bulk",
+ status_code=status.HTTP_201_CREATED,
+ responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]),
+)
+def post_connections(
+ post_body: ConnectionBulkBody,
+ session: SessionDep,
+) -> ConnectionCollectionResponse:
+ """Create connection entry."""
+ connections = [Connection(**body.model_dump(by_alias=True)) for body in
post_body.connections]
+ session.add_all(connections)
+ return ConnectionCollectionResponse(
+ connections=cast(list[ConnectionResponse], connections),
+ total_entries=len(connections),
+ )
+
+
@connections_router.patch(
"/{connection_id}",
responses=create_openapi_http_exception_doc(
diff --git a/airflow/ui/openapi-gen/queries/common.ts
b/airflow/ui/openapi-gen/queries/common.ts
index f12eefb8322..f7e3576019e 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1594,6 +1594,9 @@ export type BackfillServiceCreateBackfillMutationResult =
Awaited<
export type ConnectionServicePostConnectionMutationResult = Awaited<
ReturnType<typeof ConnectionService.postConnection>
>;
+export type ConnectionServicePostConnectionsMutationResult = Awaited<
+ ReturnType<typeof ConnectionService.postConnections>
+>;
export type ConnectionServiceTestConnectionMutationResult = Awaited<
ReturnType<typeof ConnectionService.testConnection>
>;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts
b/airflow/ui/openapi-gen/queries/queries.ts
index 01e04aecdde..3992c3a2edf 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -38,6 +38,7 @@ import {
BackfillPostBody,
ClearTaskInstancesBody,
ConnectionBody,
+ ConnectionBulkBody,
CreateAssetEventsBody,
DAGPatchBody,
DAGRunClearBody,
@@ -2684,6 +2685,45 @@ export const useConnectionServicePostConnection = <
}) as unknown as Promise<TData>,
...options,
});
+/**
+ * Post Connections
+ * Create connection entry.
+ * @param data The data for the request.
+ * @param data.requestBody
+ * @returns ConnectionCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useConnectionServicePostConnections = <
+ TData = Common.ConnectionServicePostConnectionsMutationResult,
+ TError = unknown,
+ TContext = unknown,
+>(
+ options?: Omit<
+ UseMutationOptions<
+ TData,
+ TError,
+ {
+ requestBody: ConnectionBulkBody;
+ },
+ TContext
+ >,
+ "mutationFn"
+ >,
+) =>
+ useMutation<
+ TData,
+ TError,
+ {
+ requestBody: ConnectionBulkBody;
+ },
+ TContext
+ >({
+ mutationFn: ({ requestBody }) =>
+ ConnectionService.postConnections({
+ requestBody,
+ }) as unknown as Promise<TData>,
+ ...options,
+ });
/**
* Test Connection
* Test an API connection.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index c2cf77baab6..503910d75ad 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -750,6 +750,8 @@ export const $ConnectionBody = {
properties: {
connection_id: {
type: "string",
+ maxLength: 200,
+ pattern: "^[\\w.-]+$",
title: "Connection Id",
},
conn_type: {
@@ -840,6 +842,22 @@ export const $ConnectionBody = {
description: "Connection Serializer for requests body.",
} as const;
+export const $ConnectionBulkBody = {
+ properties: {
+ connections: {
+ items: {
+ $ref: "#/components/schemas/ConnectionBody",
+ },
+ type: "array",
+ title: "Connections",
+ },
+ },
+ type: "object",
+ required: ["connections"],
+ title: "ConnectionBulkBody",
+ description: "Connections Serializer for requests body.",
+} as const;
+
export const $ConnectionCollectionResponse = {
properties: {
connections: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow/ui/openapi-gen/requests/services.gen.ts
index ee949d070c8..c5e494fe119 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -58,6 +58,8 @@ import type {
GetConnectionsResponse,
PostConnectionData,
PostConnectionResponse,
+ PostConnectionsData,
+ PostConnectionsResponse,
TestConnectionData,
TestConnectionResponse,
GetDagRunData,
@@ -997,6 +999,31 @@ export class ConnectionService {
});
}
+ /**
+ * Post Connections
+ * Create connection entry.
+ * @param data The data for the request.
+ * @param data.requestBody
+ * @returns ConnectionCollectionResponse Successful Response
+ * @throws ApiError
+ */
+ public static postConnections(
+ data: PostConnectionsData,
+ ): CancelablePromise<PostConnectionsResponse> {
+ return __request(OpenAPI, {
+ method: "POST",
+ url: "/public/connections/bulk",
+ body: data.requestBody,
+ mediaType: "application/json",
+ errors: {
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 409: "Conflict",
+ 422: "Validation Error",
+ },
+ });
+ }
+
/**
* Test Connection
* Test an API connection.
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index dcb3ec94f95..48861f2bf42 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -214,6 +214,13 @@ export type ConnectionBody = {
extra?: string | null;
};
+/**
+ * Connections Serializer for requests body.
+ */
+export type ConnectionBulkBody = {
+ connections: Array<ConnectionBody>;
+};
+
/**
* Connection Collection serializer for responses.
*/
@@ -1499,6 +1506,12 @@ export type PostConnectionData = {
export type PostConnectionResponse = ConnectionResponse;
+export type PostConnectionsData = {
+ requestBody: ConnectionBulkBody;
+};
+
+export type PostConnectionsResponse = ConnectionCollectionResponse;
+
export type TestConnectionData = {
requestBody: ConnectionBody;
};
@@ -2766,6 +2779,33 @@ export type $OpenApiTs = {
};
};
};
+ "/public/connections/bulk": {
+ post: {
+ req: PostConnectionsData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 201: ConnectionCollectionResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Conflict
+ */
+ 409: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
+ };
"/public/connections/test": {
post: {
req: TestConnectionData;
diff --git a/tests/api_fastapi/core_api/routes/public/test_connections.py
b/tests/api_fastapi/core_api/routes/public/test_connections.py
index 0410c307fb6..2e090081d2f 100644
--- a/tests/api_fastapi/core_api/routes/public/test_connections.py
+++ b/tests/api_fastapi/core_api/routes/public/test_connections.py
@@ -45,6 +45,10 @@ TEST_CONN_PORT_2 = 8081
TEST_CONN_LOGIN_2 = "some_login_b"
+TEST_CONN_ID_3 = "test_connection_id_3"
+TEST_CONN_TYPE_3 = "test_type_3"
+
+
@provide_session
def _create_connection(session) -> None:
connection_model = Connection(
@@ -199,7 +203,7 @@ class TestPostConnection(TestConnectionEndpoint):
},
],
)
- def test_post_should_respond_200(self, test_client, session, body):
+ def test_post_should_respond_201(self, test_client, session, body):
response = test_client.post("/public/connections", json=body)
assert response.status_code == 201
connection = session.query(Connection).all()
@@ -214,13 +218,20 @@ class TestPostConnection(TestConnectionEndpoint):
{"connection_id": "iam_not@#$_connection_id", "conn_type":
TEST_CONN_TYPE},
],
)
- def test_post_should_respond_400_for_invalid_conn_id(self, test_client,
body):
+ def test_post_should_respond_422_for_invalid_conn_id(self, test_client,
body):
response = test_client.post("/public/connections", json=body)
- assert response.status_code == 400
- connection_id = body["connection_id"]
+ assert response.status_code == 422
+ # This regex is used for validation in ConnectionBody
assert response.json() == {
- "detail": f"The key '{connection_id}' has to be made of "
- "alphanumeric characters, dashes, dots and underscores
exclusively",
+ "detail": [
+ {
+ "ctx": {"pattern": r"^[\w.-]+$"},
+ "input": f"{body['connection_id']}",
+ "loc": ["body", "connection_id"],
+ "msg": "String should match pattern '^[\\w.-]+$'",
+ "type": "string_pattern_mismatch",
+ }
+ ]
}
@pytest.mark.parametrize(
@@ -236,7 +247,7 @@ class TestPostConnection(TestConnectionEndpoint):
response = test_client.post("/public/connections", json=body)
assert response.status_code == 409
assert response.json() == {
- "detail": f"Connection with connection_id: `{TEST_CONN_ID}`
already exists",
+ "detail": "Unique constraint violation",
}
@pytest.mark.enable_redact
@@ -298,6 +309,180 @@ class TestPostConnection(TestConnectionEndpoint):
assert response.json() == expected_response
+class TestPostConnections(TestConnectionEndpoint):
+ @pytest.mark.parametrize(
+ "body",
+ [
+ {
+ "connections": [
+ {"connection_id": TEST_CONN_ID, "conn_type":
TEST_CONN_TYPE},
+ {"connection_id": TEST_CONN_ID_2, "conn_type":
TEST_CONN_TYPE_2, "extra": None},
+ ]
+ },
+ {
+ "connections": [
+ {"connection_id": TEST_CONN_ID, "conn_type":
TEST_CONN_TYPE, "extra": "{}"},
+ {
+ "connection_id": TEST_CONN_ID_2,
+ "conn_type": TEST_CONN_TYPE_2,
+ "extra": '{"key": "value"}',
+ },
+ {
+ "connection_id": TEST_CONN_ID_3,
+ "conn_type": TEST_CONN_ID_3,
+ "description": "test_description",
+ "host": "test_host",
+ "login": "test_login",
+ "schema": "test_schema",
+ "port": 8080,
+ "extra": '{"key": "value"}',
+ },
+ ]
+ },
+ ],
+ )
+ def test_post_should_respond_201(self, test_client, session, body):
+ response = test_client.post("/public/connections/bulk", json=body)
+ assert response.status_code == 201
+ connection = session.query(Connection).all()
+ assert len(connection) == len(body["connections"])
+
+ @pytest.mark.parametrize(
+ "body",
+ [
+ {
+ "connections": [
+ {"connection_id": "****", "conn_type": TEST_CONN_TYPE},
+ {"connection_id": "test()", "conn_type": TEST_CONN_TYPE},
+ ]
+ },
+ {
+ "connections": [
+ {"connection_id": "this_^$#is_invalid", "conn_type":
TEST_CONN_TYPE},
+ {"connection_id": "iam_not@#$_connection_id", "conn_type":
TEST_CONN_TYPE},
+ ]
+ },
+ ],
+ )
+ def test_post_should_respond_422_for_invalid_conn_id(self, test_client,
body):
+ response = test_client.post("/public/connections/bulk", json=body)
+ assert response.status_code == 422
+ expected_response_detail = [
+ {
+ "ctx": {"pattern": r"^[\w.-]+$"},
+ "input": f"{body['connections'][conn_index]['connection_id']}",
+ "loc": ["body", "connections", conn_index, "connection_id"],
+ "msg": "String should match pattern '^[\\w.-]+$'",
+ "type": "string_pattern_mismatch",
+ }
+ for conn_index in range(len(body["connections"]))
+ ]
+ assert response.json() == {"detail": expected_response_detail}
+
+ @pytest.mark.parametrize(
+ "body",
+ [
+ {
+ "connections": [
+ {"connection_id": TEST_CONN_ID, "conn_type":
TEST_CONN_TYPE},
+ {"connection_id": TEST_CONN_ID_2, "conn_type":
TEST_CONN_TYPE_2, "extra": None},
+ ]
+ },
+ ],
+ )
+ def test_post_should_respond_already_exist(self, test_client, body):
+ response = test_client.post("/public/connections/bulk", json=body)
+ assert response.status_code == 201
+ # Another request
+ response = test_client.post("/public/connections/bulk", json=body)
+ assert response.status_code == 409
+ assert response.json() == {
+ "detail": "Unique constraint violation",
+ }
+
+ @pytest.mark.enable_redact
+ @pytest.mark.parametrize(
+ "body, expected_response",
+ [
+ (
+ {
+ "connections": [
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "password": "test-password",
+ },
+ {
+ "connection_id": TEST_CONN_ID_2,
+ "conn_type": TEST_CONN_TYPE_2,
+ "password": "?>@#+!_%()#",
+ },
+ ]
+ },
+ {
+ "connections": [
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "description": None,
+ "extra": None,
+ "host": None,
+ "login": None,
+ "password": "***",
+ "port": None,
+ "schema": None,
+ },
+ {
+ "connection_id": TEST_CONN_ID_2,
+ "conn_type": TEST_CONN_TYPE_2,
+ "description": None,
+ "extra": None,
+ "host": None,
+ "login": None,
+ "password": "***",
+ "port": None,
+ "schema": None,
+ },
+ ],
+ "total_entries": 2,
+ },
+ ),
+ (
+ {
+ "connections": [
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "password": "A!rF|0wi$aw3s0m3",
+ "extra": '{"password": "test-password"}',
+ }
+ ]
+ },
+ {
+ "connections": [
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "description": None,
+ "extra": '{"password": "***"}',
+ "host": None,
+ "login": None,
+ "password": "***",
+ "port": None,
+ "schema": None,
+ },
+ ],
+ "total_entries": 1,
+ },
+ ),
+ ],
+ )
+ def test_post_should_response_201_redacted_password(self, test_client,
body, expected_response):
+ response = test_client.post("/public/connections/bulk", json=body)
+ assert response.status_code == 201
+ assert response.json() == expected_response
+
+
class TestPatchConnection(TestConnectionEndpoint):
@pytest.mark.parametrize(
"body",