This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1ca620896bf AIP-84 Migrate get connections to FastAPI API #42571
(#42782)
1ca620896bf is described below
commit 1ca620896bfeff2b79f5ac4241b4b9ede8a02778
Author: Bugra Ozturk <[email protected]>
AuthorDate: Tue Oct 15 18:05:58 2024 +0200
AIP-84 Migrate get connections to FastAPI API #42571 (#42782)
* Make SortParam parent for Model Specific SortParams, Include get
connections endpoint to fastapi
* Change depends() method regular method in SortParam due to parent class
already have abstract
* Remove subclass, get default order_by from primary key, change alias
strategy for backcompat
* pre-commit hooks
* Dynamic return value of SortParam generated within openapi specs and
removed unnecessary attribute mapping keys
* Include connection_id to attr_mapping again
* Dynamic depends with correct documentation
* Add more tests
---------
Co-authored-by: pierrejeambrun <[email protected]>
---
.../api_connexion/endpoints/connection_endpoint.py | 1 +
airflow/api_fastapi/openapi/v1-generated.yaml | 82 +++++++++++++++++++++-
airflow/api_fastapi/parameters.py | 48 ++++++++++---
airflow/api_fastapi/serializers/connections.py | 9 ++-
airflow/api_fastapi/views/public/connections.py | 42 ++++++++++-
airflow/api_fastapi/views/public/dags.py | 5 +-
airflow/ui/openapi-gen/queries/common.ts | 24 +++++++
airflow/ui/openapi-gen/queries/prefetch.ts | 30 ++++++++
airflow/ui/openapi-gen/queries/queries.ts | 36 ++++++++++
airflow/ui/openapi-gen/queries/suspense.ts | 36 ++++++++++
airflow/ui/openapi-gen/requests/schemas.gen.ts | 26 ++++++-
airflow/ui/openapi-gen/requests/services.gen.ts | 32 +++++++++
airflow/ui/openapi-gen/requests/types.gen.ts | 45 +++++++++++-
tests/api_fastapi/views/public/test_connections.py | 74 +++++++++++++++++--
14 files changed, 466 insertions(+), 24 deletions(-)
diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py
b/airflow/api_connexion/endpoints/connection_endpoint.py
index 7cea4cf9837..37c91c44eb6 100644
--- a/airflow/api_connexion/endpoints/connection_endpoint.py
+++ b/airflow/api_connexion/endpoints/connection_endpoint.py
@@ -92,6 +92,7 @@ def get_connection(*, connection_id: str, session: Session =
NEW_SESSION) -> API
@security.requires_access_connection("GET")
@format_parameters({"limit": check_limit})
@provide_session
+@mark_fastapi_migration_done
def get_connections(
*,
limit: int,
diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml
b/airflow/api_fastapi/openapi/v1-generated.yaml
index 56f48c73e98..d9339fdb56c 100644
--- a/airflow/api_fastapi/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/openapi/v1-generated.yaml
@@ -593,6 +593,66 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+ /public/connections/:
+ get:
+ tags:
+ - Connection
+ summary: Get Connections
+ description: Get all connection entries.
+ operationId: get_connections
+ parameters:
+ - name: limit
+ in: query
+ required: false
+ schema:
+ type: integer
+ default: 100
+ title: Limit
+ - name: offset
+ in: query
+ required: false
+ schema:
+ type: integer
+ default: 0
+ title: Offset
+ - name: order_by
+ in: query
+ required: false
+ schema:
+ type: string
+ default: id
+ title: Order By
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ConnectionCollectionResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
/public/variables/{variable_key}:
delete:
tags:
@@ -886,11 +946,27 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
components:
schemas:
+ ConnectionCollectionResponse:
+ properties:
+ connections:
+ items:
+ $ref: '#/components/schemas/ConnectionResponse'
+ type: array
+ title: Connections
+ total_entries:
+ type: integer
+ title: Total Entries
+ type: object
+ required:
+ - connections
+ - total_entries
+ title: ConnectionCollectionResponse
+ description: DAG Collection serializer for responses.
ConnectionResponse:
properties:
- conn_id:
+ connection_id:
type: string
- title: Conn Id
+ title: Connection Id
conn_type:
type: string
title: Conn Type
@@ -926,7 +1002,7 @@ components:
title: Extra
type: object
required:
- - conn_id
+ - connection_id
- conn_type
- description
- host
diff --git a/airflow/api_fastapi/parameters.py
b/airflow/api_fastapi/parameters.py
index 59d61ad6860..07d8a76c79c 100644
--- a/airflow/api_fastapi/parameters.py
+++ b/airflow/api_fastapi/parameters.py
@@ -17,16 +17,19 @@
from __future__ import annotations
+import importlib
from abc import ABC, abstractmethod
from datetime import datetime
-from typing import TYPE_CHECKING, Any, Generic, List, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar
from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
from pydantic import AfterValidator
-from sqlalchemy import case, or_
+from sqlalchemy import Column, case, or_
+from sqlalchemy.inspection import inspect
from typing_extensions import Annotated, Self
+from airflow.models import Base, Connection
from airflow.models.dag import DagModel, DagTag
from airflow.models.dagrun import DagRun
from airflow.utils import timezone
@@ -154,11 +157,17 @@ class SortParam(BaseParam[str]):
attr_mapping = {
"last_run_state": DagRun.state,
"last_run_start_date": DagRun.start_date,
+ "connection_id": Connection.conn_id,
}
- def __init__(self, allowed_attrs: list[str]) -> None:
+ def __init__(
+ self,
+ allowed_attrs: list[str],
+ model: Base,
+ ) -> None:
super().__init__()
self.allowed_attrs = allowed_attrs
+ self.model = model
def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
@@ -175,7 +184,9 @@ class SortParam(BaseParam[str]):
f"the attribute does not exist on the model",
)
- column = self.attr_mapping.get(lstriped_orderby, None) or
getattr(DagModel, lstriped_orderby)
+ column: Column = self.attr_mapping.get(lstriped_orderby, None) or
getattr(
+ self.model, lstriped_orderby
+ )
# MySQL does not support `nullslast`, and True/False ordering depends
on the
# database implementation.
@@ -185,12 +196,33 @@ class SortParam(BaseParam[str]):
select = select.order_by(None)
if self.value[0] == "-":
- return select.order_by(nullscheck, column.desc(),
DagModel.dag_id.desc())
+ return select.order_by(nullscheck, column.desc(), column.desc())
else:
- return select.order_by(nullscheck, column.asc(),
DagModel.dag_id.asc())
+ return select.order_by(nullscheck, column.asc(), column.asc())
+
+ def get_primary_key(self) -> str:
+ """Get the primary key of the model of SortParam object."""
+ return inspect(self.model).primary_key[0].name
+
+ @staticmethod
+ def get_primary_key_of_given_model_string(model_string: str) -> str:
+ """
+ Get the primary key of given 'airflow.models' class as a string. The
class should have driven be from 'airflow.models.base'.
+
+ :param model_string: The string representation of the model class.
+ :return: The primary key of the model class.
+ """
+ dynamic_return_model =
getattr(importlib.import_module("airflow.models"), model_string)
+ return inspect(dynamic_return_model).primary_key[0].name
+
+ def depends(self, *args: Any, **kwargs: Any) -> Self:
+ raise NotImplementedError("Use dynamic_depends, depends not
implemented.")
+
+ def dynamic_depends(self) -> Callable:
+ def inner(order_by: str = self.get_primary_key()) -> SortParam:
+ return self.set_value(self.get_primary_key() if order_by == ""
else order_by)
- def depends(self, order_by: str = "dag_id") -> SortParam:
- return self.set_value(order_by)
+ return inner
class _TagsFilter(BaseParam[List[str]]):
diff --git a/airflow/api_fastapi/serializers/connections.py
b/airflow/api_fastapi/serializers/connections.py
index e40b2fa1b21..1c801607299 100644
--- a/airflow/api_fastapi/serializers/connections.py
+++ b/airflow/api_fastapi/serializers/connections.py
@@ -27,7 +27,7 @@ from airflow.utils.log.secrets_masker import redact
class ConnectionResponse(BaseModel):
"""Connection serializer for responses."""
- connection_id: str = Field(alias="conn_id")
+ connection_id: str = Field(serialization_alias="connection_id",
validation_alias="conn_id")
conn_type: str
description: str | None
host: str | None
@@ -48,3 +48,10 @@ class ConnectionResponse(BaseModel):
except json.JSONDecodeError:
# we can't redact fields in an unstructured `extra`
return v
+
+
+class ConnectionCollectionResponse(BaseModel):
+ """DAG Collection serializer for responses."""
+
+ connections: list[ConnectionResponse]
+ total_entries: int
diff --git a/airflow/api_fastapi/views/public/connections.py
b/airflow/api_fastapi/views/public/connections.py
index 94e9b614e9c..6fca43aca26 100644
--- a/airflow/api_fastapi/views/public/connections.py
+++ b/airflow/api_fastapi/views/public/connections.py
@@ -21,9 +21,10 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import Annotated
-from airflow.api_fastapi.db.common import get_session
+from airflow.api_fastapi.db.common import get_session, paginated_select
from airflow.api_fastapi.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.api_fastapi.serializers.connections import ConnectionResponse
+from airflow.api_fastapi.parameters import QueryLimit, QueryOffset, SortParam
+from airflow.api_fastapi.serializers.connections import
ConnectionCollectionResponse, ConnectionResponse
from airflow.api_fastapi.views.router import AirflowRouter
from airflow.models import Connection
@@ -63,3 +64,40 @@ async def get_connection(
raise HTTPException(404, f"The Connection with connection_id:
`{connection_id}` was not found")
return ConnectionResponse.model_validate(connection, from_attributes=True)
+
+
+@connections_router.get(
+ "/",
+ responses=create_openapi_http_exception_doc([401, 403, 404]),
+)
+async def get_connections(
+ limit: QueryLimit,
+ offset: QueryOffset,
+ order_by: Annotated[
+ SortParam,
+ Depends(
+ SortParam(
+ ["connection_id", "conn_type", "description", "host", "port",
"id"], Connection
+ ).dynamic_depends()
+ ),
+ ],
+ session: Annotated[Session, Depends(get_session)],
+) -> ConnectionCollectionResponse:
+ """Get all connection entries."""
+ connection_select, total_entries = paginated_select(
+ select(Connection),
+ [],
+ order_by=order_by,
+ offset=offset,
+ limit=limit,
+ session=session,
+ )
+
+ connections = session.scalars(connection_select).all()
+
+ return ConnectionCollectionResponse(
+ connections=[
+ ConnectionResponse.model_validate(connection,
from_attributes=True) for connection in connections
+ ],
+ total_entries=total_entries,
+ )
diff --git a/airflow/api_fastapi/views/public/dags.py
b/airflow/api_fastapi/views/public/dags.py
index eb8233a7f70..46d4d8b3540 100644
--- a/airflow/api_fastapi/views/public/dags.py
+++ b/airflow/api_fastapi/views/public/dags.py
@@ -70,8 +70,9 @@ async def get_dags(
SortParam,
Depends(
SortParam(
- ["dag_id", "dag_display_name", "next_dagrun",
"last_run_state", "last_run_start_date"]
- ).depends
+ ["dag_id", "dag_display_name", "next_dagrun",
"last_run_state", "last_run_start_date"],
+ DagModel,
+ ).dynamic_depends()
),
],
session: Annotated[Session, Depends(get_session)],
diff --git a/airflow/ui/openapi-gen/queries/common.ts
b/airflow/ui/openapi-gen/queries/common.ts
index 2f1c6a78d92..30a20826109 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -151,6 +151,30 @@ export const UseConnectionServiceGetConnectionKeyFn = (
useConnectionServiceGetConnectionKey,
...(queryKey ?? [{ connectionId }]),
];
+export type ConnectionServiceGetConnectionsDefaultResponse = Awaited<
+ ReturnType<typeof ConnectionService.getConnections>
+>;
+export type ConnectionServiceGetConnectionsQueryResult<
+ TData = ConnectionServiceGetConnectionsDefaultResponse,
+ TError = unknown,
+> = UseQueryResult<TData, TError>;
+export const useConnectionServiceGetConnectionsKey =
+ "ConnectionServiceGetConnections";
+export const UseConnectionServiceGetConnectionsKeyFn = (
+ {
+ limit,
+ offset,
+ orderBy,
+ }: {
+ limit?: number;
+ offset?: number;
+ orderBy?: string;
+ } = {},
+ queryKey?: Array<unknown>,
+) => [
+ useConnectionServiceGetConnectionsKey,
+ ...(queryKey ?? [{ limit, offset, orderBy }]),
+];
export type VariableServiceGetVariableDefaultResponse = Awaited<
ReturnType<typeof VariableService.getVariable>
>;
diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts
b/airflow/ui/openapi-gen/queries/prefetch.ts
index 3e194302f4b..79b310107d2 100644
--- a/airflow/ui/openapi-gen/queries/prefetch.ts
+++ b/airflow/ui/openapi-gen/queries/prefetch.ts
@@ -187,6 +187,36 @@ export const prefetchUseConnectionServiceGetConnection = (
queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }),
queryFn: () => ConnectionService.getConnection({ connectionId }),
});
+/**
+ * Get Connections
+ * Get all connection entries.
+ * @param data The data for the request.
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns ConnectionCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const prefetchUseConnectionServiceGetConnections = (
+ queryClient: QueryClient,
+ {
+ limit,
+ offset,
+ orderBy,
+ }: {
+ limit?: number;
+ offset?: number;
+ orderBy?: string;
+ } = {},
+) =>
+ queryClient.prefetchQuery({
+ queryKey: Common.UseConnectionServiceGetConnectionsKeyFn({
+ limit,
+ offset,
+ orderBy,
+ }),
+ queryFn: () => ConnectionService.getConnections({ limit, offset, orderBy
}),
+ });
/**
* Get Variable
* Get a variable entry.
diff --git a/airflow/ui/openapi-gen/queries/queries.ts
b/airflow/ui/openapi-gen/queries/queries.ts
index a16bdf165b1..1e9586e4360 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -235,6 +235,42 @@ export const useConnectionServiceGetConnection = <
queryFn: () => ConnectionService.getConnection({ connectionId }) as TData,
...options,
});
+/**
+ * Get Connections
+ * Get all connection entries.
+ * @param data The data for the request.
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns ConnectionCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useConnectionServiceGetConnections = <
+ TData = Common.ConnectionServiceGetConnectionsDefaultResponse,
+ TError = unknown,
+ TQueryKey extends Array<unknown> = unknown[],
+>(
+ {
+ limit,
+ offset,
+ orderBy,
+ }: {
+ limit?: number;
+ offset?: number;
+ orderBy?: string;
+ } = {},
+ queryKey?: TQueryKey,
+ options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
+) =>
+ useQuery<TData, TError>({
+ queryKey: Common.UseConnectionServiceGetConnectionsKeyFn(
+ { limit, offset, orderBy },
+ queryKey,
+ ),
+ queryFn: () =>
+ ConnectionService.getConnections({ limit, offset, orderBy }) as TData,
+ ...options,
+ });
/**
* Get Variable
* Get a variable entry.
diff --git a/airflow/ui/openapi-gen/queries/suspense.ts
b/airflow/ui/openapi-gen/queries/suspense.ts
index 79ad479f0a4..af2a571871e 100644
--- a/airflow/ui/openapi-gen/queries/suspense.ts
+++ b/airflow/ui/openapi-gen/queries/suspense.ts
@@ -230,6 +230,42 @@ export const useConnectionServiceGetConnectionSuspense = <
queryFn: () => ConnectionService.getConnection({ connectionId }) as TData,
...options,
});
+/**
+ * Get Connections
+ * Get all connection entries.
+ * @param data The data for the request.
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns ConnectionCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useConnectionServiceGetConnectionsSuspense = <
+ TData = Common.ConnectionServiceGetConnectionsDefaultResponse,
+ TError = unknown,
+ TQueryKey extends Array<unknown> = unknown[],
+>(
+ {
+ limit,
+ offset,
+ orderBy,
+ }: {
+ limit?: number;
+ offset?: number;
+ orderBy?: string;
+ } = {},
+ queryKey?: TQueryKey,
+ options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
+) =>
+ useSuspenseQuery<TData, TError>({
+ queryKey: Common.UseConnectionServiceGetConnectionsKeyFn(
+ { limit, offset, orderBy },
+ queryKey,
+ ),
+ queryFn: () =>
+ ConnectionService.getConnections({ limit, offset, orderBy }) as TData,
+ ...options,
+ });
/**
* Get Variable
* Get a variable entry.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index e42a3f6572c..44a29d3899d 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -1,10 +1,30 @@
// This file is auto-generated by @hey-api/openapi-ts
+export const $ConnectionCollectionResponse = {
+ properties: {
+ connections: {
+ items: {
+ $ref: "#/components/schemas/ConnectionResponse",
+ },
+ type: "array",
+ title: "Connections",
+ },
+ total_entries: {
+ type: "integer",
+ title: "Total Entries",
+ },
+ },
+ type: "object",
+ required: ["connections", "total_entries"],
+ title: "ConnectionCollectionResponse",
+ description: "DAG Collection serializer for responses.",
+} as const;
+
export const $ConnectionResponse = {
properties: {
- conn_id: {
+ connection_id: {
type: "string",
- title: "Conn Id",
+ title: "Connection Id",
},
conn_type: {
type: "string",
@@ -79,7 +99,7 @@ export const $ConnectionResponse = {
},
type: "object",
required: [
- "conn_id",
+ "connection_id",
"conn_type",
"description",
"host",
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 8d7f0cee2b2..6f9b0aa5cb8 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -23,6 +23,8 @@ import type {
DeleteConnectionResponse,
GetConnectionData,
GetConnectionResponse,
+ GetConnectionsData,
+ GetConnectionsResponse,
DeleteVariableData,
DeleteVariableResponse,
GetVariableData,
@@ -343,6 +345,36 @@ export class ConnectionService {
},
});
}
+
+ /**
+ * Get Connections
+ * Get all connection entries.
+ * @param data The data for the request.
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns ConnectionCollectionResponse Successful Response
+ * @throws ApiError
+ */
+ public static getConnections(
+ data: GetConnectionsData = {},
+ ): CancelablePromise<GetConnectionsResponse> {
+ return __request(OpenAPI, {
+ method: "GET",
+ url: "/public/connections/",
+ query: {
+ limit: data.limit,
+ offset: data.offset,
+ order_by: data.orderBy,
+ },
+ errors: {
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 404: "Not Found",
+ 422: "Validation Error",
+ },
+ });
+ }
}
export class VariableService {
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 7f603a1adb4..2fae25c7660 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -1,10 +1,18 @@
// This file is auto-generated by @hey-api/openapi-ts
+/**
+ * DAG Collection serializer for responses.
+ */
+export type ConnectionCollectionResponse = {
+ connections: Array<ConnectionResponse>;
+ total_entries: number;
+};
+
/**
* Connection serializer for responses.
*/
export type ConnectionResponse = {
- conn_id: string;
+ connection_id: string;
conn_type: string;
description: string | null;
host: string | null;
@@ -351,6 +359,14 @@ export type GetConnectionData = {
export type GetConnectionResponse = ConnectionResponse;
+export type GetConnectionsData = {
+ limit?: number;
+ offset?: number;
+ orderBy?: string;
+};
+
+export type GetConnectionsResponse = ConnectionCollectionResponse;
+
export type DeleteVariableData = {
variableKey: string;
};
@@ -644,6 +660,33 @@ export type $OpenApiTs = {
};
};
};
+ "/public/connections/": {
+ get: {
+ req: GetConnectionsData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: ConnectionCollectionResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
+ };
"/public/variables/{variable_key}": {
delete: {
req: DeleteVariableData;
diff --git a/tests/api_fastapi/views/public/test_connections.py
b/tests/api_fastapi/views/public/test_connections.py
index a0da0d2b9b2..a80f2467b9b 100644
--- a/tests/api_fastapi/views/public/test_connections.py
+++ b/tests/api_fastapi/views/public/test_connections.py
@@ -26,14 +26,43 @@ pytestmark = pytest.mark.db_test
TEST_CONN_ID = "test_connection_id"
TEST_CONN_TYPE = "test_type"
+TEST_CONN_DESCRIPTION = "some_description_a"
+TEST_CONN_HOST = "some_host_a"
+TEST_CONN_PORT = 8080
+
+
+TEST_CONN_ID_2 = "test_connection_id_2"
+TEST_CONN_TYPE_2 = "test_type_2"
+TEST_CONN_DESCRIPTION_2 = "some_description_b"
+TEST_CONN_HOST_2 = "some_host_b"
+TEST_CONN_PORT_2 = 8081
@provide_session
def _create_connection(session) -> None:
- connection_model = Connection(conn_id=TEST_CONN_ID,
conn_type=TEST_CONN_TYPE)
+ connection_model = Connection(
+ conn_id=TEST_CONN_ID,
+ conn_type=TEST_CONN_TYPE,
+ description=TEST_CONN_DESCRIPTION,
+ host=TEST_CONN_HOST,
+ port=TEST_CONN_PORT,
+ )
session.add(connection_model)
+@provide_session
+def _create_connections(session) -> None:
+ _create_connection(session)
+ connection_model_2 = Connection(
+ conn_id=TEST_CONN_ID_2,
+ conn_type=TEST_CONN_TYPE_2,
+ description=TEST_CONN_DESCRIPTION_2,
+ host=TEST_CONN_HOST_2,
+ port=TEST_CONN_PORT_2,
+ )
+ session.add(connection_model_2)
+
+
class TestConnectionEndpoint:
@pytest.fixture(autouse=True)
def setup(self) -> None:
@@ -45,6 +74,9 @@ class TestConnectionEndpoint:
def create_connection(self):
_create_connection()
+ def create_connections(self):
+ _create_connections()
+
class TestDeleteConnection(TestConnectionEndpoint):
def test_delete_should_respond_204(self, test_client, session):
@@ -69,7 +101,7 @@ class TestGetConnection(TestConnectionEndpoint):
response = test_client.get(f"/public/connections/{TEST_CONN_ID}")
assert response.status_code == 200
body = response.json()
- assert body["conn_id"] == TEST_CONN_ID
+ assert body["connection_id"] == TEST_CONN_ID
assert body["conn_type"] == TEST_CONN_TYPE
def test_get_should_respond_404(self, test_client):
@@ -86,7 +118,7 @@ class TestGetConnection(TestConnectionEndpoint):
response = test_client.get(f"/public/connections/{TEST_CONN_ID}")
assert response.status_code == 200
body = response.json()
- assert body["conn_id"] == TEST_CONN_ID
+ assert body["connection_id"] == TEST_CONN_ID
assert body["conn_type"] == TEST_CONN_TYPE
assert body["extra"] == '{"extra_key": "extra_value"}'
@@ -99,6 +131,40 @@ class TestGetConnection(TestConnectionEndpoint):
response = test_client.get(f"/public/connections/{TEST_CONN_ID}")
assert response.status_code == 200
body = response.json()
- assert body["conn_id"] == TEST_CONN_ID
+ assert body["connection_id"] == TEST_CONN_ID
assert body["conn_type"] == TEST_CONN_TYPE
assert body["extra"] == '{"password": "***"}'
+
+
+class TestGetConnections(TestConnectionEndpoint):
+ @pytest.mark.parametrize(
+ "query_params, expected_total_entries, expected_ids",
+ [
+ # Filters
+ ({}, 2, [TEST_CONN_ID, TEST_CONN_ID_2]),
+ ({"limit": 1}, 2, [TEST_CONN_ID]),
+ ({"limit": 1, "offset": 1}, 2, [TEST_CONN_ID_2]),
+ # Sort
+ ({"order_by": "-connection_id"}, 2, [TEST_CONN_ID_2,
TEST_CONN_ID]),
+ ({"order_by": "conn_type"}, 2, [TEST_CONN_ID, TEST_CONN_ID_2]),
+ ({"order_by": "-conn_type"}, 2, [TEST_CONN_ID_2, TEST_CONN_ID]),
+ ({"order_by": "description"}, 2, [TEST_CONN_ID, TEST_CONN_ID_2]),
+ ({"order_by": "-description"}, 2, [TEST_CONN_ID_2, TEST_CONN_ID]),
+ ({"order_by": "host"}, 2, [TEST_CONN_ID, TEST_CONN_ID_2]),
+ ({"order_by": "-host"}, 2, [TEST_CONN_ID_2, TEST_CONN_ID]),
+ ({"order_by": "port"}, 2, [TEST_CONN_ID, TEST_CONN_ID_2]),
+ ({"order_by": "-port"}, 2, [TEST_CONN_ID_2, TEST_CONN_ID]),
+ ({"order_by": "id"}, 2, [TEST_CONN_ID, TEST_CONN_ID_2]),
+ ({"order_by": "-id"}, 2, [TEST_CONN_ID_2, TEST_CONN_ID]),
+ ],
+ )
+ def test_should_respond_200(
+ self, test_client, session, query_params, expected_total_entries,
expected_ids
+ ):
+ self.create_connections()
+ response = test_client.get("/public/connections/", params=query_params)
+ assert response.status_code == 200
+
+ body = response.json()
+ assert body["total_entries"] == expected_total_entries
+ assert [dag["connection_id"] for dag in body["connections"]] ==
expected_ids