This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a5313b32444 Add `get_user_from_token` method in auth manager interface
(#46135)
a5313b32444 is described below
commit a5313b324445b552597a3659faac8f02fbcdca08
Author: Vincent <[email protected]>
AuthorDate: Mon Jan 27 13:00:16 2025 -0500
Add `get_user_from_token` method in auth manager interface (#46135)
---
airflow/api_fastapi/core_api/security.py | 6 ++---
airflow/auth/managers/base_auth_manager.py | 32 ++++++++++++++++++++++-----
tests/api_fastapi/core_api/test_security.py | 25 ++++++++-------------
tests/auth/managers/test_base_auth_manager.py | 17 ++++++++++++++
4 files changed, 54 insertions(+), 26 deletions(-)
diff --git a/airflow/api_fastapi/core_api/security.py
b/airflow/api_fastapi/core_api/security.py
index 695a01ad478..50c1337c95c 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from functools import cache
-from typing import TYPE_CHECKING, Annotated, Any, Callable
+from typing import TYPE_CHECKING, Annotated, Callable
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
@@ -46,9 +46,7 @@ def get_signer() -> JWTSigner:
def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
try:
- signer = get_signer()
- payload: dict[str, Any] = signer.verify_token(token_str)
- return get_auth_manager().deserialize_user(payload)
+ return get_auth_manager().get_user_from_token(token_str)
except InvalidTokenError:
raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index 2f83a1c6c06..ebe10d282c6 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -17,9 +17,11 @@
# under the License.
from __future__ import annotations
+import logging
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generic, TypeVar
+from jwt import InvalidTokenError
from sqlalchemy import select
from airflow.auth.managers.models.base_user import BaseUser
@@ -61,6 +63,7 @@ if TYPE_CHECKING:
# TODO: Move this inside once all providers drop Airflow 2.x support.
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
+log = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseUser)
@@ -102,14 +105,18 @@ class BaseAuthManager(Generic[T], LoggingMixin):
def serialize_user(self, user: T) -> dict[str, Any]:
"""Create a dict from a user object."""
+ def get_user_from_token(self, token: str) -> BaseUser:
+ """Verify the JWT token is valid and create a user object from it if
valid."""
+ try:
+ payload: dict[str, Any] =
self._get_token_signer().verify_token(token)
+ return self.deserialize_user(payload)
+ except InvalidTokenError as e:
+ log.error("JWT token is not valid")
+ raise e
+
def get_jwt_token(self, user: T) -> str:
"""Return the JWT token from a user object."""
- signer = JWTSigner(
- secret_key=conf.get("api", "auth_jwt_secret"),
- expiration_time_in_seconds=conf.getint("api",
"auth_jwt_expiration_time"),
- audience="front-apis",
- )
- return signer.generate_signed_token(self.serialize_user(user))
+ return
self._get_token_signer().generate_signed_token(self.serialize_user(user))
def get_user_id(self) -> str | None:
"""Return the user ID associated to the user in session."""
@@ -448,3 +455,16 @@ class BaseAuthManager(Generic[T], LoggingMixin):
def register_views(self) -> None:
"""Register views specific to the auth manager."""
+
+ @staticmethod
+ def _get_token_signer():
+ """
+ Return the signer used to sign JWT token.
+
+ :meta private:
+ """
+ return JWTSigner(
+ secret_key=conf.get("api", "auth_jwt_secret"),
+ expiration_time_in_seconds=conf.getint("api",
"auth_jwt_expiration_time"),
+ audience="front-apis",
+ )
diff --git a/tests/api_fastapi/core_api/test_security.py
b/tests/api_fastapi/core_api/test_security.py
index 90bf3f647bb..b9e1c58aa20 100644
--- a/tests/api_fastapi/core_api/test_security.py
+++ b/tests/api_fastapi/core_api/test_security.py
@@ -43,39 +43,32 @@ class TestFastApiSecurity:
):
create_app()
- @patch("airflow.api_fastapi.core_api.security.get_signer")
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
- def test_get_user(self, mock_get_auth_manager, mock_get_signer):
+ def test_get_user(self, mock_get_auth_manager):
token_str = "test-token"
- user_dict = {"user": "XXXXXXXXX"}
user = SimpleAuthManagerUser(username="username", role="admin")
auth_manager = Mock()
- auth_manager.deserialize_user.return_value = user
+ auth_manager.get_user_from_token.return_value = user
mock_get_auth_manager.return_value = auth_manager
- signer = Mock()
- signer.verify_token.return_value = user_dict
- mock_get_signer.return_value = signer
-
result = get_user(token_str)
- signer.verify_token.assert_called_once_with(token_str)
- auth_manager.deserialize_user.assert_called_once_with(user_dict)
+ auth_manager.get_user_from_token.assert_called_once_with(token_str)
assert result == user
- @patch("airflow.api_fastapi.core_api.security.get_signer")
- def test_get_user_unsuccessful(self, mock_get_signer):
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ def test_get_user_unsuccessful(self, mock_get_auth_manager):
token_str = "test-token"
- signer = Mock()
- signer.verify_token.side_effect = InvalidTokenError()
- mock_get_signer.return_value = signer
+ auth_manager = Mock()
+ auth_manager.get_user_from_token.side_effect = InvalidTokenError()
+ mock_get_auth_manager.return_value = auth_manager
with pytest.raises(HTTPException, match="Forbidden"):
get_user(token_str)
- signer.verify_token.assert_called_once_with(token_str)
+ auth_manager.get_user_from_token.assert_called_once_with(token_str)
@patch("airflow.api_fastapi.core_api.security.get_auth_manager")
def test_requires_access_dag_authorized(self, mock_get_auth_manager):
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index d8d24a93f1d..b887382bfda 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -190,6 +190,23 @@ class TestBaseAuthManager:
def test_get_url_user_profile_return_none(self, auth_manager):
assert auth_manager.get_url_user_profile() is None
+ @patch("airflow.auth.managers.base_auth_manager.JWTSigner")
+ @patch.object(EmptyAuthManager, "deserialize_user")
+ def test_get_user_from_token(self, mock_deserialize_user, mock_jwt_signer,
auth_manager):
+ token = "token"
+ payload = {}
+ user = BaseAuthManagerUserTest(name="test")
+ signer = Mock()
+ signer.verify_token.return_value = payload
+ mock_jwt_signer.return_value = signer
+ mock_deserialize_user.return_value = user
+
+ result = auth_manager.get_user_from_token(token)
+
+ mock_deserialize_user.assert_called_once_with(payload)
+ signer.verify_token.assert_called_once_with(token)
+ assert result == user
+
@patch("airflow.auth.managers.base_auth_manager.JWTSigner")
@patch.object(EmptyAuthManager, "serialize_user")
def test_get_jwt_token(self, mock_serialize_user, mock_jwt_signer,
auth_manager):