This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b66ce98264d Connections API is able to patch all fields correctly 
(#48109)
b66ce98264d is described below

commit b66ce98264d6a65e8827225b7a6ca36a7e989fef
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Apr 1 00:17:48 2025 +0200

    Connections API is able to patch all fields correctly (#48109)
    
    * Connections API is able to patch all fields correctly
    
    * Fix implementation and pytest, apply to bulk service as well
    
    * Consolidate Pydantic to ORM update in connections API
    
    * Review feedback - genralize via model_dump where possible
    
    * Fix pytest
    
    * Review feedback
    
    * Review feedback, extended pytests
    
    * Add a explicit test if None is passed in body
---
 .../core_api/routes/public/connections.py          |  27 +--
 .../core_api/services/public/connections.py        |  40 +++-
 airflow-core/src/airflow/models/connection.py      |   2 +-
 .../core_api/routes/public/test_connections.py     | 239 ++++++++++++++++++---
 4 files changed, 260 insertions(+), 48 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
index 966e59e587b..b4a300d1393 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -44,7 +44,10 @@ from airflow.api_fastapi.core_api.datamodels.connections 
import (
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
 from airflow.api_fastapi.core_api.security import requires_access_connection
-from airflow.api_fastapi.core_api.services.public.connections import 
BulkConnectionService
+from airflow.api_fastapi.core_api.services.public.connections import (
+    BulkConnectionService,
+    update_orm_from_pydantic,
+)
 from airflow.api_fastapi.logging.decorators import action_logging
 from airflow.configuration import conf
 from airflow.models import Connection
@@ -187,29 +190,19 @@ def patch_connection(
             "The connection_id in the request body does not match the URL 
parameter",
         )
 
-    non_update_fields = {"connection_id", "conn_id"}
-    connection = 
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
+    connection: Connection = 
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
 
     if connection is None:
         raise HTTPException(
             status.HTTP_404_NOT_FOUND, f"The Connection with connection_id: 
`{connection_id}` was not found"
         )
 
-    fields_to_update = patch_body.model_fields_set
-
-    if update_mask:
-        fields_to_update = fields_to_update.intersection(update_mask)
-    else:
-        try:
-            ConnectionBody(**patch_body.model_dump())
-        except ValidationError as e:
-            raise RequestValidationError(errors=e.errors())
-
-    data = patch_body.model_dump(include=fields_to_update - non_update_fields, 
by_alias=True)
-
-    for key, val in data.items():
-        setattr(connection, key, val)
+    try:
+        ConnectionBody(**patch_body.model_dump())
+    except ValidationError as e:
+        raise RequestValidationError(errors=e.errors())
 
+    update_orm_from_pydantic(connection, patch_body, update_mask)
     return connection
 
 
diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py
index 7aab305af5c..3883c3f5e84 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py
@@ -34,6 +34,36 @@ from airflow.api_fastapi.core_api.services.public.common 
import BulkService
 from airflow.models.connection import Connection
 
 
+def update_orm_from_pydantic(
+    orm_conn: Connection, pydantic_conn: ConnectionBody, update_mask: 
list[str] | None = None
+) -> None:
+    """Update ORM object from Pydantic object."""
+    # Not all fields match and some need setters, therefore copy partly 
manually via setters
+    non_update_fields = {"connection_id", "conn_id"}
+    setter_fields = {"password", "extra"}
+    fields_set = pydantic_conn.model_fields_set
+    if "schema_" in fields_set:  # Alias is not resolved correctly, need to 
patch
+        fields_set.remove("schema_")
+        fields_set.add("schema")
+    fields_to_update = fields_set - non_update_fields - setter_fields
+    if update_mask:
+        fields_to_update = fields_to_update.intersection(update_mask)
+        print(fields_to_update)
+    conn_data = pydantic_conn.model_dump(by_alias=True)
+    for key, val in conn_data.items():
+        if key in fields_to_update:
+            setattr(orm_conn, key, val)
+
+    if (not update_mask and "password" in pydantic_conn.model_fields_set) or (
+        update_mask and "password" in update_mask
+    ):
+        orm_conn.set_password(pydantic_conn.password)
+    if (not update_mask and "extra" in pydantic_conn.model_fields_set) or (
+        update_mask and "extra" in update_mask
+    ):
+        orm_conn.set_extra(pydantic_conn.extra)
+
+
 class BulkConnectionService(BulkService[ConnectionBody]):
     """Service for handling bulk operations on connections."""
 
@@ -108,12 +138,16 @@ class BulkConnectionService(BulkService[ConnectionBody]):
 
             for connection in action.entities:
                 if connection.connection_id in update_connection_ids:
-                    old_connection = self.session.scalar(
+                    old_connection: Connection = self.session.scalar(
                         select(Connection).filter(Connection.conn_id == 
connection.connection_id).limit(1)
                     )
+                    if old_connection is None:
+                        raise ValidationError(
+                            f"The Connection with connection_id: 
`{connection.connection_id}` was not found"
+                        )
                     ConnectionBody(**connection.model_dump())
-                    for key, val in 
connection.model_dump(by_alias=True).items():
-                        setattr(old_connection, key, val)
+
+                    update_orm_from_pydantic(old_connection, connection)
                     results.success.append(connection.connection_id)
 
         except HTTPException as e:
diff --git a/airflow-core/src/airflow/models/connection.py 
b/airflow-core/src/airflow/models/connection.py
index 92917aaf865..0f32133949e 100644
--- a/airflow-core/src/airflow/models/connection.py
+++ b/airflow-core/src/airflow/models/connection.py
@@ -353,7 +353,7 @@ class Connection(Base, LoggingMixin):
             self._validate_extra(extra_val, self.conn_id)
         return extra_val
 
-    def set_extra(self, value: str):
+    def set_extra(self, value: str | None):
         """Encrypt extra-data and save in object attribute to object."""
         if value:
             self._validate_extra(value, self.conn_id)
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
index f59969bb440..2e32470fc8f 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
@@ -36,6 +36,8 @@ TEST_CONN_DESCRIPTION = "some_description_a"
 TEST_CONN_HOST = "some_host_a"
 TEST_CONN_PORT = 8080
 TEST_CONN_LOGIN = "some_login"
+TEST_CONN_SCHEMA = "https"
+TEST_CONN_EXTRA = '{"extra_key": "extra_value"}'
 
 
 TEST_CONN_ID_2 = "test_connection_id_2"
@@ -350,34 +352,160 @@ class TestPostConnection(TestConnectionEndpoint):
 
 class TestPatchConnection(TestConnectionEndpoint):
     @pytest.mark.parametrize(
-        "body",
+        "body, expected_result",
         [
-            {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"extra": '{"key": "var"}'},
-            {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"host": "test_host_patch"},
-            {
-                "connection_id": TEST_CONN_ID,
-                "conn_type": TEST_CONN_TYPE,
-                "host": "test_host_patch",
-                "port": 80,
-            },
-            {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"login": "test_login_patch"},
-            {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"port": 80},
-            {
-                "connection_id": TEST_CONN_ID,
-                "conn_type": TEST_CONN_TYPE,
-                "port": 80,
-                "login": "test_login_patch",
-            },
+            (
+                {"connection_id": TEST_CONN_ID, "conn_type": "new_type", 
"extra": '{"key": "var"}'},
+                {
+                    "conn_type": "new_type",
+                    "connection_id": TEST_CONN_ID,
+                    "description": TEST_CONN_DESCRIPTION,
+                    "extra": '{"key": "var"}',
+                    "host": TEST_CONN_HOST,
+                    "login": TEST_CONN_LOGIN,
+                    "password": None,
+                    "port": TEST_CONN_PORT,
+                    "schema": None,
+                },
+            ),
+            (
+                {"connection_id": TEST_CONN_ID, "conn_type": "type_patch", 
"host": "test_host_patch"},
+                {
+                    "conn_type": "type_patch",
+                    "connection_id": TEST_CONN_ID,
+                    "description": TEST_CONN_DESCRIPTION,
+                    "extra": None,
+                    "host": "test_host_patch",
+                    "login": TEST_CONN_LOGIN,
+                    "password": None,
+                    "port": TEST_CONN_PORT,
+                    "schema": None,
+                },
+            ),
+            (
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": "surprise",
+                    "host": "test_host_patch",
+                    "port": 80,
+                },
+                {
+                    "conn_type": "surprise",
+                    "connection_id": TEST_CONN_ID,
+                    "description": TEST_CONN_DESCRIPTION,
+                    "extra": None,
+                    "host": "test_host_patch",
+                    "login": TEST_CONN_LOGIN,
+                    "password": None,
+                    "port": 80,
+                    "schema": None,
+                },
+            ),
+            (
+                {"connection_id": TEST_CONN_ID, "conn_type": 
"really_new_type", "login": "test_login_patch"},
+                {
+                    "conn_type": "really_new_type",
+                    "connection_id": TEST_CONN_ID,
+                    "description": TEST_CONN_DESCRIPTION,
+                    "extra": None,
+                    "host": TEST_CONN_HOST,
+                    "login": "test_login_patch",
+                    "password": None,
+                    "port": TEST_CONN_PORT,
+                    "schema": None,
+                },
+            ),
+            (
+                {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"port": 80},
+                {
+                    "conn_type": TEST_CONN_TYPE,
+                    "connection_id": TEST_CONN_ID,
+                    "description": TEST_CONN_DESCRIPTION,
+                    "extra": None,
+                    "host": TEST_CONN_HOST,
+                    "login": TEST_CONN_LOGIN,
+                    "password": None,
+                    "port": 80,
+                    "schema": None,
+                },
+            ),
+            (
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": TEST_CONN_TYPE,
+                    "port": 80,
+                    "login": "test_login_patch",
+                    "password": "test_password_patch",
+                },
+                {
+                    "conn_type": TEST_CONN_TYPE,
+                    "connection_id": TEST_CONN_ID,
+                    "description": TEST_CONN_DESCRIPTION,
+                    "extra": None,
+                    "host": TEST_CONN_HOST,
+                    "login": "test_login_patch",
+                    "password": "test_password_patch",
+                    "port": 80,
+                    "schema": None,
+                },
+            ),
+            (
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": TEST_CONN_TYPE,
+                    "schema": "http_patch",
+                    "extra": '{"extra_key_patch": "extra_value_patch"}',
+                },
+                {
+                    "conn_type": TEST_CONN_TYPE,
+                    "connection_id": TEST_CONN_ID,
+                    "description": TEST_CONN_DESCRIPTION,
+                    "extra": '{"extra_key_patch": "extra_value_patch"}',
+                    "host": TEST_CONN_HOST,
+                    "login": TEST_CONN_LOGIN,
+                    "password": None,
+                    "port": TEST_CONN_PORT,
+                    "schema": "http_patch",
+                },
+            ),
+            (
+                {  # Explicitly test that None is applied compared to if not 
provided
+                    "conn_type": TEST_CONN_TYPE,
+                    "connection_id": TEST_CONN_ID,
+                    "description": None,
+                    "extra": None,
+                    "host": None,
+                    "login": None,
+                    "password": None,
+                    "port": None,
+                    "schema": None,
+                },
+                {
+                    "conn_type": TEST_CONN_TYPE,
+                    "connection_id": TEST_CONN_ID,
+                    "description": None,
+                    "extra": None,
+                    "host": None,
+                    "login": None,
+                    "password": None,
+                    "port": None,
+                    "schema": None,
+                },
+            ),
         ],
     )
     @provide_session
-    def test_patch_should_respond_200(self, test_client, body, session):
+    def test_patch_should_respond_200(
+        self, test_client, body: dict[str, str], expected_result: dict[str, 
str], session
+    ):
         self.create_connection()
 
         response = test_client.patch(f"/connections/{TEST_CONN_ID}", json=body)
         assert response.status_code == 200
         _check_last_log(session, dag_id=None, event="patch_connection", 
logical_date=None)
 
+        assert response.json() == expected_result
+
     def test_should_respond_401(self, unauthenticated_test_client):
         response = 
unauthenticated_test_client.patch(f"/connections/{TEST_CONN_ID}", json={})
         assert response.status_code == 401
@@ -390,7 +518,13 @@ class TestPatchConnection(TestConnectionEndpoint):
         "body, updated_connection, update_mask",
         [
             (
-                {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"extra": '{"key": "var"}'},
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": TEST_CONN_TYPE,
+                    "extra": '{"key": "var"}',
+                    "login": TEST_CONN_LOGIN,
+                    "port": TEST_CONN_PORT,
+                },
                 {
                     "connection_id": TEST_CONN_ID,
                     "conn_type": TEST_CONN_TYPE,
@@ -404,6 +538,27 @@ class TestPatchConnection(TestConnectionEndpoint):
                 },
                 {"update_mask": ["login", "port"]},
             ),
+            (
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": TEST_CONN_TYPE,
+                    "extra": '{"key": "var"}',
+                    "login": None,
+                    "port": None,
+                },
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": TEST_CONN_TYPE,
+                    "extra": None,
+                    "host": TEST_CONN_HOST,
+                    "login": None,
+                    "port": None,
+                    "schema": None,
+                    "password": None,
+                    "description": TEST_CONN_DESCRIPTION,
+                },
+                {"update_mask": ["login", "port"]},
+            ),
             (
                 {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"host": "test_host_patch"},
                 {
@@ -474,6 +629,28 @@ class TestPatchConnection(TestConnectionEndpoint):
                 },
                 {"update_mask": ["host"]},
             ),
+            (
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": TEST_CONN_TYPE,
+                    "extra": '{"new_extra_key": "new_extra_value"}',
+                    "host": TEST_CONN_HOST,
+                    "schema": "new_schema",
+                    "port": 80,
+                },
+                {
+                    "connection_id": TEST_CONN_ID,
+                    "conn_type": TEST_CONN_TYPE,
+                    "extra": '{"new_extra_key": "new_extra_value"}',
+                    "host": TEST_CONN_HOST,
+                    "login": TEST_CONN_LOGIN,
+                    "port": TEST_CONN_PORT,
+                    "password": None,
+                    "schema": "new_schema",
+                    "description": TEST_CONN_DESCRIPTION,
+                },
+                {"update_mask": ["schema", "extra"]},
+            ),
         ],
     )
     def test_patch_should_respond_200_with_update_mask(
@@ -569,7 +746,7 @@ class TestPatchConnection(TestConnectionEndpoint):
 
     @pytest.mark.enable_redact
     @pytest.mark.parametrize(
-        "body, expected_response",
+        "body, expected_response, update_mask",
         [
             (
                 {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"password": "test-password"},
@@ -584,6 +761,7 @@ class TestPatchConnection(TestConnectionEndpoint):
                     "port": 8080,
                     "schema": None,
                 },
+                {"update_mask": ["password"]},
             ),
             (
                 {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, 
"password": "?>@#+!_%()#"},
@@ -598,6 +776,7 @@ class TestPatchConnection(TestConnectionEndpoint):
                     "port": 8080,
                     "schema": None,
                 },
+                {"update_mask": ["password"]},
             ),
             (
                 {
@@ -617,12 +796,15 @@ class TestPatchConnection(TestConnectionEndpoint):
                     "port": 8080,
                     "schema": None,
                 },
+                {"update_mask": ["password", "extra"]},
             ),
         ],
     )
-    def test_patch_should_response_200_redacted_password(self, test_client, 
session, body, expected_response):
+    def test_patch_should_response_200_redacted_password(
+        self, test_client, session, body, expected_response, update_mask
+    ):
         self.create_connections()
-        response = test_client.patch(f"/connections/{TEST_CONN_ID}", json=body)
+        response = test_client.patch(f"/connections/{TEST_CONN_ID}", 
json=body, params=update_mask)
         assert response.status_code == 200
         assert response.json() == expected_response
         _check_last_log(session, dag_id=None, event="patch_connection", 
logical_date=None, check_masked=True)
@@ -631,18 +813,21 @@ class TestPatchConnection(TestConnectionEndpoint):
 class TestConnection(TestConnectionEndpoint):
     @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
     @pytest.mark.parametrize(
-        "body",
+        "body, message",
         [
-            {"connection_id": TEST_CONN_ID, "conn_type": "sqlite"},
-            {"connection_id": TEST_CONN_ID, "conn_type": "ftp"},
+            ({"connection_id": TEST_CONN_ID, "conn_type": "sqlite"}, 
"Connection successfully tested"),
+            (
+                {"connection_id": TEST_CONN_ID, "conn_type": "fs", "extra": 
'{"path": "/"}'},
+                "Path / is existing.",
+            ),
         ],
     )
-    def test_should_respond_200(self, test_client, body):
+    def test_should_respond_200(self, test_client, body, message):
         response = test_client.post("/connections/test", json=body)
         assert response.status_code == 200
         assert response.json() == {
             "status": True,
-            "message": "Connection successfully tested",
+            "message": message,
         }
 
     def test_should_respond_401(self, unauthenticated_test_client):

Reply via email to