This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a5313b32444 Add `get_user_from_token` method in auth manager interface 
(#46135)
a5313b32444 is described below

commit a5313b324445b552597a3659faac8f02fbcdca08
Author: Vincent <[email protected]>
AuthorDate: Mon Jan 27 13:00:16 2025 -0500

    Add `get_user_from_token` method in auth manager interface (#46135)
---
 airflow/api_fastapi/core_api/security.py      |  6 ++---
 airflow/auth/managers/base_auth_manager.py    | 32 ++++++++++++++++++++++-----
 tests/api_fastapi/core_api/test_security.py   | 25 ++++++++-------------
 tests/auth/managers/test_base_auth_manager.py | 17 ++++++++++++++
 4 files changed, 54 insertions(+), 26 deletions(-)

diff --git a/airflow/api_fastapi/core_api/security.py 
b/airflow/api_fastapi/core_api/security.py
index 695a01ad478..50c1337c95c 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 from functools import cache
-from typing import TYPE_CHECKING, Annotated, Any, Callable
+from typing import TYPE_CHECKING, Annotated, Callable
 
 from fastapi import Depends, HTTPException, Request, status
 from fastapi.security import OAuth2PasswordBearer
@@ -46,9 +46,7 @@ def get_signer() -> JWTSigner:
 
 def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
     try:
-        signer = get_signer()
-        payload: dict[str, Any] = signer.verify_token(token_str)
-        return get_auth_manager().deserialize_user(payload)
+        return get_auth_manager().get_user_from_token(token_str)
     except InvalidTokenError:
         raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
 
diff --git a/airflow/auth/managers/base_auth_manager.py 
b/airflow/auth/managers/base_auth_manager.py
index 2f83a1c6c06..ebe10d282c6 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -17,9 +17,11 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from abc import abstractmethod
 from typing import TYPE_CHECKING, Any, Generic, TypeVar
 
+from jwt import InvalidTokenError
 from sqlalchemy import select
 
 from airflow.auth.managers.models.base_user import BaseUser
@@ -61,6 +63,7 @@ if TYPE_CHECKING:
 # TODO: Move this inside once all providers drop Airflow 2.x support.
 ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
 
+log = logging.getLogger(__name__)
 T = TypeVar("T", bound=BaseUser)
 
 
@@ -102,14 +105,18 @@ class BaseAuthManager(Generic[T], LoggingMixin):
     def serialize_user(self, user: T) -> dict[str, Any]:
         """Create a dict from a user object."""
 
+    def get_user_from_token(self, token: str) -> BaseUser:
+        """Verify the JWT token is valid and create a user object from it if 
valid."""
+        try:
+            payload: dict[str, Any] = 
self._get_token_signer().verify_token(token)
+            return self.deserialize_user(payload)
+        except InvalidTokenError as e:
+            log.error("JWT token is not valid")
+            raise e
+
     def get_jwt_token(self, user: T) -> str:
         """Return the JWT token from a user object."""
-        signer = JWTSigner(
-            secret_key=conf.get("api", "auth_jwt_secret"),
-            expiration_time_in_seconds=conf.getint("api", 
"auth_jwt_expiration_time"),
-            audience="front-apis",
-        )
-        return signer.generate_signed_token(self.serialize_user(user))
+        return 
self._get_token_signer().generate_signed_token(self.serialize_user(user))
 
     def get_user_id(self) -> str | None:
         """Return the user ID associated to the user in session."""
@@ -448,3 +455,16 @@ class BaseAuthManager(Generic[T], LoggingMixin):
 
     def register_views(self) -> None:
         """Register views specific to the auth manager."""
+
+    @staticmethod
+    def _get_token_signer():
+        """
+        Return the signer used to sign JWT token.
+
+        :meta private:
+        """
+        return JWTSigner(
+            secret_key=conf.get("api", "auth_jwt_secret"),
+            expiration_time_in_seconds=conf.getint("api", 
"auth_jwt_expiration_time"),
+            audience="front-apis",
+        )
diff --git a/tests/api_fastapi/core_api/test_security.py 
b/tests/api_fastapi/core_api/test_security.py
index 90bf3f647bb..b9e1c58aa20 100644
--- a/tests/api_fastapi/core_api/test_security.py
+++ b/tests/api_fastapi/core_api/test_security.py
@@ -43,39 +43,32 @@ class TestFastApiSecurity:
         ):
             create_app()
 
-    @patch("airflow.api_fastapi.core_api.security.get_signer")
     @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
-    def test_get_user(self, mock_get_auth_manager, mock_get_signer):
+    def test_get_user(self, mock_get_auth_manager):
         token_str = "test-token"
-        user_dict = {"user": "XXXXXXXXX"}
         user = SimpleAuthManagerUser(username="username", role="admin")
 
         auth_manager = Mock()
-        auth_manager.deserialize_user.return_value = user
+        auth_manager.get_user_from_token.return_value = user
         mock_get_auth_manager.return_value = auth_manager
 
-        signer = Mock()
-        signer.verify_token.return_value = user_dict
-        mock_get_signer.return_value = signer
-
         result = get_user(token_str)
 
-        signer.verify_token.assert_called_once_with(token_str)
-        auth_manager.deserialize_user.assert_called_once_with(user_dict)
+        auth_manager.get_user_from_token.assert_called_once_with(token_str)
         assert result == user
 
-    @patch("airflow.api_fastapi.core_api.security.get_signer")
-    def test_get_user_unsuccessful(self, mock_get_signer):
+    @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+    def test_get_user_unsuccessful(self, mock_get_auth_manager):
         token_str = "test-token"
 
-        signer = Mock()
-        signer.verify_token.side_effect = InvalidTokenError()
-        mock_get_signer.return_value = signer
+        auth_manager = Mock()
+        auth_manager.get_user_from_token.side_effect = InvalidTokenError()
+        mock_get_auth_manager.return_value = auth_manager
 
         with pytest.raises(HTTPException, match="Forbidden"):
             get_user(token_str)
 
-        signer.verify_token.assert_called_once_with(token_str)
+        auth_manager.get_user_from_token.assert_called_once_with(token_str)
 
     @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
     def test_requires_access_dag_authorized(self, mock_get_auth_manager):
diff --git a/tests/auth/managers/test_base_auth_manager.py 
b/tests/auth/managers/test_base_auth_manager.py
index d8d24a93f1d..b887382bfda 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -190,6 +190,23 @@ class TestBaseAuthManager:
     def test_get_url_user_profile_return_none(self, auth_manager):
         assert auth_manager.get_url_user_profile() is None
 
+    @patch("airflow.auth.managers.base_auth_manager.JWTSigner")
+    @patch.object(EmptyAuthManager, "deserialize_user")
+    def test_get_user_from_token(self, mock_deserialize_user, mock_jwt_signer, 
auth_manager):
+        token = "token"
+        payload = {}
+        user = BaseAuthManagerUserTest(name="test")
+        signer = Mock()
+        signer.verify_token.return_value = payload
+        mock_jwt_signer.return_value = signer
+        mock_deserialize_user.return_value = user
+
+        result = auth_manager.get_user_from_token(token)
+
+        mock_deserialize_user.assert_called_once_with(payload)
+        signer.verify_token.assert_called_once_with(token)
+        assert result == user
+
     @patch("airflow.auth.managers.base_auth_manager.JWTSigner")
     @patch.object(EmptyAuthManager, "serialize_user")
     def test_get_jwt_token(self, mock_serialize_user, mock_jwt_signer, 
auth_manager):

Reply via email to