This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 431608103a4 Handle token expired for UI to redirect to login page
(#47308)
431608103a4 is described below
commit 431608103a4e622c9f9bf01a41dd253c23fb9367
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Mon Mar 3 18:15:14 2025 +0100
Handle token expired for UI to redirect to login page (#47308)
---
airflow/api_fastapi/core_api/security.py | 17 +++++++----------
tests/api_fastapi/core_api/test_security.py | 17 +++++++++++++++--
2 files changed, 22 insertions(+), 12 deletions(-)
diff --git a/airflow/api_fastapi/core_api/security.py
b/airflow/api_fastapi/core_api/security.py
index 27a91eb2c49..2eb40cef6fc 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Annotated, Callable
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
-from jwt import InvalidTokenError
+from jwt import ExpiredSignatureError, InvalidTokenError
from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.models.base_user import BaseUser
@@ -52,6 +52,8 @@ def get_signer() -> JWTSigner:
def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
try:
return get_auth_manager().get_user_from_token(token_str)
+ except ExpiredSignatureError:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token Expired")
except InvalidTokenError:
raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
@@ -59,15 +61,10 @@ def get_user(token_str: Annotated[str,
Depends(oauth2_scheme)]) -> BaseUser:
async def get_user_with_exception_handling(request: Request) -> BaseUser |
None:
# Currently the UI does not support JWT authentication, this method
defines a fallback if no token is provided by the UI.
# We can remove this method when issue
https://github.com/apache/airflow/issues/44884 is done.
- try:
- token_str = await oauth2_scheme(request)
- if not token_str: # Handle None or empty token
- return None
- return get_user(token_str)
- except HTTPException as e:
- if e.status_code == status.HTTP_401_UNAUTHORIZED:
- return None
- raise e
+ token_str = await oauth2_scheme(request)
+ if not token_str: # Handle None or empty token
+ return None
+ return get_user(token_str)
def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity
| None = None) -> Callable:
diff --git a/tests/api_fastapi/core_api/test_security.py
b/tests/api_fastapi/core_api/test_security.py
index b9e1c58aa20..7824ecd171b 100644
--- a/tests/api_fastapi/core_api/test_security.py
+++ b/tests/api_fastapi/core_api/test_security.py
@@ -20,7 +20,7 @@ from unittest.mock import Mock, patch
import pytest
from fastapi import HTTPException
-from jwt import InvalidTokenError
+from jwt import ExpiredSignatureError, InvalidTokenError
from airflow.api_fastapi.app import create_app
from airflow.api_fastapi.core_api.security import get_user, requires_access_dag
@@ -58,7 +58,7 @@ class TestFastApiSecurity:
assert result == user
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
- def test_get_user_unsuccessful(self, mock_get_auth_manager):
+ def test_get_user_wrong_token(self, mock_get_auth_manager):
token_str = "test-token"
auth_manager = Mock()
@@ -70,6 +70,19 @@ class TestFastApiSecurity:
auth_manager.get_user_from_token.assert_called_once_with(token_str)
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ def test_get_user_expired_token(self, mock_get_auth_manager):
+ token_str = "test-token"
+
+ auth_manager = Mock()
+ auth_manager.get_user_from_token.side_effect = ExpiredSignatureError()
+ mock_get_auth_manager.return_value = auth_manager
+
+ with pytest.raises(HTTPException, match="Token Expired"):
+ get_user(token_str)
+
+ auth_manager.get_user_from_token.assert_called_once_with(token_str)
+
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_requires_access_dag_authorized(self, mock_get_auth_manager):
auth_manager = Mock()