This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 9aabbe5e06fd7f761c065d0e077cff20fd61c23c Author: Ephraim Anierobi <[email protected]> AuthorDate: Mon Dec 8 09:04:34 2025 +0100 Revert "Update refresh token flow (#55506) (#58649)" This reverts commit 147eca049d0741fe7f7e945b26526d92e09cafaf. --- airflow-core/src/airflow/api_fastapi/app.py | 2 - .../api_fastapi/auth/managers/base_auth_manager.py | 11 +-- .../api_fastapi/auth/middlewares/__init__.py | 17 ---- .../api_fastapi/auth/middlewares/refresh_token.py | 68 ------------- .../src/airflow/api_fastapi/core_api/app.py | 6 -- .../core_api/openapi/v2-rest-api-generated.yaml | 34 +++++++ .../api_fastapi/core_api/routes/public/auth.py | 20 ++++ .../src/airflow/ui/openapi-gen/queries/common.ts | 6 ++ .../ui/openapi-gen/queries/ensureQueryData.ts | 11 +++ .../src/airflow/ui/openapi-gen/queries/prefetch.ts | 11 +++ .../src/airflow/ui/openapi-gen/queries/queries.ts | 11 +++ .../src/airflow/ui/openapi-gen/queries/suspense.ts | 11 +++ .../ui/openapi-gen/requests/services.gen.ts | 24 ++++- .../airflow/ui/openapi-gen/requests/types.gen.ts | 25 +++++ .../auth/managers/test_base_auth_manager.py | 3 - .../unit/api_fastapi/auth/middlewares/__init__.py | 17 ---- .../auth/middlewares/test_refresh_token.py | 106 --------------------- .../core_api/routes/public/test_auth.py | 61 ++++++++++++ .../api_fastapi/core_api/routes/test_routes.py | 1 + .../keycloak/auth_manager/keycloak_auth_manager.py | 28 +----- .../keycloak/auth_manager/routes/test_login.py | 77 ++++++++++++++- .../auth_manager/test_keycloak_auth_manager.py | 39 -------- 22 files changed, 297 insertions(+), 292 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/app.py b/airflow-core/src/airflow/api_fastapi/app.py index 7c05295807e..58cfb157083 100644 --- a/airflow-core/src/airflow/api_fastapi/app.py +++ b/airflow-core/src/airflow/api_fastapi/app.py @@ -29,7 +29,6 @@ from airflow.api_fastapi.core_api.app import ( init_config, init_error_handlers, init_flask_plugins, - init_middlewares, init_ui_plugins, init_views, ) @@ -100,7 +99,6 @@ def create_app(apps: str = "all") -> FastAPI: init_ui_plugins(app) init_views(app) # Core views need to be the last routes added - it has a catch all route init_error_handlers(app) - init_middlewares(app) init_config(app) diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py index b8656cd068f..d57dd3cdc39 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py @@ -135,15 +135,12 @@ class BaseAuthManager(Generic[T], LoggingMixin, metaclass=ABCMeta): """ return None - def refresh_user(self, *, user: T) -> T | None: + def get_url_refresh(self) -> str | None: """ - Refresh the user if needed. + Return the URL to refresh the authentication token. - By default, does nothing. Some auth managers might need to refresh the user to, for instance, - refresh some tokens that are needed to communicate with a service/tool. - - This method is called by every single request, it must be lightweight otherwise the overall API - server latency will increase. + This is used to refresh the authentication token when it expires. + The default implementation returns None, which means that the auth manager does not support refresh token. """ return None diff --git a/airflow-core/src/airflow/api_fastapi/auth/middlewares/__init__.py b/airflow-core/src/airflow/api_fastapi/auth/middlewares/__init__.py deleted file mode 100644 index 217e5db9607..00000000000 --- a/airflow-core/src/airflow/api_fastapi/auth/middlewares/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py b/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py deleted file mode 100644 index f304eb9517f..00000000000 --- a/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py +++ /dev/null @@ -1,68 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from fastapi import Request -from starlette.middleware.base import BaseHTTPMiddleware - -from airflow.api_fastapi.app import get_auth_manager -from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN -from airflow.api_fastapi.auth.managers.models.base_user import BaseUser -from airflow.api_fastapi.core_api.security import resolve_user_from_token -from airflow.configuration import conf - - -class JWTRefreshMiddleware(BaseHTTPMiddleware): - """ - Middleware to handle JWT token refresh. - - This middleware: - 1. Extracts JWT token from cookies and build the user from the token - 2. Calls ``refresh_user`` method from auth manager with the user - 3. If ``refresh_user`` returns a user, generate a JWT token based upon this user and send it in the - response as cookie - """ - - async def dispatch(self, request: Request, call_next): - new_user = None - current_token = request.cookies.get(COOKIE_NAME_JWT_TOKEN) - if current_token: - new_user = await self._refresh_user(current_token) - if new_user: - request.state.user = new_user - - response = await call_next(request) - - if new_user: - # If we created a new user, serialize it and set it as a cookie - new_token = get_auth_manager().generate_jwt(new_user) - secure = bool(conf.get("api", "ssl_cert", fallback="")) - response.set_cookie( - COOKIE_NAME_JWT_TOKEN, - new_token, - httponly=True, - secure=secure, - samesite="lax", - ) - - return response - - @staticmethod - async def _refresh_user(current_token: str) -> BaseUser | None: - user = await resolve_user_from_token(current_token) - return get_auth_manager().refresh_user(user=user) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/app.py b/airflow-core/src/airflow/api_fastapi/core_api/app.py index 8db1fa66680..cd34b28f3cb 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/app.py @@ -181,12 +181,6 @@ def init_error_handlers(app: FastAPI) -> None: app.add_exception_handler(handler.exception_cls, handler.exception_handler) -def init_middlewares(app: FastAPI) -> None: - from airflow.api_fastapi.auth.middlewares.refresh_token import JWTRefreshMiddleware - - app.add_middleware(JWTRefreshMiddleware) - - def init_ui_plugins(app: FastAPI) -> None: """Initialize UI plugins.""" from airflow import plugins_manager diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index ceaf90b60f8..bef548133c6 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -8478,6 +8478,40 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPExceptionResponse' + /api/v2/auth/refresh: + get: + tags: + - Login + summary: Refresh + description: Refresh the authentication token. + operationId: refresh + parameters: + - name: next + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Next + responses: + '200': + description: Successful Response + content: + application/json: + schema: {} + '307': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Temporary Redirect + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' components: schemas: AppBuilderMenuItemResponse: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py index a97b7fd9972..d1e770b9277 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py @@ -65,3 +65,23 @@ def logout(request: Request) -> RedirectResponse: ) return response + + +@auth_router.get( + "/refresh", + responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]), +) +def refresh(request: Request, next: None | str = None) -> RedirectResponse: + """Refresh the authentication token.""" + refresh_url = request.app.state.auth_manager.get_url_refresh() + + if not refresh_url: + return RedirectResponse(f"{conf.get('api', 'base_url', fallback='/')}auth/logout") + + if next and not is_safe_url(next, request=request): + raise HTTPException(status_code=400, detail="Invalid or unsafe next URL") + + if next: + refresh_url += f"?next={next}" + + return RedirectResponse(refresh_url) diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index 994fe2d91c1..de2bbc9647e 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -753,6 +753,12 @@ export type LoginServiceLogoutDefaultResponse = Awaited<ReturnType<typeof LoginS export type LoginServiceLogoutQueryResult<TData = LoginServiceLogoutDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>; export const useLoginServiceLogoutKey = "LoginServiceLogout"; export const UseLoginServiceLogoutKeyFn = (queryKey?: Array<unknown>) => [useLoginServiceLogoutKey, ...(queryKey ?? [])]; +export type LoginServiceRefreshDefaultResponse = Awaited<ReturnType<typeof LoginService.refresh>>; +export type LoginServiceRefreshQueryResult<TData = LoginServiceRefreshDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>; +export const useLoginServiceRefreshKey = "LoginServiceRefresh"; +export const UseLoginServiceRefreshKeyFn = ({ next }: { + next?: string; +} = {}, queryKey?: Array<unknown>) => [useLoginServiceRefreshKey, ...(queryKey ?? [{ next }])]; export type AuthLinksServiceGetAuthMenusDefaultResponse = Awaited<ReturnType<typeof AuthLinksService.getAuthMenus>>; export type AuthLinksServiceGetAuthMenusQueryResult<TData = AuthLinksServiceGetAuthMenusDefaultResponse, TError = unknown> = UseQueryResult<TData, TError>; export const useAuthLinksServiceGetAuthMenusKey = "AuthLinksServiceGetAuthMenus"; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index 41fe0005dcd..c38c3958500 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -1431,6 +1431,17 @@ export const ensureUseLoginServiceLoginData = (queryClient: QueryClient, { next */ export const ensureUseLoginServiceLogoutData = (queryClient: QueryClient) => queryClient.ensureQueryData({ queryKey: Common.UseLoginServiceLogoutKeyFn(), queryFn: () => LoginService.logout() }); /** +* Refresh +* Refresh the authentication token. +* @param data The data for the request. +* @param data.next +* @returns unknown Successful Response +* @throws ApiError +*/ +export const ensureUseLoginServiceRefreshData = (queryClient: QueryClient, { next }: { + next?: string; +} = {}) => queryClient.ensureQueryData({ queryKey: Common.UseLoginServiceRefreshKeyFn({ next }), queryFn: () => LoginService.refresh({ next }) }); +/** * Get Auth Menus * @returns MenuItemCollectionResponse Successful Response * @throws ApiError diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index fa6162ec588..ec8e3471a4c 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -1431,6 +1431,17 @@ export const prefetchUseLoginServiceLogin = (queryClient: QueryClient, { next }: */ export const prefetchUseLoginServiceLogout = (queryClient: QueryClient) => queryClient.prefetchQuery({ queryKey: Common.UseLoginServiceLogoutKeyFn(), queryFn: () => LoginService.logout() }); /** +* Refresh +* Refresh the authentication token. +* @param data The data for the request. +* @param data.next +* @returns unknown Successful Response +* @throws ApiError +*/ +export const prefetchUseLoginServiceRefresh = (queryClient: QueryClient, { next }: { + next?: string; +} = {}) => queryClient.prefetchQuery({ queryKey: Common.UseLoginServiceRefreshKeyFn({ next }), queryFn: () => LoginService.refresh({ next }) }); +/** * Get Auth Menus * @returns MenuItemCollectionResponse Successful Response * @throws ApiError diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 955e5049e60..5211c77b349 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -1431,6 +1431,17 @@ export const useLoginServiceLogin = <TData = Common.LoginServiceLoginDefaultResp */ export const useLoginServiceLogout = <TData = Common.LoginServiceLogoutDefaultResponse, TError = unknown, TQueryKey extends Array<unknown> = unknown[]>(queryKey?: TQueryKey, options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">) => useQuery<TData, TError>({ queryKey: Common.UseLoginServiceLogoutKeyFn(queryKey), queryFn: () => LoginService.logout() as TData, ...options }); /** +* Refresh +* Refresh the authentication token. +* @param data The data for the request. +* @param data.next +* @returns unknown Successful Response +* @throws ApiError +*/ +export const useLoginServiceRefresh = <TData = Common.LoginServiceRefreshDefaultResponse, TError = unknown, TQueryKey extends Array<unknown> = unknown[]>({ next }: { + next?: string; +} = {}, queryKey?: TQueryKey, options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">) => useQuery<TData, TError>({ queryKey: Common.UseLoginServiceRefreshKeyFn({ next }, queryKey), queryFn: () => LoginService.refresh({ next }) as TData, ...options }); +/** * Get Auth Menus * @returns MenuItemCollectionResponse Successful Response * @throws ApiError diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index aafe12ed9bc..9b980d1cbec 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -1431,6 +1431,17 @@ export const useLoginServiceLoginSuspense = <TData = Common.LoginServiceLoginDef */ export const useLoginServiceLogoutSuspense = <TData = Common.LoginServiceLogoutDefaultResponse, TError = unknown, TQueryKey extends Array<unknown> = unknown[]>(queryKey?: TQueryKey, options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">) => useSuspenseQuery<TData, TError>({ queryKey: Common.UseLoginServiceLogoutKeyFn(queryKey), queryFn: () => LoginService.logout() as TData, ...options }); /** +* Refresh +* Refresh the authentication token. +* @param data The data for the request. +* @param data.next +* @returns unknown Successful Response +* @throws ApiError +*/ +export const useLoginServiceRefreshSuspense = <TData = Common.LoginServiceRefreshDefaultResponse, TError = unknown, TQueryKey extends Array<unknown> = unknown[]>({ next }: { + next?: string; +} = {}, queryKey?: TQueryKey, options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">) => useSuspenseQuery<TData, TError>({ queryKey: Common.UseLoginServiceRefreshKeyFn({ next }, queryKey), queryFn: () => LoginService.refresh({ next }) as TData, ...options }); +/** * Get Auth Menus * @returns MenuItemCollectionResponse Successful Response * @throws ApiError diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index 91ded2463e7..c9cb3594cae 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, Dele [...] +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, Dele [...] export class AssetService { /** @@ -3714,6 +3714,28 @@ export class LoginService { }); } + /** + * Refresh + * Refresh the authentication token. + * @param data The data for the request. + * @param data.next + * @returns unknown Successful Response + * @throws ApiError + */ + public static refresh(data: RefreshData = {}): CancelablePromise<RefreshResponse> { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v2/auth/refresh', + query: { + next: data.next + }, + errors: { + 307: 'Temporary Redirect', + 422: 'Validation Error' + } + }); + } + } export class AuthLinksService { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 006f33286fd..18efcbc82ba 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -3208,6 +3208,12 @@ export type LoginResponse = unknown; export type LogoutResponse = unknown; +export type RefreshData = { + next?: string | null; +}; + +export type RefreshResponse = unknown; + export type GetAuthMenusResponse = MenuItemCollectionResponse; export type GetDependenciesData = { @@ -6278,6 +6284,25 @@ export type $OpenApiTs = { }; }; }; + '/api/v2/auth/refresh': { + get: { + req: RefreshData; + res: { + /** + * Successful Response + */ + 200: unknown; + /** + * Temporary Redirect + */ + 307: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; '/ui/auth/menus': { get: { res: { diff --git a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py index 98217cef115..fde84665609 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py +++ b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py @@ -165,9 +165,6 @@ class TestBaseAuthManager: def test_get_fastapi_app_return_none(self, auth_manager): assert auth_manager.get_fastapi_app() is None - def test_refresh_user_default_returns_none(self, auth_manager): - assert auth_manager.refresh_user(user=BaseAuthManagerUserTest(name="test")) is None - def test_get_url_logout_return_none(self, auth_manager): assert auth_manager.get_url_logout() is None diff --git a/airflow-core/tests/unit/api_fastapi/auth/middlewares/__init__.py b/airflow-core/tests/unit/api_fastapi/auth/middlewares/__init__.py deleted file mode 100644 index 217e5db9607..00000000000 --- a/airflow-core/tests/unit/api_fastapi/auth/middlewares/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow-core/tests/unit/api_fastapi/auth/middlewares/test_refresh_token.py b/airflow-core/tests/unit/api_fastapi/auth/middlewares/test_refresh_token.py deleted file mode 100644 index 87648a2be2b..00000000000 --- a/airflow-core/tests/unit/api_fastapi/auth/middlewares/test_refresh_token.py +++ /dev/null @@ -1,106 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from fastapi import Request, Response - -from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN -from airflow.api_fastapi.auth.managers.models.base_user import BaseUser -from airflow.api_fastapi.auth.middlewares.refresh_token import JWTRefreshMiddleware - - -class TestJWTRefreshMiddleware: - @pytest.fixture - def middleware(self): - return JWTRefreshMiddleware(app=MagicMock()) - - @pytest.fixture - def mock_request(self): - request = MagicMock(spec=Request) - request.cookies = {} - request.state = MagicMock() - return request - - @pytest.fixture - def mock_user(self): - return MagicMock(spec=BaseUser) - - @patch.object(JWTRefreshMiddleware, "_refresh_user") - @pytest.mark.asyncio - async def test_dispatch_no_token(self, mock_refresh_user, middleware, mock_request): - call_next = AsyncMock(return_value=Response()) - - await middleware.dispatch(mock_request, call_next) - - call_next.assert_called_once_with(mock_request) - mock_refresh_user.assert_not_called() - - @patch("airflow.api_fastapi.auth.middlewares.refresh_token.get_auth_manager") - @patch("airflow.api_fastapi.auth.middlewares.refresh_token.resolve_user_from_token") - @pytest.mark.asyncio - async def test_dispatch_no_refreshed_token( - self, mock_resolve_user_from_token, mock_get_auth_manager, middleware, mock_request, mock_user - ): - mock_request.cookies = {COOKIE_NAME_JWT_TOKEN: "valid_token"} - mock_resolve_user_from_token.return_value = mock_user - mock_auth_manager = MagicMock() - mock_get_auth_manager.return_value = mock_auth_manager - mock_auth_manager.refresh_user.return_value = None - - call_next = AsyncMock(return_value=Response()) - await middleware.dispatch(mock_request, call_next) - - call_next.assert_called_once_with(mock_request) - mock_resolve_user_from_token.assert_called_once_with("valid_token") - mock_auth_manager.generate_jwt.assert_not_called() - - @pytest.mark.asyncio - @patch("airflow.api_fastapi.auth.middlewares.refresh_token.get_auth_manager") - @patch("airflow.api_fastapi.auth.middlewares.refresh_token.resolve_user_from_token") - @patch("airflow.api_fastapi.auth.middlewares.refresh_token.conf") - async def test_dispatch_with_refreshed_user( - self, - mock_conf, - mock_resolve_user_from_token, - mock_get_auth_manager, - middleware, - mock_request, - mock_user, - ): - refreshed_user = MagicMock(spec=BaseUser) - mock_request.cookies = {COOKIE_NAME_JWT_TOKEN: "valid_token"} - mock_resolve_user_from_token.return_value = mock_user - mock_auth_manager = MagicMock() - mock_get_auth_manager.return_value = mock_auth_manager - mock_auth_manager.refresh_user.return_value = refreshed_user - mock_auth_manager.generate_jwt.return_value = "new_token" - mock_conf.get.return_value = "" - - call_next = AsyncMock(return_value=Response()) - response = await middleware.dispatch(mock_request, call_next) - - assert mock_request.state.user == refreshed_user - call_next.assert_called_once_with(mock_request) - mock_resolve_user_from_token.assert_called_once_with("valid_token") - mock_auth_manager.refresh_user.assert_called_once_with(user=mock_user) - mock_auth_manager.generate_jwt.assert_called_once_with(refreshed_user) - set_cookie_headers = response.headers.get("set-cookie", "") - assert f"{COOKIE_NAME_JWT_TOKEN}=new_token" in set_cookie_headers diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py index d4a5e5869e3..20a8e329a2a 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py @@ -16,13 +16,17 @@ # under the License. from __future__ import annotations +import os from unittest.mock import MagicMock, patch import pytest +from airflow.api_fastapi.auth.managers.base_auth_manager import COOKIE_NAME_JWT_TOKEN + from tests_common.test_utils.config import conf_vars AUTH_MANAGER_LOGIN_URL = "http://some_login_url" +AUTH_MANAGER_REFRESH_URL = "http://some_refresh_url" AUTH_MANAGER_LOGOUT_URL = "http://some_logout_url" pytestmark = pytest.mark.db_test @@ -33,6 +37,7 @@ class TestAuthEndpoint: def setup(self, test_client) -> None: auth_manager_mock = MagicMock() auth_manager_mock.get_url_login.return_value = AUTH_MANAGER_LOGIN_URL + auth_manager_mock.get_url_refresh.return_value = AUTH_MANAGER_REFRESH_URL auth_manager_mock.get_url_logout.return_value = AUTH_MANAGER_LOGOUT_URL test_client.app.state.auth_manager = auth_manager_mock @@ -94,3 +99,59 @@ class TestLogout(TestAuthEndpoint): assert response.status_code == 307 assert response.headers["location"] == expected_redirection + if delete_cookies: + cookies = response.headers.get_list("set-cookie") + assert any(f"{COOKIE_NAME_JWT_TOKEN}=" in c for c in cookies) + + +class TestRefresh(TestAuthEndpoint): + @pytest.mark.parametrize( + "params", + [ + {}, + {"next": None}, + {"next": "http://localhost:8080"}, + {"next": "http://localhost:8080", "other_param": "something_else"}, + ], + ) + @patch("airflow.api_fastapi.core_api.routes.public.auth.is_safe_url", return_value=True) + def test_should_respond_307(self, mock_is_safe_url, test_client, params): + response = test_client.get("/auth/refresh", follow_redirects=False, params=params) + + assert response.status_code == 307 + assert ( + response.headers["location"] == f"{AUTH_MANAGER_REFRESH_URL}?next={params.get('next')}" + if params.get("next") + else AUTH_MANAGER_REFRESH_URL + ) + + @patch.dict(os.environ, {"AIRFLOW__API__BASE_URL": "http://localhost:8080/"}) + @pytest.mark.parametrize( + "params", + [ + {}, + {"next": None}, + {"next": "http://localhost:8080"}, + {"next": "http://localhost:8080", "other_param": "something_else"}, + ], + ) + @patch("airflow.api_fastapi.core_api.routes.public.auth.is_safe_url", return_value=True) + def test_refresh_url_is_none(self, mock_is_safe_url, test_client, params): + test_client.app.state.auth_manager.get_url_refresh.return_value = None + response = test_client.get("/auth/refresh", follow_redirects=False, params=params) + + assert response.status_code == 307 + assert response.headers["location"] == "http://localhost:8080/auth/logout" + + @pytest.mark.parametrize( + "params", + [ + {"next": "http://fake_domain.com:8080"}, + {"next": "http://localhost:8080/../../up"}, + ], + ) + @conf_vars({("api", "base_url"): "http://localhost:8080/prefix"}) + def test_should_respond_400(self, test_client, params): + response = test_client.get("/auth/refresh", follow_redirects=False, params=params) + + assert response.status_code == 400 diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py index 95734bdaa37..8fe41ede186 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py @@ -22,6 +22,7 @@ from airflow.api_fastapi.core_api.routes.public import authenticated_router, pub NO_AUTH_PATHS = { "/api/v2/auth/login", "/api/v2/auth/logout", + "/api/v2/auth/refresh", "/api/v2/version", "/api/v2/monitor/health", } diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py index 5d42c8bb160..8fdaa0b03e9 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py @@ -19,8 +19,6 @@ from __future__ import annotations import argparse import json import logging -import time -from base64 import urlsafe_b64decode from typing import TYPE_CHECKING, Any from urllib.parse import urljoin @@ -110,15 +108,9 @@ class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]): base_url = conf.get("api", "base_url", fallback="/") return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login") - def refresh_user(self, *, user: KeycloakAuthManagerUser) -> KeycloakAuthManagerUser | None: - if self._token_expired(user.access_token): - client = self.get_keycloak_client() - tokens = client.refresh_token(user.refresh_token) - user.refresh_token = tokens["refresh_token"] - user.access_token = tokens["access_token"] - return user - - return None + def get_url_refresh(self) -> str | None: + base_url = conf.get("api", "base_url", fallback="/") + return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/refresh") def is_authorized_configuration( self, @@ -374,17 +366,3 @@ class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]): "Authorization": f"Bearer {access_token}", "Content-Type": "application/x-www-form-urlencoded", } - - @staticmethod - def _token_expired(token: str) -> bool: - """ - Check whether a JWT token is expired. - - :meta private: - - :param token: the token - """ - payload_b64 = token.split(".")[1] + "==" - payload_bytes = urlsafe_b64decode(payload_b64) - payload = json.loads(payload_bytes) - return payload["exp"] < int(time.time()) diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py b/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py index 2604da76581..9cfc1984e71 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py +++ b/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, AsyncMock, Mock, patch + +import pytest from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX @@ -74,3 +76,76 @@ class TestLoginRouter: def test_login_callback_without_code(self, client): response = client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX + "/login_callback") assert response.status_code == 400 + + @patch("airflow.api_fastapi.core_api.security.get_user", new_callable=AsyncMock) + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + @patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client") + @patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager") + @pytest.mark.asyncio + async def test_refresh( + self, + mock_get_auth_manager, + mock_get_keycloak_client, + mock_sec_get_auth_manager, + mock_get_user, + client, + ): + mock_user = Mock() + mock_get_user.return_value = mock_user + mock_auth_manager_sec = Mock() + mock_sec_get_auth_manager.return_value = mock_auth_manager_sec + mock_auth_manager_sec.get_user_from_token = AsyncMock(return_value=mock_user) + mock_get_keycloak_client.refresh_token.return_value = { + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + } + + mock_auth_manager = Mock() + mock_get_auth_manager.return_value = mock_auth_manager + mock_auth_manager.generate_jwt.return_value = "new_token" + + next_url = "http://localhost:8080" + response = client.get( + AUTH_MANAGER_FASTAPI_APP_PREFIX + "/refresh", + headers={"Authorization": "Bearer refresh_token"}, + follow_redirects=False, + params={"next": next_url}, + ) + + assert response.status_code == 303 + assert "_token" in response.cookies + + assert "location" in response.headers + assert response.headers["location"] == next_url + + # Test when user is None or refresh_token is not set + @patch("airflow.api_fastapi.core_api.security.get_user", new_callable=AsyncMock) + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + @patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client") + @patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager") + @pytest.mark.asyncio + async def test_refresh_user_none( + self, + mock_get_auth_manager, + mock_get_keycloak_client, + mock_sec_get_auth_manager, + mock_get_user, + client, + ): + mock_user = None + mock_get_user.return_value = mock_user + mock_auth_manager_sec = Mock() + mock_sec_get_auth_manager.return_value = mock_auth_manager_sec + mock_auth_manager_sec.get_user_from_token = AsyncMock(return_value=mock_user) + + next_url = "http://localhost:8080" + response = client.get( + AUTH_MANAGER_FASTAPI_APP_PREFIX + "/refresh", + headers={"Authorization": "Bearer refresh_token"}, + follow_redirects=False, + params={"next": next_url}, + ) + + assert response.status_code == 400 + assert "_token" not in response.cookies + assert "location" not in response.headers diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py index 7c00c74f1ab..2e7478b2991 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py +++ b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py @@ -67,7 +67,6 @@ def auth_manager(): def user(): user = Mock() user.access_token = "access_token" - user.refresh_token = "refresh_token" return user @@ -103,32 +102,6 @@ class TestKeycloakAuthManager: result = auth_manager.get_url_login() assert result == f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login" - @patch.object(KeycloakAuthManager, "_token_expired") - def test_refresh_user_not_expired(self, mock_token_expired, auth_manager): - mock_token_expired.return_value = False - - result = auth_manager.refresh_user(user=Mock()) - - assert result is None - - @patch.object(KeycloakAuthManager, "get_keycloak_client") - @patch.object(KeycloakAuthManager, "_token_expired") - def test_refresh_user_expired(self, mock_token_expired, mock_get_keycloak_client, auth_manager, user): - mock_token_expired.return_value = True - keycloak_client = Mock() - keycloak_client.refresh_token.return_value = { - "access_token": "new_access_token", - "refresh_token": "new_refresh_token", - } - - mock_get_keycloak_client.return_value = keycloak_client - - result = auth_manager.refresh_user(user=user) - - keycloak_client.refresh_token.assert_called_with("refresh_token") - assert result.access_token == "new_access_token" - assert result.refresh_token == "new_refresh_token" - @pytest.mark.parametrize( "function, method, details, permission, attributes", [ @@ -453,15 +426,3 @@ class TestKeycloakAuthManager: def test_get_cli_commands_return_cli_commands(self, auth_manager): assert len(auth_manager.get_cli_commands()) == 1 - - @pytest.mark.parametrize( - "expiration, expected", - [ - (-30, True), - (30, False), - ], - ) - def test_token_expired(self, auth_manager, expiration, expected): - token = auth_manager._get_token_signer(expiration_time_in_seconds=expiration).generate({}) - - assert KeycloakAuthManager._token_expired(token) is expected
