This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b66ce98264d Connections API is able to patch all fields correctly
(#48109)
b66ce98264d is described below
commit b66ce98264d6a65e8827225b7a6ca36a7e989fef
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Apr 1 00:17:48 2025 +0200
Connections API is able to patch all fields correctly (#48109)
* Connections API is able to patch all fields correctly
* Fix implementation and pytest, apply to bulk service as well
* Consolidate Pydantic to ORM update in connections API
* Review feedback - genralize via model_dump where possible
* Fix pytest
* Review feedback
* Review feedback, extended pytests
* Add a explicit test if None is passed in body
---
.../core_api/routes/public/connections.py | 27 +--
.../core_api/services/public/connections.py | 40 +++-
airflow-core/src/airflow/models/connection.py | 2 +-
.../core_api/routes/public/test_connections.py | 239 ++++++++++++++++++---
4 files changed, 260 insertions(+), 48 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
index 966e59e587b..b4a300d1393 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -44,7 +44,10 @@ from airflow.api_fastapi.core_api.datamodels.connections
import (
)
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_connection
-from airflow.api_fastapi.core_api.services.public.connections import
BulkConnectionService
+from airflow.api_fastapi.core_api.services.public.connections import (
+ BulkConnectionService,
+ update_orm_from_pydantic,
+)
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.configuration import conf
from airflow.models import Connection
@@ -187,29 +190,19 @@ def patch_connection(
"The connection_id in the request body does not match the URL
parameter",
)
- non_update_fields = {"connection_id", "conn_id"}
- connection =
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
+ connection: Connection =
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
if connection is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND, f"The Connection with connection_id:
`{connection_id}` was not found"
)
- fields_to_update = patch_body.model_fields_set
-
- if update_mask:
- fields_to_update = fields_to_update.intersection(update_mask)
- else:
- try:
- ConnectionBody(**patch_body.model_dump())
- except ValidationError as e:
- raise RequestValidationError(errors=e.errors())
-
- data = patch_body.model_dump(include=fields_to_update - non_update_fields,
by_alias=True)
-
- for key, val in data.items():
- setattr(connection, key, val)
+ try:
+ ConnectionBody(**patch_body.model_dump())
+ except ValidationError as e:
+ raise RequestValidationError(errors=e.errors())
+ update_orm_from_pydantic(connection, patch_body, update_mask)
return connection
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py
index 7aab305af5c..3883c3f5e84 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/connections.py
@@ -34,6 +34,36 @@ from airflow.api_fastapi.core_api.services.public.common
import BulkService
from airflow.models.connection import Connection
+def update_orm_from_pydantic(
+ orm_conn: Connection, pydantic_conn: ConnectionBody, update_mask:
list[str] | None = None
+) -> None:
+ """Update ORM object from Pydantic object."""
+ # Not all fields match and some need setters, therefore copy partly
manually via setters
+ non_update_fields = {"connection_id", "conn_id"}
+ setter_fields = {"password", "extra"}
+ fields_set = pydantic_conn.model_fields_set
+ if "schema_" in fields_set: # Alias is not resolved correctly, need to
patch
+ fields_set.remove("schema_")
+ fields_set.add("schema")
+ fields_to_update = fields_set - non_update_fields - setter_fields
+ if update_mask:
+ fields_to_update = fields_to_update.intersection(update_mask)
+ print(fields_to_update)
+ conn_data = pydantic_conn.model_dump(by_alias=True)
+ for key, val in conn_data.items():
+ if key in fields_to_update:
+ setattr(orm_conn, key, val)
+
+ if (not update_mask and "password" in pydantic_conn.model_fields_set) or (
+ update_mask and "password" in update_mask
+ ):
+ orm_conn.set_password(pydantic_conn.password)
+ if (not update_mask and "extra" in pydantic_conn.model_fields_set) or (
+ update_mask and "extra" in update_mask
+ ):
+ orm_conn.set_extra(pydantic_conn.extra)
+
+
class BulkConnectionService(BulkService[ConnectionBody]):
"""Service for handling bulk operations on connections."""
@@ -108,12 +138,16 @@ class BulkConnectionService(BulkService[ConnectionBody]):
for connection in action.entities:
if connection.connection_id in update_connection_ids:
- old_connection = self.session.scalar(
+ old_connection: Connection = self.session.scalar(
select(Connection).filter(Connection.conn_id ==
connection.connection_id).limit(1)
)
+ if old_connection is None:
+ raise ValidationError(
+ f"The Connection with connection_id:
`{connection.connection_id}` was not found"
+ )
ConnectionBody(**connection.model_dump())
- for key, val in
connection.model_dump(by_alias=True).items():
- setattr(old_connection, key, val)
+
+ update_orm_from_pydantic(old_connection, connection)
results.success.append(connection.connection_id)
except HTTPException as e:
diff --git a/airflow-core/src/airflow/models/connection.py
b/airflow-core/src/airflow/models/connection.py
index 92917aaf865..0f32133949e 100644
--- a/airflow-core/src/airflow/models/connection.py
+++ b/airflow-core/src/airflow/models/connection.py
@@ -353,7 +353,7 @@ class Connection(Base, LoggingMixin):
self._validate_extra(extra_val, self.conn_id)
return extra_val
- def set_extra(self, value: str):
+ def set_extra(self, value: str | None):
"""Encrypt extra-data and save in object attribute to object."""
if value:
self._validate_extra(value, self.conn_id)
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
index f59969bb440..2e32470fc8f 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
@@ -36,6 +36,8 @@ TEST_CONN_DESCRIPTION = "some_description_a"
TEST_CONN_HOST = "some_host_a"
TEST_CONN_PORT = 8080
TEST_CONN_LOGIN = "some_login"
+TEST_CONN_SCHEMA = "https"
+TEST_CONN_EXTRA = '{"extra_key": "extra_value"}'
TEST_CONN_ID_2 = "test_connection_id_2"
@@ -350,34 +352,160 @@ class TestPostConnection(TestConnectionEndpoint):
class TestPatchConnection(TestConnectionEndpoint):
@pytest.mark.parametrize(
- "body",
+ "body, expected_result",
[
- {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"extra": '{"key": "var"}'},
- {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"host": "test_host_patch"},
- {
- "connection_id": TEST_CONN_ID,
- "conn_type": TEST_CONN_TYPE,
- "host": "test_host_patch",
- "port": 80,
- },
- {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"login": "test_login_patch"},
- {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"port": 80},
- {
- "connection_id": TEST_CONN_ID,
- "conn_type": TEST_CONN_TYPE,
- "port": 80,
- "login": "test_login_patch",
- },
+ (
+ {"connection_id": TEST_CONN_ID, "conn_type": "new_type",
"extra": '{"key": "var"}'},
+ {
+ "conn_type": "new_type",
+ "connection_id": TEST_CONN_ID,
+ "description": TEST_CONN_DESCRIPTION,
+ "extra": '{"key": "var"}',
+ "host": TEST_CONN_HOST,
+ "login": TEST_CONN_LOGIN,
+ "password": None,
+ "port": TEST_CONN_PORT,
+ "schema": None,
+ },
+ ),
+ (
+ {"connection_id": TEST_CONN_ID, "conn_type": "type_patch",
"host": "test_host_patch"},
+ {
+ "conn_type": "type_patch",
+ "connection_id": TEST_CONN_ID,
+ "description": TEST_CONN_DESCRIPTION,
+ "extra": None,
+ "host": "test_host_patch",
+ "login": TEST_CONN_LOGIN,
+ "password": None,
+ "port": TEST_CONN_PORT,
+ "schema": None,
+ },
+ ),
+ (
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": "surprise",
+ "host": "test_host_patch",
+ "port": 80,
+ },
+ {
+ "conn_type": "surprise",
+ "connection_id": TEST_CONN_ID,
+ "description": TEST_CONN_DESCRIPTION,
+ "extra": None,
+ "host": "test_host_patch",
+ "login": TEST_CONN_LOGIN,
+ "password": None,
+ "port": 80,
+ "schema": None,
+ },
+ ),
+ (
+ {"connection_id": TEST_CONN_ID, "conn_type":
"really_new_type", "login": "test_login_patch"},
+ {
+ "conn_type": "really_new_type",
+ "connection_id": TEST_CONN_ID,
+ "description": TEST_CONN_DESCRIPTION,
+ "extra": None,
+ "host": TEST_CONN_HOST,
+ "login": "test_login_patch",
+ "password": None,
+ "port": TEST_CONN_PORT,
+ "schema": None,
+ },
+ ),
+ (
+ {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"port": 80},
+ {
+ "conn_type": TEST_CONN_TYPE,
+ "connection_id": TEST_CONN_ID,
+ "description": TEST_CONN_DESCRIPTION,
+ "extra": None,
+ "host": TEST_CONN_HOST,
+ "login": TEST_CONN_LOGIN,
+ "password": None,
+ "port": 80,
+ "schema": None,
+ },
+ ),
+ (
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "port": 80,
+ "login": "test_login_patch",
+ "password": "test_password_patch",
+ },
+ {
+ "conn_type": TEST_CONN_TYPE,
+ "connection_id": TEST_CONN_ID,
+ "description": TEST_CONN_DESCRIPTION,
+ "extra": None,
+ "host": TEST_CONN_HOST,
+ "login": "test_login_patch",
+ "password": "test_password_patch",
+ "port": 80,
+ "schema": None,
+ },
+ ),
+ (
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "schema": "http_patch",
+ "extra": '{"extra_key_patch": "extra_value_patch"}',
+ },
+ {
+ "conn_type": TEST_CONN_TYPE,
+ "connection_id": TEST_CONN_ID,
+ "description": TEST_CONN_DESCRIPTION,
+ "extra": '{"extra_key_patch": "extra_value_patch"}',
+ "host": TEST_CONN_HOST,
+ "login": TEST_CONN_LOGIN,
+ "password": None,
+ "port": TEST_CONN_PORT,
+ "schema": "http_patch",
+ },
+ ),
+ (
+ { # Explicitly test that None is applied compared to if not
provided
+ "conn_type": TEST_CONN_TYPE,
+ "connection_id": TEST_CONN_ID,
+ "description": None,
+ "extra": None,
+ "host": None,
+ "login": None,
+ "password": None,
+ "port": None,
+ "schema": None,
+ },
+ {
+ "conn_type": TEST_CONN_TYPE,
+ "connection_id": TEST_CONN_ID,
+ "description": None,
+ "extra": None,
+ "host": None,
+ "login": None,
+ "password": None,
+ "port": None,
+ "schema": None,
+ },
+ ),
],
)
@provide_session
- def test_patch_should_respond_200(self, test_client, body, session):
+ def test_patch_should_respond_200(
+ self, test_client, body: dict[str, str], expected_result: dict[str,
str], session
+ ):
self.create_connection()
response = test_client.patch(f"/connections/{TEST_CONN_ID}", json=body)
assert response.status_code == 200
_check_last_log(session, dag_id=None, event="patch_connection",
logical_date=None)
+ assert response.json() == expected_result
+
def test_should_respond_401(self, unauthenticated_test_client):
response =
unauthenticated_test_client.patch(f"/connections/{TEST_CONN_ID}", json={})
assert response.status_code == 401
@@ -390,7 +518,13 @@ class TestPatchConnection(TestConnectionEndpoint):
"body, updated_connection, update_mask",
[
(
- {"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"extra": '{"key": "var"}'},
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "extra": '{"key": "var"}',
+ "login": TEST_CONN_LOGIN,
+ "port": TEST_CONN_PORT,
+ },
{
"connection_id": TEST_CONN_ID,
"conn_type": TEST_CONN_TYPE,
@@ -404,6 +538,27 @@ class TestPatchConnection(TestConnectionEndpoint):
},
{"update_mask": ["login", "port"]},
),
+ (
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "extra": '{"key": "var"}',
+ "login": None,
+ "port": None,
+ },
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "extra": None,
+ "host": TEST_CONN_HOST,
+ "login": None,
+ "port": None,
+ "schema": None,
+ "password": None,
+ "description": TEST_CONN_DESCRIPTION,
+ },
+ {"update_mask": ["login", "port"]},
+ ),
(
{"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"host": "test_host_patch"},
{
@@ -474,6 +629,28 @@ class TestPatchConnection(TestConnectionEndpoint):
},
{"update_mask": ["host"]},
),
+ (
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "extra": '{"new_extra_key": "new_extra_value"}',
+ "host": TEST_CONN_HOST,
+ "schema": "new_schema",
+ "port": 80,
+ },
+ {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "extra": '{"new_extra_key": "new_extra_value"}',
+ "host": TEST_CONN_HOST,
+ "login": TEST_CONN_LOGIN,
+ "port": TEST_CONN_PORT,
+ "password": None,
+ "schema": "new_schema",
+ "description": TEST_CONN_DESCRIPTION,
+ },
+ {"update_mask": ["schema", "extra"]},
+ ),
],
)
def test_patch_should_respond_200_with_update_mask(
@@ -569,7 +746,7 @@ class TestPatchConnection(TestConnectionEndpoint):
@pytest.mark.enable_redact
@pytest.mark.parametrize(
- "body, expected_response",
+ "body, expected_response, update_mask",
[
(
{"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"password": "test-password"},
@@ -584,6 +761,7 @@ class TestPatchConnection(TestConnectionEndpoint):
"port": 8080,
"schema": None,
},
+ {"update_mask": ["password"]},
),
(
{"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE,
"password": "?>@#+!_%()#"},
@@ -598,6 +776,7 @@ class TestPatchConnection(TestConnectionEndpoint):
"port": 8080,
"schema": None,
},
+ {"update_mask": ["password"]},
),
(
{
@@ -617,12 +796,15 @@ class TestPatchConnection(TestConnectionEndpoint):
"port": 8080,
"schema": None,
},
+ {"update_mask": ["password", "extra"]},
),
],
)
- def test_patch_should_response_200_redacted_password(self, test_client,
session, body, expected_response):
+ def test_patch_should_response_200_redacted_password(
+ self, test_client, session, body, expected_response, update_mask
+ ):
self.create_connections()
- response = test_client.patch(f"/connections/{TEST_CONN_ID}", json=body)
+ response = test_client.patch(f"/connections/{TEST_CONN_ID}",
json=body, params=update_mask)
assert response.status_code == 200
assert response.json() == expected_response
_check_last_log(session, dag_id=None, event="patch_connection",
logical_date=None, check_masked=True)
@@ -631,18 +813,21 @@ class TestPatchConnection(TestConnectionEndpoint):
class TestConnection(TestConnectionEndpoint):
@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
@pytest.mark.parametrize(
- "body",
+ "body, message",
[
- {"connection_id": TEST_CONN_ID, "conn_type": "sqlite"},
- {"connection_id": TEST_CONN_ID, "conn_type": "ftp"},
+ ({"connection_id": TEST_CONN_ID, "conn_type": "sqlite"},
"Connection successfully tested"),
+ (
+ {"connection_id": TEST_CONN_ID, "conn_type": "fs", "extra":
'{"path": "/"}'},
+ "Path / is existing.",
+ ),
],
)
- def test_should_respond_200(self, test_client, body):
+ def test_should_respond_200(self, test_client, body, message):
response = test_client.post("/connections/test", json=body)
assert response.status_code == 200
assert response.json() == {
"status": True,
- "message": "Connection successfully tested",
+ "message": message,
}
def test_should_respond_401(self, unauthenticated_test_client):