This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new 6b9f3cd690c Update refresh token flow (#55506) (#58649)
6b9f3cd690c is described below
commit 6b9f3cd690cdfe3c1cf657dfdca3f655e7fcb556
Author: Vincent <[email protected]>
AuthorDate: Tue Nov 25 07:40:09 2025 -0500
Update refresh token flow (#55506) (#58649)
---
airflow-core/src/airflow/api_fastapi/app.py | 2 +
.../api_fastapi/auth/managers/base_auth_manager.py | 11 ++-
.../api_fastapi/auth/middlewares/__init__.py | 17 ++++
.../api_fastapi/auth/middlewares/refresh_token.py | 68 +++++++++++++
.../src/airflow/api_fastapi/core_api/app.py | 6 ++
.../core_api/openapi/v2-rest-api-generated.yaml | 34 -------
.../api_fastapi/core_api/routes/public/auth.py | 20 ----
.../src/airflow/ui/openapi-gen/queries/common.ts | 6 --
.../ui/openapi-gen/queries/ensureQueryData.ts | 11 ---
.../src/airflow/ui/openapi-gen/queries/prefetch.ts | 11 ---
.../src/airflow/ui/openapi-gen/queries/queries.ts | 11 ---
.../src/airflow/ui/openapi-gen/queries/suspense.ts | 11 ---
.../ui/openapi-gen/requests/services.gen.ts | 24 +----
.../airflow/ui/openapi-gen/requests/types.gen.ts | 25 -----
.../auth/managers/test_base_auth_manager.py | 3 +
.../unit/api_fastapi/auth/middlewares/__init__.py | 17 ++++
.../auth/middlewares/test_refresh_token.py | 106 +++++++++++++++++++++
.../core_api/routes/public/test_auth.py | 61 ------------
.../api_fastapi/core_api/routes/test_routes.py | 1 -
.../keycloak/auth_manager/keycloak_auth_manager.py | 28 +++++-
.../keycloak/auth_manager/routes/test_login.py | 77 +--------------
.../auth_manager/test_keycloak_auth_manager.py | 39 ++++++++
22 files changed, 292 insertions(+), 297 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/app.py
b/airflow-core/src/airflow/api_fastapi/app.py
index 58cfb157083..7c05295807e 100644
--- a/airflow-core/src/airflow/api_fastapi/app.py
+++ b/airflow-core/src/airflow/api_fastapi/app.py
@@ -29,6 +29,7 @@ from airflow.api_fastapi.core_api.app import (
init_config,
init_error_handlers,
init_flask_plugins,
+ init_middlewares,
init_ui_plugins,
init_views,
)
@@ -99,6 +100,7 @@ def create_app(apps: str = "all") -> FastAPI:
init_ui_plugins(app)
init_views(app) # Core views need to be the last routes added - it
has a catch all route
init_error_handlers(app)
+ init_middlewares(app)
init_config(app)
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
index d57dd3cdc39..b8656cd068f 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
@@ -135,12 +135,15 @@ class BaseAuthManager(Generic[T], LoggingMixin,
metaclass=ABCMeta):
"""
return None
- def get_url_refresh(self) -> str | None:
+ def refresh_user(self, *, user: T) -> T | None:
"""
- Return the URL to refresh the authentication token.
+ Refresh the user if needed.
- This is used to refresh the authentication token when it expires.
- The default implementation returns None, which means that the auth
manager does not support refresh token.
+ By default, does nothing. Some auth managers might need to refresh the
user to, for instance,
+ refresh some tokens that are needed to communicate with a service/tool.
+
+ This method is called by every single request, it must be lightweight
otherwise the overall API
+ server latency will increase.
"""
return None
diff --git a/airflow-core/src/airflow/api_fastapi/auth/middlewares/__init__.py
b/airflow-core/src/airflow/api_fastapi/auth/middlewares/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/auth/middlewares/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py
b/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py
new file mode 100644
index 00000000000..f304eb9517f
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/auth/middlewares/refresh_token.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from fastapi import Request
+from starlette.middleware.base import BaseHTTPMiddleware
+
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.api_fastapi.auth.managers.base_auth_manager import
COOKIE_NAME_JWT_TOKEN
+from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
+from airflow.api_fastapi.core_api.security import resolve_user_from_token
+from airflow.configuration import conf
+
+
+class JWTRefreshMiddleware(BaseHTTPMiddleware):
+ """
+ Middleware to handle JWT token refresh.
+
+ This middleware:
+ 1. Extracts JWT token from cookies and build the user from the token
+ 2. Calls ``refresh_user`` method from auth manager with the user
+ 3. If ``refresh_user`` returns a user, generate a JWT token based upon
this user and send it in the
+ response as cookie
+ """
+
+ async def dispatch(self, request: Request, call_next):
+ new_user = None
+ current_token = request.cookies.get(COOKIE_NAME_JWT_TOKEN)
+ if current_token:
+ new_user = await self._refresh_user(current_token)
+ if new_user:
+ request.state.user = new_user
+
+ response = await call_next(request)
+
+ if new_user:
+ # If we created a new user, serialize it and set it as a cookie
+ new_token = get_auth_manager().generate_jwt(new_user)
+ secure = bool(conf.get("api", "ssl_cert", fallback=""))
+ response.set_cookie(
+ COOKIE_NAME_JWT_TOKEN,
+ new_token,
+ httponly=True,
+ secure=secure,
+ samesite="lax",
+ )
+
+ return response
+
+ @staticmethod
+ async def _refresh_user(current_token: str) -> BaseUser | None:
+ user = await resolve_user_from_token(current_token)
+ return get_auth_manager().refresh_user(user=user)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/app.py
b/airflow-core/src/airflow/api_fastapi/core_api/app.py
index cd34b28f3cb..8db1fa66680 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/app.py
@@ -181,6 +181,12 @@ def init_error_handlers(app: FastAPI) -> None:
app.add_exception_handler(handler.exception_cls,
handler.exception_handler)
+def init_middlewares(app: FastAPI) -> None:
+ from airflow.api_fastapi.auth.middlewares.refresh_token import
JWTRefreshMiddleware
+
+ app.add_middleware(JWTRefreshMiddleware)
+
+
def init_ui_plugins(app: FastAPI) -> None:
"""Initialize UI plugins."""
from airflow import plugins_manager
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
index bef548133c6..ceaf90b60f8 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
+++
b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
@@ -8478,40 +8478,6 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
- /api/v2/auth/refresh:
- get:
- tags:
- - Login
- summary: Refresh
- description: Refresh the authentication token.
- operationId: refresh
- parameters:
- - name: next
- in: query
- required: false
- schema:
- anyOf:
- - type: string
- - type: 'null'
- title: Next
- responses:
- '200':
- description: Successful Response
- content:
- application/json:
- schema: {}
- '307':
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPExceptionResponse'
- description: Temporary Redirect
- '422':
- description: Validation Error
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/HTTPValidationError'
components:
schemas:
AppBuilderMenuItemResponse:
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
index d1e770b9277..a97b7fd9972 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/auth.py
@@ -65,23 +65,3 @@ def logout(request: Request) -> RedirectResponse:
)
return response
-
-
-@auth_router.get(
- "/refresh",
-
responses=create_openapi_http_exception_doc([status.HTTP_307_TEMPORARY_REDIRECT]),
-)
-def refresh(request: Request, next: None | str = None) -> RedirectResponse:
- """Refresh the authentication token."""
- refresh_url = request.app.state.auth_manager.get_url_refresh()
-
- if not refresh_url:
- return RedirectResponse(f"{conf.get('api', 'base_url',
fallback='/')}auth/logout")
-
- if next and not is_safe_url(next, request=request):
- raise HTTPException(status_code=400, detail="Invalid or unsafe next
URL")
-
- if next:
- refresh_url += f"?next={next}"
-
- return RedirectResponse(refresh_url)
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
index de2bbc9647e..994fe2d91c1 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
@@ -753,12 +753,6 @@ export type LoginServiceLogoutDefaultResponse =
Awaited<ReturnType<typeof LoginS
export type LoginServiceLogoutQueryResult<TData =
LoginServiceLogoutDefaultResponse, TError = unknown> = UseQueryResult<TData,
TError>;
export const useLoginServiceLogoutKey = "LoginServiceLogout";
export const UseLoginServiceLogoutKeyFn = (queryKey?: Array<unknown>) =>
[useLoginServiceLogoutKey, ...(queryKey ?? [])];
-export type LoginServiceRefreshDefaultResponse = Awaited<ReturnType<typeof
LoginService.refresh>>;
-export type LoginServiceRefreshQueryResult<TData =
LoginServiceRefreshDefaultResponse, TError = unknown> = UseQueryResult<TData,
TError>;
-export const useLoginServiceRefreshKey = "LoginServiceRefresh";
-export const UseLoginServiceRefreshKeyFn = ({ next }: {
- next?: string;
-} = {}, queryKey?: Array<unknown>) => [useLoginServiceRefreshKey, ...(queryKey
?? [{ next }])];
export type AuthLinksServiceGetAuthMenusDefaultResponse =
Awaited<ReturnType<typeof AuthLinksService.getAuthMenus>>;
export type AuthLinksServiceGetAuthMenusQueryResult<TData =
AuthLinksServiceGetAuthMenusDefaultResponse, TError = unknown> =
UseQueryResult<TData, TError>;
export const useAuthLinksServiceGetAuthMenusKey =
"AuthLinksServiceGetAuthMenus";
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts
b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts
index c38c3958500..41fe0005dcd 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts
@@ -1431,17 +1431,6 @@ export const ensureUseLoginServiceLoginData =
(queryClient: QueryClient, { next
*/
export const ensureUseLoginServiceLogoutData = (queryClient: QueryClient) =>
queryClient.ensureQueryData({ queryKey: Common.UseLoginServiceLogoutKeyFn(),
queryFn: () => LoginService.logout() });
/**
-* Refresh
-* Refresh the authentication token.
-* @param data The data for the request.
-* @param data.next
-* @returns unknown Successful Response
-* @throws ApiError
-*/
-export const ensureUseLoginServiceRefreshData = (queryClient: QueryClient, {
next }: {
- next?: string;
-} = {}) => queryClient.ensureQueryData({ queryKey:
Common.UseLoginServiceRefreshKeyFn({ next }), queryFn: () =>
LoginService.refresh({ next }) });
-/**
* Get Auth Menus
* @returns MenuItemCollectionResponse Successful Response
* @throws ApiError
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
index ec8e3471a4c..fa6162ec588 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
@@ -1431,17 +1431,6 @@ export const prefetchUseLoginServiceLogin =
(queryClient: QueryClient, { next }:
*/
export const prefetchUseLoginServiceLogout = (queryClient: QueryClient) =>
queryClient.prefetchQuery({ queryKey: Common.UseLoginServiceLogoutKeyFn(),
queryFn: () => LoginService.logout() });
/**
-* Refresh
-* Refresh the authentication token.
-* @param data The data for the request.
-* @param data.next
-* @returns unknown Successful Response
-* @throws ApiError
-*/
-export const prefetchUseLoginServiceRefresh = (queryClient: QueryClient, {
next }: {
- next?: string;
-} = {}) => queryClient.prefetchQuery({ queryKey:
Common.UseLoginServiceRefreshKeyFn({ next }), queryFn: () =>
LoginService.refresh({ next }) });
-/**
* Get Auth Menus
* @returns MenuItemCollectionResponse Successful Response
* @throws ApiError
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
index 5211c77b349..955e5049e60 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
@@ -1431,17 +1431,6 @@ export const useLoginServiceLogin = <TData =
Common.LoginServiceLoginDefaultResp
*/
export const useLoginServiceLogout = <TData =
Common.LoginServiceLogoutDefaultResponse, TError = unknown, TQueryKey extends
Array<unknown> = unknown[]>(queryKey?: TQueryKey, options?:
Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">) =>
useQuery<TData, TError>({ queryKey:
Common.UseLoginServiceLogoutKeyFn(queryKey), queryFn: () =>
LoginService.logout() as TData, ...options });
/**
-* Refresh
-* Refresh the authentication token.
-* @param data The data for the request.
-* @param data.next
-* @returns unknown Successful Response
-* @throws ApiError
-*/
-export const useLoginServiceRefresh = <TData =
Common.LoginServiceRefreshDefaultResponse, TError = unknown, TQueryKey extends
Array<unknown> = unknown[]>({ next }: {
- next?: string;
-} = {}, queryKey?: TQueryKey, options?: Omit<UseQueryOptions<TData, TError>,
"queryKey" | "queryFn">) => useQuery<TData, TError>({ queryKey:
Common.UseLoginServiceRefreshKeyFn({ next }, queryKey), queryFn: () =>
LoginService.refresh({ next }) as TData, ...options });
-/**
* Get Auth Menus
* @returns MenuItemCollectionResponse Successful Response
* @throws ApiError
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts
b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts
index 9b980d1cbec..aafe12ed9bc 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts
@@ -1431,17 +1431,6 @@ export const useLoginServiceLoginSuspense = <TData =
Common.LoginServiceLoginDef
*/
export const useLoginServiceLogoutSuspense = <TData =
Common.LoginServiceLogoutDefaultResponse, TError = unknown, TQueryKey extends
Array<unknown> = unknown[]>(queryKey?: TQueryKey, options?:
Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">) =>
useSuspenseQuery<TData, TError>({ queryKey:
Common.UseLoginServiceLogoutKeyFn(queryKey), queryFn: () =>
LoginService.logout() as TData, ...options });
/**
-* Refresh
-* Refresh the authentication token.
-* @param data The data for the request.
-* @param data.next
-* @returns unknown Successful Response
-* @throws ApiError
-*/
-export const useLoginServiceRefreshSuspense = <TData =
Common.LoginServiceRefreshDefaultResponse, TError = unknown, TQueryKey extends
Array<unknown> = unknown[]>({ next }: {
- next?: string;
-} = {}, queryKey?: TQueryKey, options?: Omit<UseQueryOptions<TData, TError>,
"queryKey" | "queryFn">) => useSuspenseQuery<TData, TError>({ queryKey:
Common.UseLoginServiceRefreshKeyFn({ next }, queryKey), queryFn: () =>
LoginService.refresh({ next }) as TData, ...options });
-/**
* Get Auth Menus
* @returns MenuItemCollectionResponse Successful Response
* @throws ApiError
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
index c9cb3594cae..91ded2463e7 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -3,7 +3,7 @@
import type { CancelablePromise } from './core/CancelablePromise';
import { OpenAPI } from './core/OpenAPI';
import { request as __request } from './core/request';
-import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData,
GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse,
GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData,
CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse,
GetAssetQueuedEventsData, GetAssetQueuedEventsResponse,
DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData,
GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse,
Dele [...]
+import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData,
GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse,
GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData,
CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse,
GetAssetQueuedEventsData, GetAssetQueuedEventsResponse,
DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData,
GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse,
Dele [...]
export class AssetService {
/**
@@ -3714,28 +3714,6 @@ export class LoginService {
});
}
- /**
- * Refresh
- * Refresh the authentication token.
- * @param data The data for the request.
- * @param data.next
- * @returns unknown Successful Response
- * @throws ApiError
- */
- public static refresh(data: RefreshData = {}):
CancelablePromise<RefreshResponse> {
- return __request(OpenAPI, {
- method: 'GET',
- url: '/api/v2/auth/refresh',
- query: {
- next: data.next
- },
- errors: {
- 307: 'Temporary Redirect',
- 422: 'Validation Error'
- }
- });
- }
-
}
export class AuthLinksService {
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
index 18efcbc82ba..006f33286fd 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -3208,12 +3208,6 @@ export type LoginResponse = unknown;
export type LogoutResponse = unknown;
-export type RefreshData = {
- next?: string | null;
-};
-
-export type RefreshResponse = unknown;
-
export type GetAuthMenusResponse = MenuItemCollectionResponse;
export type GetDependenciesData = {
@@ -6284,25 +6278,6 @@ export type $OpenApiTs = {
};
};
};
- '/api/v2/auth/refresh': {
- get: {
- req: RefreshData;
- res: {
- /**
- * Successful Response
- */
- 200: unknown;
- /**
- * Temporary Redirect
- */
- 307: HTTPExceptionResponse;
- /**
- * Validation Error
- */
- 422: HTTPValidationError;
- };
- };
- };
'/ui/auth/menus': {
get: {
res: {
diff --git
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
index fde84665609..98217cef115 100644
---
a/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
+++
b/airflow-core/tests/unit/api_fastapi/auth/managers/test_base_auth_manager.py
@@ -165,6 +165,9 @@ class TestBaseAuthManager:
def test_get_fastapi_app_return_none(self, auth_manager):
assert auth_manager.get_fastapi_app() is None
+ def test_refresh_user_default_returns_none(self, auth_manager):
+ assert
auth_manager.refresh_user(user=BaseAuthManagerUserTest(name="test")) is None
+
def test_get_url_logout_return_none(self, auth_manager):
assert auth_manager.get_url_logout() is None
diff --git a/airflow-core/tests/unit/api_fastapi/auth/middlewares/__init__.py
b/airflow-core/tests/unit/api_fastapi/auth/middlewares/__init__.py
new file mode 100644
index 00000000000..217e5db9607
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/auth/middlewares/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/airflow-core/tests/unit/api_fastapi/auth/middlewares/test_refresh_token.py
b/airflow-core/tests/unit/api_fastapi/auth/middlewares/test_refresh_token.py
new file mode 100644
index 00000000000..87648a2be2b
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/auth/middlewares/test_refresh_token.py
@@ -0,0 +1,106 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from fastapi import Request, Response
+
+from airflow.api_fastapi.auth.managers.base_auth_manager import
COOKIE_NAME_JWT_TOKEN
+from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
+from airflow.api_fastapi.auth.middlewares.refresh_token import
JWTRefreshMiddleware
+
+
+class TestJWTRefreshMiddleware:
+ @pytest.fixture
+ def middleware(self):
+ return JWTRefreshMiddleware(app=MagicMock())
+
+ @pytest.fixture
+ def mock_request(self):
+ request = MagicMock(spec=Request)
+ request.cookies = {}
+ request.state = MagicMock()
+ return request
+
+ @pytest.fixture
+ def mock_user(self):
+ return MagicMock(spec=BaseUser)
+
+ @patch.object(JWTRefreshMiddleware, "_refresh_user")
+ @pytest.mark.asyncio
+ async def test_dispatch_no_token(self, mock_refresh_user, middleware,
mock_request):
+ call_next = AsyncMock(return_value=Response())
+
+ await middleware.dispatch(mock_request, call_next)
+
+ call_next.assert_called_once_with(mock_request)
+ mock_refresh_user.assert_not_called()
+
+
@patch("airflow.api_fastapi.auth.middlewares.refresh_token.get_auth_manager")
+
@patch("airflow.api_fastapi.auth.middlewares.refresh_token.resolve_user_from_token")
+ @pytest.mark.asyncio
+ async def test_dispatch_no_refreshed_token(
+ self, mock_resolve_user_from_token, mock_get_auth_manager, middleware,
mock_request, mock_user
+ ):
+ mock_request.cookies = {COOKIE_NAME_JWT_TOKEN: "valid_token"}
+ mock_resolve_user_from_token.return_value = mock_user
+ mock_auth_manager = MagicMock()
+ mock_get_auth_manager.return_value = mock_auth_manager
+ mock_auth_manager.refresh_user.return_value = None
+
+ call_next = AsyncMock(return_value=Response())
+ await middleware.dispatch(mock_request, call_next)
+
+ call_next.assert_called_once_with(mock_request)
+ mock_resolve_user_from_token.assert_called_once_with("valid_token")
+ mock_auth_manager.generate_jwt.assert_not_called()
+
+ @pytest.mark.asyncio
+
@patch("airflow.api_fastapi.auth.middlewares.refresh_token.get_auth_manager")
+
@patch("airflow.api_fastapi.auth.middlewares.refresh_token.resolve_user_from_token")
+ @patch("airflow.api_fastapi.auth.middlewares.refresh_token.conf")
+ async def test_dispatch_with_refreshed_user(
+ self,
+ mock_conf,
+ mock_resolve_user_from_token,
+ mock_get_auth_manager,
+ middleware,
+ mock_request,
+ mock_user,
+ ):
+ refreshed_user = MagicMock(spec=BaseUser)
+ mock_request.cookies = {COOKIE_NAME_JWT_TOKEN: "valid_token"}
+ mock_resolve_user_from_token.return_value = mock_user
+ mock_auth_manager = MagicMock()
+ mock_get_auth_manager.return_value = mock_auth_manager
+ mock_auth_manager.refresh_user.return_value = refreshed_user
+ mock_auth_manager.generate_jwt.return_value = "new_token"
+ mock_conf.get.return_value = ""
+
+ call_next = AsyncMock(return_value=Response())
+ response = await middleware.dispatch(mock_request, call_next)
+
+ assert mock_request.state.user == refreshed_user
+ call_next.assert_called_once_with(mock_request)
+ mock_resolve_user_from_token.assert_called_once_with("valid_token")
+ mock_auth_manager.refresh_user.assert_called_once_with(user=mock_user)
+ mock_auth_manager.generate_jwt.assert_called_once_with(refreshed_user)
+ set_cookie_headers = response.headers.get("set-cookie", "")
+ assert f"{COOKIE_NAME_JWT_TOKEN}=new_token" in set_cookie_headers
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py
index 20a8e329a2a..d4a5e5869e3 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_auth.py
@@ -16,17 +16,13 @@
# under the License.
from __future__ import annotations
-import os
from unittest.mock import MagicMock, patch
import pytest
-from airflow.api_fastapi.auth.managers.base_auth_manager import
COOKIE_NAME_JWT_TOKEN
-
from tests_common.test_utils.config import conf_vars
AUTH_MANAGER_LOGIN_URL = "http://some_login_url"
-AUTH_MANAGER_REFRESH_URL = "http://some_refresh_url"
AUTH_MANAGER_LOGOUT_URL = "http://some_logout_url"
pytestmark = pytest.mark.db_test
@@ -37,7 +33,6 @@ class TestAuthEndpoint:
def setup(self, test_client) -> None:
auth_manager_mock = MagicMock()
auth_manager_mock.get_url_login.return_value = AUTH_MANAGER_LOGIN_URL
- auth_manager_mock.get_url_refresh.return_value =
AUTH_MANAGER_REFRESH_URL
auth_manager_mock.get_url_logout.return_value = AUTH_MANAGER_LOGOUT_URL
test_client.app.state.auth_manager = auth_manager_mock
@@ -99,59 +94,3 @@ class TestLogout(TestAuthEndpoint):
assert response.status_code == 307
assert response.headers["location"] == expected_redirection
- if delete_cookies:
- cookies = response.headers.get_list("set-cookie")
- assert any(f"{COOKIE_NAME_JWT_TOKEN}=" in c for c in cookies)
-
-
-class TestRefresh(TestAuthEndpoint):
- @pytest.mark.parametrize(
- "params",
- [
- {},
- {"next": None},
- {"next": "http://localhost:8080"},
- {"next": "http://localhost:8080", "other_param": "something_else"},
- ],
- )
- @patch("airflow.api_fastapi.core_api.routes.public.auth.is_safe_url",
return_value=True)
- def test_should_respond_307(self, mock_is_safe_url, test_client, params):
- response = test_client.get("/auth/refresh", follow_redirects=False,
params=params)
-
- assert response.status_code == 307
- assert (
- response.headers["location"] ==
f"{AUTH_MANAGER_REFRESH_URL}?next={params.get('next')}"
- if params.get("next")
- else AUTH_MANAGER_REFRESH_URL
- )
-
- @patch.dict(os.environ, {"AIRFLOW__API__BASE_URL":
"http://localhost:8080/"})
- @pytest.mark.parametrize(
- "params",
- [
- {},
- {"next": None},
- {"next": "http://localhost:8080"},
- {"next": "http://localhost:8080", "other_param": "something_else"},
- ],
- )
- @patch("airflow.api_fastapi.core_api.routes.public.auth.is_safe_url",
return_value=True)
- def test_refresh_url_is_none(self, mock_is_safe_url, test_client, params):
- test_client.app.state.auth_manager.get_url_refresh.return_value = None
- response = test_client.get("/auth/refresh", follow_redirects=False,
params=params)
-
- assert response.status_code == 307
- assert response.headers["location"] ==
"http://localhost:8080/auth/logout"
-
- @pytest.mark.parametrize(
- "params",
- [
- {"next": "http://fake_domain.com:8080"},
- {"next": "http://localhost:8080/../../up"},
- ],
- )
- @conf_vars({("api", "base_url"): "http://localhost:8080/prefix"})
- def test_should_respond_400(self, test_client, params):
- response = test_client.get("/auth/refresh", follow_redirects=False,
params=params)
-
- assert response.status_code == 400
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py
index 8fe41ede186..95734bdaa37 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/test_routes.py
@@ -22,7 +22,6 @@ from airflow.api_fastapi.core_api.routes.public import
authenticated_router, pub
NO_AUTH_PATHS = {
"/api/v2/auth/login",
"/api/v2/auth/logout",
- "/api/v2/auth/refresh",
"/api/v2/version",
"/api/v2/monitor/health",
}
diff --git
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
index 8fdaa0b03e9..5d42c8bb160 100644
---
a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
+++
b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py
@@ -19,6 +19,8 @@ from __future__ import annotations
import argparse
import json
import logging
+import time
+from base64 import urlsafe_b64decode
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
@@ -108,9 +110,15 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
base_url = conf.get("api", "base_url", fallback="/")
return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
- def get_url_refresh(self) -> str | None:
- base_url = conf.get("api", "base_url", fallback="/")
- return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/refresh")
+ def refresh_user(self, *, user: KeycloakAuthManagerUser) ->
KeycloakAuthManagerUser | None:
+ if self._token_expired(user.access_token):
+ client = self.get_keycloak_client()
+ tokens = client.refresh_token(user.refresh_token)
+ user.refresh_token = tokens["refresh_token"]
+ user.access_token = tokens["access_token"]
+ return user
+
+ return None
def is_authorized_configuration(
self,
@@ -366,3 +374,17 @@ class
KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/x-www-form-urlencoded",
}
+
+ @staticmethod
+ def _token_expired(token: str) -> bool:
+ """
+ Check whether a JWT token is expired.
+
+ :meta private:
+
+ :param token: the token
+ """
+ payload_b64 = token.split(".")[1] + "=="
+ payload_bytes = urlsafe_b64decode(payload_b64)
+ payload = json.loads(payload_bytes)
+ return payload["exp"] < int(time.time())
diff --git
a/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py
b/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py
index 9cfc1984e71..2604da76581 100644
--- a/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py
+++ b/providers/keycloak/tests/unit/keycloak/auth_manager/routes/test_login.py
@@ -16,9 +16,7 @@
# under the License.
from __future__ import annotations
-from unittest.mock import ANY, AsyncMock, Mock, patch
-
-import pytest
+from unittest.mock import ANY, Mock, patch
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
@@ -76,76 +74,3 @@ class TestLoginRouter:
def test_login_callback_without_code(self, client):
response = client.get(AUTH_MANAGER_FASTAPI_APP_PREFIX +
"/login_callback")
assert response.status_code == 400
-
- @patch("airflow.api_fastapi.core_api.security.get_user",
new_callable=AsyncMock)
- @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
-
@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
-
@patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager")
- @pytest.mark.asyncio
- async def test_refresh(
- self,
- mock_get_auth_manager,
- mock_get_keycloak_client,
- mock_sec_get_auth_manager,
- mock_get_user,
- client,
- ):
- mock_user = Mock()
- mock_get_user.return_value = mock_user
- mock_auth_manager_sec = Mock()
- mock_sec_get_auth_manager.return_value = mock_auth_manager_sec
- mock_auth_manager_sec.get_user_from_token =
AsyncMock(return_value=mock_user)
- mock_get_keycloak_client.refresh_token.return_value = {
- "access_token": "new_access_token",
- "refresh_token": "new_refresh_token",
- }
-
- mock_auth_manager = Mock()
- mock_get_auth_manager.return_value = mock_auth_manager
- mock_auth_manager.generate_jwt.return_value = "new_token"
-
- next_url = "http://localhost:8080"
- response = client.get(
- AUTH_MANAGER_FASTAPI_APP_PREFIX + "/refresh",
- headers={"Authorization": "Bearer refresh_token"},
- follow_redirects=False,
- params={"next": next_url},
- )
-
- assert response.status_code == 303
- assert "_token" in response.cookies
-
- assert "location" in response.headers
- assert response.headers["location"] == next_url
-
- # Test when user is None or refresh_token is not set
- @patch("airflow.api_fastapi.core_api.security.get_user",
new_callable=AsyncMock)
- @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
-
@patch("airflow.providers.keycloak.auth_manager.routes.login.KeycloakAuthManager.get_keycloak_client")
-
@patch("airflow.providers.keycloak.auth_manager.routes.login.get_auth_manager")
- @pytest.mark.asyncio
- async def test_refresh_user_none(
- self,
- mock_get_auth_manager,
- mock_get_keycloak_client,
- mock_sec_get_auth_manager,
- mock_get_user,
- client,
- ):
- mock_user = None
- mock_get_user.return_value = mock_user
- mock_auth_manager_sec = Mock()
- mock_sec_get_auth_manager.return_value = mock_auth_manager_sec
- mock_auth_manager_sec.get_user_from_token =
AsyncMock(return_value=mock_user)
-
- next_url = "http://localhost:8080"
- response = client.get(
- AUTH_MANAGER_FASTAPI_APP_PREFIX + "/refresh",
- headers={"Authorization": "Bearer refresh_token"},
- follow_redirects=False,
- params={"next": next_url},
- )
-
- assert response.status_code == 400
- assert "_token" not in response.cookies
- assert "location" not in response.headers
diff --git
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
index 2e7478b2991..7c00c74f1ab 100644
---
a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
+++
b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py
@@ -67,6 +67,7 @@ def auth_manager():
def user():
user = Mock()
user.access_token = "access_token"
+ user.refresh_token = "refresh_token"
return user
@@ -102,6 +103,32 @@ class TestKeycloakAuthManager:
result = auth_manager.get_url_login()
assert result == f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login"
+ @patch.object(KeycloakAuthManager, "_token_expired")
+ def test_refresh_user_not_expired(self, mock_token_expired, auth_manager):
+ mock_token_expired.return_value = False
+
+ result = auth_manager.refresh_user(user=Mock())
+
+ assert result is None
+
+ @patch.object(KeycloakAuthManager, "get_keycloak_client")
+ @patch.object(KeycloakAuthManager, "_token_expired")
+ def test_refresh_user_expired(self, mock_token_expired,
mock_get_keycloak_client, auth_manager, user):
+ mock_token_expired.return_value = True
+ keycloak_client = Mock()
+ keycloak_client.refresh_token.return_value = {
+ "access_token": "new_access_token",
+ "refresh_token": "new_refresh_token",
+ }
+
+ mock_get_keycloak_client.return_value = keycloak_client
+
+ result = auth_manager.refresh_user(user=user)
+
+ keycloak_client.refresh_token.assert_called_with("refresh_token")
+ assert result.access_token == "new_access_token"
+ assert result.refresh_token == "new_refresh_token"
+
@pytest.mark.parametrize(
"function, method, details, permission, attributes",
[
@@ -426,3 +453,15 @@ class TestKeycloakAuthManager:
def test_get_cli_commands_return_cli_commands(self, auth_manager):
assert len(auth_manager.get_cli_commands()) == 1
+
+ @pytest.mark.parametrize(
+ "expiration, expected",
+ [
+ (-30, True),
+ (30, False),
+ ],
+ )
+ def test_token_expired(self, auth_manager, expiration, expected):
+ token =
auth_manager._get_token_signer(expiration_time_in_seconds=expiration).generate({})
+
+ assert KeycloakAuthManager._token_expired(token) is expected