This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c81610b7425 AIP-81 Add Overwrite for Bulk Insert Connection API 
(#45396)
c81610b7425 is described below

commit c81610b7425ee65b7118e2256d5a9efe087dd080
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Sun Jan 5 01:05:34 2025 +0800

    AIP-81 Add Overwrite for Bulk Insert Connection API (#45396)
---
 .../api_fastapi/core_api/datamodels/connections.py |   1 +
 .../api_fastapi/core_api/openapi/v1-generated.yaml |  22 +-
 .../core_api/routes/public/connections.py          |  44 +++-
 airflow/ui/openapi-gen/queries/common.ts           |   6 +-
 airflow/ui/openapi-gen/queries/queries.ts          |  75 +++---
 airflow/ui/openapi-gen/requests/schemas.gen.ts     |  12 +
 airflow/ui/openapi-gen/requests/services.gen.ts    |  13 +-
 airflow/ui/openapi-gen/requests/types.gen.ts       |  15 +-
 .../core_api/routes/public/test_connections.py     | 269 ++++++++++++++++++++-
 9 files changed, 384 insertions(+), 73 deletions(-)

diff --git a/airflow/api_fastapi/core_api/datamodels/connections.py 
b/airflow/api_fastapi/core_api/datamodels/connections.py
index 98ac5389e5d..04f47841e60 100644
--- a/airflow/api_fastapi/core_api/datamodels/connections.py
+++ b/airflow/api_fastapi/core_api/datamodels/connections.py
@@ -94,3 +94,4 @@ class ConnectionBulkBody(BaseModel):
     """Connections Serializer for requests body."""
 
     connections: list[ConnectionBody]
+    overwrite: bool | None = Field(default=False)
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index a1177ec2409..ceda67d98c1 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1738,12 +1738,12 @@ paths:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
   /public/connections/bulk:
-    post:
+    put:
       tags:
       - Connection
-      summary: Post Connections
+      summary: Put Connections
       description: Create connection entry.
-      operationId: post_connections
+      operationId: put_connections
       requestBody:
         content:
           application/json:
@@ -1751,8 +1751,8 @@ paths:
               $ref: '#/components/schemas/ConnectionBulkBody'
         required: true
       responses:
-        '201':
-          description: Successful Response
+        '200':
+          description: Created with overwrite
           content:
             application/json:
               schema:
@@ -1775,6 +1775,12 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPExceptionResponse'
+        '201':
+          description: Created
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/ConnectionCollectionResponse'
         '422':
           description: Validation Error
           content:
@@ -6761,6 +6767,12 @@ components:
             $ref: '#/components/schemas/ConnectionBody'
           type: array
           title: Connections
+        overwrite:
+          anyOf:
+          - type: boolean
+          - type: 'null'
+          title: Overwrite
+          default: false
       type: object
       required:
       - connections
diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py 
b/airflow/api_fastapi/core_api/routes/public/connections.py
index 61fc76832c6..081fe7b0dd5 100644
--- a/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 import os
 from typing import Annotated, cast
 
-from fastapi import Depends, HTTPException, Query, status
+from fastapi import Depends, HTTPException, Query, Response, status
 from fastapi.exceptions import RequestValidationError
 from pydantic import ValidationError
 from sqlalchemy import select
@@ -135,18 +135,48 @@ def post_connection(
     return connection
 
 
-@connections_router.post(
+@connections_router.put(
     "/bulk",
-    status_code=status.HTTP_201_CREATED,
-    responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]),
+    responses={
+        **create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]),
+        status.HTTP_201_CREATED: {
+            "description": "Created",
+            "model": ConnectionCollectionResponse,
+        },
+        status.HTTP_200_OK: {
+            "description": "Created with overwrite",
+            "model": ConnectionCollectionResponse,
+        },
+    },
 )
-def post_connections(
+def put_connections(
+    response: Response,
     post_body: ConnectionBulkBody,
     session: SessionDep,
 ) -> ConnectionCollectionResponse:
     """Create connection entry."""
-    connections = [Connection(**body.model_dump(by_alias=True)) for body in 
post_body.connections]
-    session.add_all(connections)
+    response.status_code = status.HTTP_201_CREATED if not post_body.overwrite 
else status.HTTP_200_OK
+    connections: list[Connection]
+    if not post_body.overwrite:
+        connections = [Connection(**body.model_dump(by_alias=True)) for body 
in post_body.connections]
+        session.add_all(connections)
+    else:
+        connection_ids = [conn.connection_id for conn in post_body.connections]
+        existed_connections = session.execute(
+            select(Connection).filter(Connection.conn_id.in_(connection_ids))
+        ).scalars()
+        existed_connections_dict = {conn.conn_id: conn for conn in 
existed_connections}
+        connections = []
+        # if conn_id exists, update the corresponding connection, else add a 
new connection
+        for body in post_body.connections:
+            if body.connection_id in existed_connections_dict:
+                connection = existed_connections_dict[body.connection_id]
+                for key, val in body.model_dump(by_alias=True).items():
+                    setattr(connection, key, val)
+                connections.append(connection)
+            else:
+                
connections.append(Connection(**body.model_dump(by_alias=True)))
+        session.add_all(connections)
     return ConnectionCollectionResponse(
         connections=cast(list[ConnectionResponse], connections),
         total_entries=len(connections),
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index 06633922578..82afa776c7a 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1599,9 +1599,6 @@ export type BackfillServiceCreateBackfillMutationResult = 
Awaited<
 export type ConnectionServicePostConnectionMutationResult = Awaited<
   ReturnType<typeof ConnectionService.postConnection>
 >;
-export type ConnectionServicePostConnectionsMutationResult = Awaited<
-  ReturnType<typeof ConnectionService.postConnections>
->;
 export type ConnectionServiceTestConnectionMutationResult = Awaited<
   ReturnType<typeof ConnectionService.testConnection>
 >;
@@ -1635,6 +1632,9 @@ export type BackfillServiceUnpauseBackfillMutationResult 
= Awaited<
 export type BackfillServiceCancelBackfillMutationResult = Awaited<
   ReturnType<typeof BackfillService.cancelBackfill>
 >;
+export type ConnectionServicePutConnectionsMutationResult = Awaited<
+  ReturnType<typeof ConnectionService.putConnections>
+>;
 export type DagParsingServiceReparseDagFileMutationResult = Awaited<
   ReturnType<typeof DagParsingService.reparseDagFile>
 >;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index 4e4d39f96ef..e8908553c34 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -2747,43 +2747,6 @@ export const useConnectionServicePostConnection = <
       ConnectionService.postConnection({ requestBody }) as unknown as 
Promise<TData>,
     ...options,
   });
-/**
- * Post Connections
- * Create connection entry.
- * @param data The data for the request.
- * @param data.requestBody
- * @returns ConnectionCollectionResponse Successful Response
- * @throws ApiError
- */
-export const useConnectionServicePostConnections = <
-  TData = Common.ConnectionServicePostConnectionsMutationResult,
-  TError = unknown,
-  TContext = unknown,
->(
-  options?: Omit<
-    UseMutationOptions<
-      TData,
-      TError,
-      {
-        requestBody: ConnectionBulkBody;
-      },
-      TContext
-    >,
-    "mutationFn"
-  >,
-) =>
-  useMutation<
-    TData,
-    TError,
-    {
-      requestBody: ConnectionBulkBody;
-    },
-    TContext
-  >({
-    mutationFn: ({ requestBody }) =>
-      ConnectionService.postConnections({ requestBody }) as unknown as 
Promise<TData>,
-    ...options,
-  });
 /**
  * Test Connection
  * Test an API connection.
@@ -3291,6 +3254,44 @@ export const useBackfillServiceCancelBackfill = <
       BackfillService.cancelBackfill({ backfillId }) as unknown as 
Promise<TData>,
     ...options,
   });
+/**
+ * Put Connections
+ * Create connection entry.
+ * @param data The data for the request.
+ * @param data.requestBody
+ * @returns ConnectionCollectionResponse Created with overwrite
+ * @returns ConnectionCollectionResponse Created
+ * @throws ApiError
+ */
+export const useConnectionServicePutConnections = <
+  TData = Common.ConnectionServicePutConnectionsMutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        requestBody: ConnectionBulkBody;
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      requestBody: ConnectionBulkBody;
+    },
+    TContext
+  >({
+    mutationFn: ({ requestBody }) =>
+      ConnectionService.putConnections({ requestBody }) as unknown as 
Promise<TData>,
+    ...options,
+  });
 /**
  * Reparse Dag File
  * Request re-parsing a DAG file.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 1e61380d82d..b714cf07993 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -877,6 +877,18 @@ export const $ConnectionBulkBody = {
       type: "array",
       title: "Connections",
     },
+    overwrite: {
+      anyOf: [
+        {
+          type: "boolean",
+        },
+        {
+          type: "null",
+        },
+      ],
+      title: "Overwrite",
+      default: false,
+    },
   },
   type: "object",
   required: ["connections"],
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 998767e1f62..c16005edcf2 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -66,8 +66,8 @@ import type {
   GetConnectionsResponse,
   PostConnectionData,
   PostConnectionResponse,
-  PostConnectionsData,
-  PostConnectionsResponse,
+  PutConnectionsData,
+  PutConnectionsResponse,
   TestConnectionData,
   TestConnectionResponse,
   GetDagRunData,
@@ -1108,16 +1108,17 @@ export class ConnectionService {
   }
 
   /**
-   * Post Connections
+   * Put Connections
    * Create connection entry.
    * @param data The data for the request.
    * @param data.requestBody
-   * @returns ConnectionCollectionResponse Successful Response
+   * @returns ConnectionCollectionResponse Created with overwrite
+   * @returns ConnectionCollectionResponse Created
    * @throws ApiError
    */
-  public static postConnections(data: PostConnectionsData): 
CancelablePromise<PostConnectionsResponse> {
+  public static putConnections(data: PutConnectionsData): 
CancelablePromise<PutConnectionsResponse> {
     return __request(OpenAPI, {
-      method: "POST",
+      method: "PUT",
       url: "/public/connections/bulk",
       body: data.requestBody,
       mediaType: "application/json",
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index cefee5ea908..70c3f532429 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -231,6 +231,7 @@ export type ConnectionBody = {
  */
 export type ConnectionBulkBody = {
   connections: Array<ConnectionBody>;
+  overwrite?: boolean | null;
 };
 
 /**
@@ -1617,11 +1618,11 @@ export type PostConnectionData = {
 
 export type PostConnectionResponse = ConnectionResponse;
 
-export type PostConnectionsData = {
+export type PutConnectionsData = {
   requestBody: ConnectionBulkBody;
 };
 
-export type PostConnectionsResponse = ConnectionCollectionResponse;
+export type PutConnectionsResponse = ConnectionCollectionResponse;
 
 export type TestConnectionData = {
   requestBody: ConnectionBody;
@@ -2989,11 +2990,15 @@ export type $OpenApiTs = {
     };
   };
   "/public/connections/bulk": {
-    post: {
-      req: PostConnectionsData;
+    put: {
+      req: PutConnectionsData;
       res: {
         /**
-         * Successful Response
+         * Created with overwrite
+         */
+        200: ConnectionCollectionResponse;
+        /**
+         * Created
          */
         201: ConnectionCollectionResponse;
         /**
diff --git a/tests/api_fastapi/core_api/routes/public/test_connections.py 
b/tests/api_fastapi/core_api/routes/public/test_connections.py
index 0cada4b9429..dd5732a7154 100644
--- a/tests/api_fastapi/core_api/routes/public/test_connections.py
+++ b/tests/api_fastapi/core_api/routes/public/test_connections.py
@@ -309,7 +309,7 @@ class TestPostConnection(TestConnectionEndpoint):
         assert response.json() == expected_response
 
 
-class TestPostConnections(TestConnectionEndpoint):
+class TestPutConnections(TestConnectionEndpoint):
     @pytest.mark.parametrize(
         "body",
         [
@@ -341,12 +341,170 @@ class TestPostConnections(TestConnectionEndpoint):
             },
         ],
     )
-    def test_post_should_respond_201(self, test_client, session, body):
-        response = test_client.post("/public/connections/bulk", json=body)
+    def test_put_should_respond_201(self, test_client, session, body):
+        response = test_client.put("/public/connections/bulk", json=body)
         assert response.status_code == 201
         connection = session.query(Connection).all()
         assert len(connection) == len(body["connections"])
 
+    @pytest.mark.parametrize(
+        "first_request_body, first_expected_entries_count, 
second_request_body, second_expected_entries_count, 
second_request_expected_response",
+        [
+            pytest.param(
+                {
+                    "connections": [
+                        {"connection_id": TEST_CONN_ID, "conn_type": 
TEST_CONN_TYPE},
+                        {"connection_id": TEST_CONN_ID_2, "conn_type": 
TEST_CONN_TYPE_2, "extra": None},
+                    ]
+                },
+                2,
+                {
+                    "connections": [
+                        {"connection_id": TEST_CONN_ID, "conn_type": 
f"new_{TEST_CONN_TYPE}"},
+                        {
+                            "connection_id": TEST_CONN_ID_3,
+                            "conn_type": TEST_CONN_TYPE_3,
+                            "port": 8080,
+                            "schema": "test_schema",
+                        },
+                    ],
+                    "overwrite": True,
+                },
+                3,
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": f"new_{TEST_CONN_TYPE}",
+                            "description": None,
+                            "extra": None,
+                            "host": None,
+                            "login": None,
+                            "password": None,
+                            "port": None,
+                            "schema": None,
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_3,
+                            "conn_type": TEST_CONN_TYPE_3,
+                            "description": None,
+                            "extra": None,
+                            "host": None,
+                            "login": None,
+                            "password": None,
+                            "port": 8080,
+                            "schema": "test_schema",
+                        },
+                    ],
+                    "total_entries": 2,
+                },
+                id="overwrite_with_partial_existing_request_body",
+            ),
+            pytest.param(
+                {
+                    "connections": [
+                        {"connection_id": TEST_CONN_ID, "conn_type": 
TEST_CONN_TYPE, "extra": "{}"},
+                        {
+                            "connection_id": TEST_CONN_ID_2,
+                            "conn_type": TEST_CONN_TYPE_2,
+                            "extra": '{"key": "value"}',
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_3,
+                            "conn_type": TEST_CONN_ID_3,
+                            "description": "test_description",
+                            "host": "test_host",
+                            "login": "test_login",
+                            "schema": "test_schema",
+                            "port": 8080,
+                            "extra": '{"key": "value"}',
+                        },
+                    ]
+                },
+                3,
+                {
+                    "connections": [
+                        {"connection_id": TEST_CONN_ID, "conn_type": 
f"new_{TEST_CONN_TYPE}", "extra": "{}"},
+                        {
+                            "connection_id": TEST_CONN_ID_2,
+                            "conn_type": f"new_{TEST_CONN_TYPE_2}",
+                            "extra": '{"key": "new_value"}',
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_3,
+                            "conn_type": TEST_CONN_ID_3,
+                            "description": "new_test_description",
+                            "host": "new_test_host",
+                            "login": "new_test_login",
+                            "schema": "new_test_schema",
+                            "port": 28080,
+                            "extra": '{"key": "new_value"}',
+                        },
+                    ],
+                    "overwrite": True,
+                },
+                3,
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": f"new_{TEST_CONN_TYPE}",
+                            "description": None,
+                            "extra": "{}",
+                            "host": None,
+                            "login": None,
+                            "password": None,
+                            "port": None,
+                            "schema": None,
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_2,
+                            "conn_type": f"new_{TEST_CONN_TYPE_2}",
+                            "description": None,
+                            "extra": '{"key": "new_value"}',
+                            "host": None,
+                            "login": None,
+                            "password": None,
+                            "port": None,
+                            "schema": None,
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_3,
+                            "conn_type": TEST_CONN_ID_3,
+                            "description": "new_test_description",
+                            "host": "new_test_host",
+                            "login": "new_test_login",
+                            "password": None,
+                            "schema": "new_test_schema",
+                            "port": 28080,
+                            "extra": '{"key": "new_value"}',
+                        },
+                    ],
+                    "total_entries": 3,
+                },
+                id="overwrite_with_extra_request_body",
+            ),
+        ],
+    )
+    def test_put_should_respond_200_overwrite(
+        self,
+        test_client,
+        session,
+        first_request_body,
+        first_expected_entries_count,
+        second_request_body,
+        second_expected_entries_count,
+        second_request_expected_response,
+    ):
+        response = test_client.put("/public/connections/bulk", 
json=first_request_body)
+        assert response.status_code == 201
+        assert session.query(Connection).count() == 
first_expected_entries_count
+        # Another request
+        response = test_client.put("/public/connections/bulk", 
json=second_request_body)
+        assert response.status_code == 200
+        assert response.json() == second_request_expected_response
+        assert session.query(Connection).count() == 
second_expected_entries_count
+
     @pytest.mark.parametrize(
         "body",
         [
@@ -364,8 +522,8 @@ class TestPostConnections(TestConnectionEndpoint):
             },
         ],
     )
-    def test_post_should_respond_422_for_invalid_conn_id(self, test_client, 
body):
-        response = test_client.post("/public/connections/bulk", json=body)
+    def test_put_should_respond_422_for_invalid_conn_id(self, test_client, 
body):
+        response = test_client.put("/public/connections/bulk", json=body)
         assert response.status_code == 422
         expected_response_detail = [
             {
@@ -390,11 +548,11 @@ class TestPostConnections(TestConnectionEndpoint):
             },
         ],
     )
-    def test_post_should_respond_already_exist(self, test_client, body):
-        response = test_client.post("/public/connections/bulk", json=body)
+    def test_put_should_respond_409_already_exist(self, test_client, body):
+        response = test_client.put("/public/connections/bulk", json=body)
         assert response.status_code == 201
         # Another request
-        response = test_client.post("/public/connections/bulk", json=body)
+        response = test_client.put("/public/connections/bulk", json=body)
         assert response.status_code == 409
         response_json = response.json()
         assert "detail" in response_json
@@ -477,11 +635,102 @@ class TestPostConnections(TestConnectionEndpoint):
             ),
         ],
     )
-    def test_post_should_response_201_redacted_password(self, test_client, 
body, expected_response):
-        response = test_client.post("/public/connections/bulk", json=body)
+    def test_put_should_response_201_redacted_password(self, test_client, 
body, expected_response):
+        response = test_client.put("/public/connections/bulk", json=body)
         assert response.status_code == 201
         assert response.json() == expected_response
 
+    @pytest.mark.enable_redact
+    @pytest.mark.parametrize(
+        "body, expected_response",
+        [
+            pytest.param(
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE_2,
+                            "password": "new-test-password",
+                            "description": "new-description",
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_2,
+                            "conn_type": TEST_CONN_TYPE,
+                            "password": "new-?>@#+!_%()#",
+                            "port": 80,
+                        },
+                    ],
+                    "overwrite": True,
+                },
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE_2,
+                            "description": "new-description",
+                            "extra": None,
+                            "host": None,
+                            "login": None,
+                            "password": "***",
+                            "port": None,
+                            "schema": None,
+                        },
+                        {
+                            "connection_id": TEST_CONN_ID_2,
+                            "conn_type": TEST_CONN_TYPE,
+                            "description": None,
+                            "extra": None,
+                            "host": None,
+                            "login": None,
+                            "password": "***",
+                            "port": 80,
+                            "schema": None,
+                        },
+                    ],
+                    "total_entries": 2,
+                },
+                id="redact_password_with_overwrite",
+            ),
+            pytest.param(
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE,
+                            "password": "A!rF|0wi$aw3s0m3",
+                            "extra": '{"password": "test-password"}',
+                        }
+                    ],
+                    "overwrite": True,
+                },
+                {
+                    "connections": [
+                        {
+                            "connection_id": TEST_CONN_ID,
+                            "conn_type": TEST_CONN_TYPE,
+                            "description": None,
+                            "extra": '{"password": "***"}',
+                            "host": None,
+                            "login": None,
+                            "password": "***",
+                            "port": None,
+                            "schema": None,
+                        },
+                    ],
+                    "total_entries": 1,
+                },
+                id="redact_extra_with_overwrite",
+            ),
+        ],
+    )
+    def test_put_should_response_200_redacted_password_with_overwrite(
+        self, test_client, body, expected_response
+    ):
+        self.create_connections()
+        response = test_client.put("/public/connections/bulk", json=body)
+        assert response.status_code == 200
+        assert response.json() == expected_response
+
 
 class TestPatchConnection(TestConnectionEndpoint):
     @pytest.mark.parametrize(

Reply via email to