This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 25e173abb1d AIP-81 Implement POST/Insert Multiple Connections in REST 
API (FastAPI) (#44531)
25e173abb1d is described below

commit 25e173abb1dc2bcf04a56e06a130cfe421b68568
Author: Bugra Ozturk <[email protected]>
AuthorDate: Sun Dec 1 15:19:44 2024 +0100

    AIP-81 Implement POST/Insert Multiple Connections in REST API (FastAPI) 
(#44531)
    
    * Move validate key to pydantic model, leave unique check to database 
session for post endpoints, include bulk connection insert endpoint
    
    * Fix test naming
---
 .../api_fastapi/core_api/datamodels/connections.py |   8 +-
 .../api_fastapi/core_api/openapi/v1-generated.yaml |  58 ++++++
 .../core_api/routes/public/connections.py          |  35 ++--
 airflow/ui/openapi-gen/queries/common.ts           |   3 +
 airflow/ui/openapi-gen/queries/queries.ts          |  40 +++++
 airflow/ui/openapi-gen/requests/schemas.gen.ts     |  18 ++
 airflow/ui/openapi-gen/requests/services.gen.ts    |  27 +++
 airflow/ui/openapi-gen/requests/types.gen.ts       |  40 +++++
 .../core_api/routes/public/test_connections.py     | 199 ++++++++++++++++++++-
 9 files changed, 405 insertions(+), 23 deletions(-)

diff --git a/airflow/api_fastapi/core_api/datamodels/connections.py 
b/airflow/api_fastapi/core_api/datamodels/connections.py
index d74ced1ba4d..98ac5389e5d 100644
--- a/airflow/api_fastapi/core_api/datamodels/connections.py
+++ b/airflow/api_fastapi/core_api/datamodels/connections.py
@@ -79,7 +79,7 @@ class ConnectionTestResponse(BaseModel):
 class ConnectionBody(BaseModel):
     """Connection Serializer for requests body."""
 
-    connection_id: str = Field(serialization_alias="conn_id")
+    connection_id: str = Field(serialization_alias="conn_id", max_length=200, 
pattern=r"^[\w.-]+$")
     conn_type: str
     description: str | None = Field(default=None)
     host: str | None = Field(default=None)
@@ -88,3 +88,9 @@ class ConnectionBody(BaseModel):
     port: int | None = Field(default=None)
     password: str | None = Field(default=None)
     extra: str | None = Field(default=None)
+
+
+class ConnectionBulkBody(BaseModel):
+    """Connections Serializer for requests body."""
+
+    connections: list[ConnectionBody]
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 6c34fc4cf41..a331a637c2e 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1399,6 +1399,50 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+  /public/connections/bulk:
+    post:
+      tags:
+      - Connection
+      summary: Post Connections
+      description: Create connection entry.
+      operationId: post_connections
+      requestBody:
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/ConnectionBulkBody'
+        required: true
+      responses:
+        '201':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/ConnectionCollectionResponse'
+        '401':
+          description: Unauthorized
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+        '403':
+          description: Forbidden
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+        '409':
+          description: Conflict
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
   /public/connections/test:
     post:
       tags:
@@ -6207,6 +6251,8 @@ components:
       properties:
         connection_id:
           type: string
+          maxLength: 200
+          pattern: ^[\w.-]+$
           title: Connection Id
         conn_type:
           type: string
@@ -6252,6 +6298,18 @@ components:
       - conn_type
       title: ConnectionBody
       description: Connection Serializer for requests body.
+    ConnectionBulkBody:
+      properties:
+        connections:
+          items:
+            $ref: '#/components/schemas/ConnectionBody'
+          type: array
+          title: Connections
+      type: object
+      required:
+      - connections
+      title: ConnectionBulkBody
+      description: Connections Serializer for requests body.
     ConnectionCollectionResponse:
       properties:
         connections:
diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py 
b/airflow/api_fastapi/core_api/routes/public/connections.py
index e2236addeda..37e94a98974 100644
--- a/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 import os
-from typing import Annotated
+from typing import Annotated, cast
 
 from fastapi import Depends, HTTPException, Query, status
 from sqlalchemy import select
@@ -27,6 +27,7 @@ from airflow.api_fastapi.common.parameters import QueryLimit, 
QueryOffset, SortP
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.core_api.datamodels.connections import (
     ConnectionBody,
+    ConnectionBulkBody,
     ConnectionCollectionResponse,
     ConnectionResponse,
     ConnectionTestResponse,
@@ -35,7 +36,6 @@ from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_
 from airflow.configuration import conf
 from airflow.models import Connection
 from airflow.secrets.environment_variables import CONN_ENV_PREFIX
-from airflow.utils import helpers
 from airflow.utils.strings import get_random_string
 
 connections_router = AirflowRouter(tags=["Connection"], prefix="/connections")
@@ -126,24 +126,29 @@ def post_connection(
     session: SessionDep,
 ) -> ConnectionResponse:
     """Create connection entry."""
-    try:
-        helpers.validate_key(post_body.connection_id, max_length=200)
-    except Exception as e:
-        raise HTTPException(status.HTTP_400_BAD_REQUEST, f"{e}")
-
-    connection = 
session.scalar(select(Connection).filter_by(conn_id=post_body.connection_id))
-    if connection is not None:
-        raise HTTPException(
-            status.HTTP_409_CONFLICT,
-            f"Connection with connection_id: `{post_body.connection_id}` 
already exists",
-        )
-
     connection = Connection(**post_body.model_dump(by_alias=True))
     session.add(connection)
-
     return connection
 
 
+@connections_router.post(
+    "/bulk",
+    status_code=status.HTTP_201_CREATED,
+    responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]),
+)
+def post_connections(
+    post_body: ConnectionBulkBody,
+    session: SessionDep,
+) -> ConnectionCollectionResponse:
+    """Create connection entry."""
+    connections = [Connection(**body.model_dump(by_alias=True)) for body in 
post_body.connections]
+    session.add_all(connections)
+    return ConnectionCollectionResponse(
+        connections=cast(list[ConnectionResponse], connections),
+        total_entries=len(connections),
+    )
+
+
 @connections_router.patch(
     "/{connection_id}",
     responses=create_openapi_http_exception_doc(
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index f12eefb8322..f7e3576019e 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1594,6 +1594,9 @@ export type BackfillServiceCreateBackfillMutationResult = 
Awaited<
 export type ConnectionServicePostConnectionMutationResult = Awaited<
   ReturnType<typeof ConnectionService.postConnection>
 >;
+export type ConnectionServicePostConnectionsMutationResult = Awaited<
+  ReturnType<typeof ConnectionService.postConnections>
+>;
 export type ConnectionServiceTestConnectionMutationResult = Awaited<
   ReturnType<typeof ConnectionService.testConnection>
 >;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index 01e04aecdde..3992c3a2edf 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -38,6 +38,7 @@ import {
   BackfillPostBody,
   ClearTaskInstancesBody,
   ConnectionBody,
+  ConnectionBulkBody,
   CreateAssetEventsBody,
   DAGPatchBody,
   DAGRunClearBody,
@@ -2684,6 +2685,45 @@ export const useConnectionServicePostConnection = <
       }) as unknown as Promise<TData>,
     ...options,
   });
+/**
+ * Post Connections
+ * Create connection entry.
+ * @param data The data for the request.
+ * @param data.requestBody
+ * @returns ConnectionCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useConnectionServicePostConnections = <
+  TData = Common.ConnectionServicePostConnectionsMutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        requestBody: ConnectionBulkBody;
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      requestBody: ConnectionBulkBody;
+    },
+    TContext
+  >({
+    mutationFn: ({ requestBody }) =>
+      ConnectionService.postConnections({
+        requestBody,
+      }) as unknown as Promise<TData>,
+    ...options,
+  });
 /**
  * Test Connection
  * Test an API connection.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index c2cf77baab6..503910d75ad 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -750,6 +750,8 @@ export const $ConnectionBody = {
   properties: {
     connection_id: {
       type: "string",
+      maxLength: 200,
+      pattern: "^[\\w.-]+$",
       title: "Connection Id",
     },
     conn_type: {
@@ -840,6 +842,22 @@ export const $ConnectionBody = {
   description: "Connection Serializer for requests body.",
 } as const;
 
+export const $ConnectionBulkBody = {
+  properties: {
+    connections: {
+      items: {
+        $ref: "#/components/schemas/ConnectionBody",
+      },
+      type: "array",
+      title: "Connections",
+    },
+  },
+  type: "object",
+  required: ["connections"],
+  title: "ConnectionBulkBody",
+  description: "Connections Serializer for requests body.",
+} as const;
+
 export const $ConnectionCollectionResponse = {
   properties: {
     connections: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index ee949d070c8..c5e494fe119 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -58,6 +58,8 @@ import type {
   GetConnectionsResponse,
   PostConnectionData,
   PostConnectionResponse,
+  PostConnectionsData,
+  PostConnectionsResponse,
   TestConnectionData,
   TestConnectionResponse,
   GetDagRunData,
@@ -997,6 +999,31 @@ export class ConnectionService {
     });
   }
 
+  /**
+   * Post Connections
+   * Create connection entry.
+   * @param data The data for the request.
+   * @param data.requestBody
+   * @returns ConnectionCollectionResponse Successful Response
+   * @throws ApiError
+   */
+  public static postConnections(
+    data: PostConnectionsData,
+  ): CancelablePromise<PostConnectionsResponse> {
+    return __request(OpenAPI, {
+      method: "POST",
+      url: "/public/connections/bulk",
+      body: data.requestBody,
+      mediaType: "application/json",
+      errors: {
+        401: "Unauthorized",
+        403: "Forbidden",
+        409: "Conflict",
+        422: "Validation Error",
+      },
+    });
+  }
+
   /**
    * Test Connection
    * Test an API connection.
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index dcb3ec94f95..48861f2bf42 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -214,6 +214,13 @@ export type ConnectionBody = {
   extra?: string | null;
 };
 
+/**
+ * Connections Serializer for requests body.
+ */
+export type ConnectionBulkBody = {
+  connections: Array<ConnectionBody>;
+};
+
 /**
  * Connection Collection serializer for responses.
  */
@@ -1499,6 +1506,12 @@ export type PostConnectionData = {
 
 export type PostConnectionResponse = ConnectionResponse;
 
+export type PostConnectionsData = {
+  requestBody: ConnectionBulkBody;
+};
+
+export type PostConnectionsResponse = ConnectionCollectionResponse;
+
 export type TestConnectionData = {
   requestBody: ConnectionBody;
 };
@@ -2766,6 +2779,33 @@ export type $OpenApiTs = {
       };
     };
   };
+  "/public/connections/bulk": {
+    post: {
+      req: PostConnectionsData;
+      res: {
+        /**
+         * Successful Response
+         */
+        201: ConnectionCollectionResponse;
+        /**
+         * Unauthorized
+         */
+        401: HTTPExceptionResponse;
+        /**
+         * Forbidden
+         */
+        403: HTTPExceptionResponse;
+        /**
+         * Conflict
+         */
+        409: HTTPExceptionResponse;
+        /**
+         * Validation Error
+         */
+        422: HTTPValidationError;
+      };
+    };
+  };
   "/public/connections/test": {
     post: {
       req: TestConnectionData;
diff --git a/tests/api_fastapi/core_api/routes/public/test_connections.py 
b/tests/api_fastapi/core_api/routes/public/test_connections.py
index 0410c307fb6..2e090081d2f 100644
--- a/tests/api_fastapi/core_api/routes/public/test_connections.py
+++ b/tests/api_fastapi/core_api/routes/public/test_connections.py
@@ -45,6 +45,10 @@ TEST_CONN_PORT_2 = 8081
 TEST_CONN_LOGIN_2 = "some_login_b"
 
 
+TEST_CONN_ID_3 = "test_connection_id_3"
+TEST_CONN_TYPE_3 = "test_type_3"
+
+
 @provide_session
 def _create_connection(session) -> None:
     connection_model = Connection(
@@ -199,7 +203,7 @@ class TestPostConnection(TestConnectionEndpoint):
             },
         ],
     )
-    def test_post_should_respond_200(self, test_client, session, body):
+    def test_post_should_respond_201(self, test_client, session, body):
         response = test_client.post("/public/connections", json=body)
         assert response.status_code == 201
         connection = session.query(Connection).all()
@@ -214,13 +218,20 @@ class TestPostConnection(TestConnectionEndpoint):
             {"connection_id": "iam_not@#$_connection_id", "conn_type": 
TEST_CONN_TYPE},
         ],
     )
-    def test_post_should_respond_400_for_invalid_conn_id(self, test_client, 
body):
+    def test_post_should_respond_422_for_invalid_conn_id(self, test_client, 
body):
         response = test_client.post("/public/connections", json=body)
-        assert response.status_code == 400
-        connection_id = body["connection_id"]
+        assert response.status_code == 422
+        # This regex is used for validation in ConnectionBody
         assert response.json() == {
-            "detail": f"The key '{connection_id}' has to be made of "
-            "alphanumeric characters, dashes, dots and underscores 
exclusively",
+            "detail": [
+                {
+                    "ctx": {"pattern": r"^[\w.-]+$"},
+                    "input": f"{body['connection_id']}",
+                    "loc": ["body", "connection_id"],
+                    "msg": "String should match pattern '^[\\w.-]+$'",
+                    "type": "string_pattern_mismatch",
+                }
+            ]
         }
 
     @pytest.mark.parametrize(
@@ -236,7 +247,7 @@ class TestPostConnection(TestConnectionEndpoint):
         response = test_client.post("/public/connections", json=body)
         assert response.status_code == 409
         assert response.json() == {
-            "detail": f"Connection with connection_id: `{TEST_CONN_ID}` 
already exists",
+            "detail": "Unique constraint violation",
         }
 
     @pytest.mark.enable_redact
@@ -298,6 +309,180 @@ class TestPostConnection(TestConnectionEndpoint):
         assert response.json() == expected_response
 
 
+class TestPostConnections(TestConnectionEndpoint):
+    @pytest.mark.parametrize(
+        "body",
+        [
+            {
+                "connections": [
+                    {"connection_id": TEST_CONN_ID, "conn_type": 
TEST_CONN_TYPE},
+                    {"connection_id": TEST_CONN_ID_2, "conn_type": 
TEST_CONN_TYPE_2, "extra": None},
+                ]
+            },
+            {
+                "connections": [
+                    {"connection_id": TEST_CONN_ID, "conn_type": 
TEST_CONN_TYPE, "extra": "{}"},
+                    {
+                        "connection_id": TEST_CONN_ID_2,
+                        "conn_type": TEST_CONN_TYPE_2,
+                        "extra": '{"key": "value"}',
+                    },
+                    {
+                        "connection_id": TEST_CONN_ID_3,
+                        "conn_type": TEST_CONN_ID_3,
+                        "description": "test_description",
+                        "host": "test_host",
+                        "login": "test_login",
+                        "schema": "test_schema",
+                        "port": 8080,
+                        "extra": '{"key": "value"}',
+                    },
+                ]
+            },
+        ],
+    )
+    def test_post_should_respond_201(self, test_client, session, body):
+        response = test_client.post("/public/connections/bulk", json=body)
+        assert response.status_code == 201
+        connection = session.query(Connection).all()
+        assert len(connection) == len(body["connections"])
+
+    @pytest.mark.parametrize(
+        "body",
+        [
+            {
+                "connections": [
+                    {"connection_id": "****", "conn_type": TEST_CONN_TYPE},
+                    {"connection_id": "test()", "conn_type": TEST_CONN_TYPE},
+                ]
+            },
+            {
+                "connections": [
+                    {"connection_id": "this_^$#is_invalid", "conn_type": 
TEST_CONN_TYPE},
+                    {"connection_id": "iam_not@#$_connection_id", "conn_type": 
TEST_CONN_TYPE},
+                ]
+            },
+        ],
+    )
+    def test_post_should_respond_422_for_invalid_conn_id(self, test_client, 
body):
+        response = test_client.post("/public/connections/bulk", json=body)
+        assert response.status_code == 422
+        expected_response_detail = [
+            {
+                "ctx": {"pattern": r"^[\w.-]+$"},
+                "input": f"{body['connections'][conn_index]['connection_id']}",
+                "loc": ["body", "connections", conn_index, "connection_id"],
+                "msg": "String should match pattern '^[\\w.-]+$'",
+                "type": "string_pattern_mismatch",
+            }
+            for conn_index in range(len(body["connections"]))
+        ]
+        assert response.json() == {"detail": expected_response_detail}
+
+    @pytest.mark.parametrize(
+        "body",
+        [
+            {
+                "connections": [
+                    {"connection_id": TEST_CONN_ID, "conn_type": 
TEST_CONN_TYPE},
+                    {"connection_id": TEST_CONN_ID_2, "conn_type": 
TEST_CONN_TYPE_2, "extra": None},
+                ]
+            },
+        ],
+    )
+    def test_post_should_respond_already_exist(self, test_client, body):
+        response = test_client.post("/public/connections/bulk", json=body)
+        assert response.status_code == 201
+        # Another request
+        response = test_client.post("/public/connections/bulk", json=body)
+        assert response.status_code == 409
+        assert response.json() == {
+            "detail": "Unique constraint violation",
+        }
+
+    @pytest.mark.enable_redact
+    @pytest.mark.parametrize(
+        "body, expected_response",
+        [
+            (
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE,
+                            "password": "test-password",
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_2,
+                            "conn_type": TEST_CONN_TYPE_2,
+                            "password": "?>@#+!_%()#",
+                        },
+                    ]
+                },
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE,
+                            "description": None,
+                            "extra": None,
+                            "host": None,
+                            "login": None,
+                            "password": "***",
+                            "port": None,
+                            "schema": None,
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_2,
+                            "conn_type": TEST_CONN_TYPE_2,
+                            "description": None,
+                            "extra": None,
+                            "host": None,
+                            "login": None,
+                            "password": "***",
+                            "port": None,
+                            "schema": None,
+                        },
+                    ],
+                    "total_entries": 2,
+                },
+            ),
+            (
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE,
+                            "password": "A!rF|0wi$aw3s0m3",
+                            "extra": '{"password": "test-password"}',
+                        }
+                    ]
+                },
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE,
+                            "description": None,
+                            "extra": '{"password": "***"}',
+                            "host": None,
+                            "login": None,
+                            "password": "***",
+                            "port": None,
+                            "schema": None,
+                        },
+                    ],
+                    "total_entries": 1,
+                },
+            ),
+        ],
+    )
+    def test_post_should_response_201_redacted_password(self, test_client, 
body, expected_response):
+        response = test_client.post("/public/connections/bulk", json=body)
+        assert response.status_code == 201
+        assert response.json() == expected_response
+
+
 class TestPatchConnection(TestConnectionEndpoint):
     @pytest.mark.parametrize(
         "body",

Reply via email to