This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d5bd1344b62 Set up JWT token authentication in Fast APIs (#42634)
d5bd1344b62 is described below
commit d5bd1344b626b0a407e651380c061c363e9cab5a
Author: Vincent <[email protected]>
AuthorDate: Tue Nov 19 14:57:14 2024 -0500
Set up JWT token authentication in Fast APIs (#42634)
---
airflow/api_fastapi/app.py | 45 ++++++++++
airflow/api_fastapi/core_api/security.py | 77 +++++++++++++++++
airflow/auth/managers/base_auth_manager.py | 36 +++++---
.../auth/managers/simple/simple_auth_manager.py | 69 ++++++++++-----
airflow/auth/managers/simple/views/auth.py | 15 +++-
airflow/config_templates/config.yml | 20 +++++
airflow/configuration.py | 1 +
docs/spelling_wordlist.txt | 1 +
.../amazon/aws/auth_manager/aws_auth_manager.py | 3 +-
.../providers/fab/auth_manager/fab_auth_manager.py | 32 +++++--
.../fab/auth_manager/test_fab_auth_manager.py | 22 ++++-
tests/api_fastapi/core_api/test_security.py | 99 ++++++++++++++++++++++
.../managers/simple/test_simple_auth_manager.py | 15 +++-
tests/auth/managers/simple/views/test_auth.py | 9 +-
tests/auth/managers/test_base_auth_manager.py | 12 ++-
tests/core/test_configuration.py | 1 +
16 files changed, 402 insertions(+), 55 deletions(-)
diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py
index 9ddd97b6cbe..4bf6ae9f6b7 100644
--- a/airflow/api_fastapi/app.py
+++ b/airflow/api_fastapi/app.py
@@ -24,10 +24,14 @@ from starlette.routing import Mount
from airflow.api_fastapi.core_api.app import init_config, init_dag_bag,
init_plugins, init_views
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
+from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException
log = logging.getLogger(__name__)
app: FastAPI | None = None
+auth_manager: BaseAuthManager | None = None
@asynccontextmanager
@@ -57,6 +61,7 @@ def create_app(apps: str = "all") -> FastAPI:
init_dag_bag(app)
init_views(app)
init_plugins(app)
+ init_auth_manager()
if "execution" in apps_list or "all" in apps_list:
task_exec_api_app = create_task_execution_api_app(app)
@@ -79,3 +84,43 @@ def purge_cached_app() -> None:
"""Remove the cached version of the app in global state."""
global app
app = None
+
+
+def get_auth_manager_cls() -> type[BaseAuthManager]:
+ """
+ Return just the auth manager class without initializing it.
+
+ Useful to save execution time if only static methods need to be called.
+ """
+ auth_manager_cls = conf.getimport(section="core", key="auth_manager")
+
+ if not auth_manager_cls:
+ raise AirflowConfigException(
+ "No auth manager defined in the config. "
+ "Please specify one using section/key [core/auth_manager]."
+ )
+
+ return auth_manager_cls
+
+
+def init_auth_manager() -> BaseAuthManager:
+ """
+ Initialize the auth manager.
+
+ Import the user manager class and instantiate it.
+ """
+ global auth_manager
+ auth_manager_cls = get_auth_manager_cls()
+ auth_manager = auth_manager_cls()
+ auth_manager.init()
+ return auth_manager
+
+
+def get_auth_manager() -> BaseAuthManager:
+ """Return the auth manager, provided it's been initialized before."""
+ if auth_manager is None:
+ raise RuntimeError(
+ "Auth Manager has not been initialized yet. "
+ "The `init_auth_manager` method needs to be called first."
+ )
+ return auth_manager
diff --git a/airflow/api_fastapi/core_api/security.py
b/airflow/api_fastapi/core_api/security.py
new file mode 100644
index 00000000000..ede628e04aa
--- /dev/null
+++ b/airflow/api_fastapi/core_api/security.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cache
+from typing import Any, Callable
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+from jwt import InvalidTokenError
+from typing_extensions import Annotated
+
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.auth.managers.base_auth_manager import ResourceMethod
+from airflow.auth.managers.models.base_user import BaseUser
+from airflow.auth.managers.models.resource_details import DagAccessEntity,
DagDetails
+from airflow.configuration import conf
+from airflow.utils.jwt_signer import JWTSigner
+
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+@cache
+def get_signer() -> JWTSigner:
+ return JWTSigner(
+ secret_key=conf.get("api", "auth_jwt_secret"),
+ expiration_time_in_seconds=conf.getint("api",
"auth_jwt_expiration_time"),
+ audience="front-apis",
+ )
+
+
+def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
+ try:
+ signer = get_signer()
+ payload: dict[str, Any] = signer.verify_token(token_str)
+ return get_auth_manager().deserialize_user(payload)
+ except InvalidTokenError:
+ raise HTTPException(403, "Forbidden")
+
+
+def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity
| None = None) -> Callable:
+ def inner(
+ dag_id: str | None = None,
+ user: Annotated[BaseUser | None, Depends(get_user)] = None,
+ ) -> None:
+ def callback():
+ return get_auth_manager().is_authorized_dag(
+ method=method, access_entity=access_entity,
details=DagDetails(id=dag_id), user=user
+ )
+
+ _requires_access(
+ is_authorized_callback=callback,
+ )
+
+ return inner
+
+
+def _requires_access(
+ *,
+ is_authorized_callback: Callable[[], bool],
+) -> None:
+ if not is_authorized_callback():
+ raise HTTPException(403, "Forbidden")
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index 69b5969c827..028d4dadb13 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -19,11 +19,12 @@ from __future__ import annotations
from abc import abstractmethod
from functools import cached_property
-from typing import TYPE_CHECKING, Container, Literal, Sequence
+from typing import TYPE_CHECKING, Any, Container, Generic, Literal, Sequence,
TypeVar
from flask_appbuilder.menu import MenuItem
from sqlalchemy import select
+from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
DagDetails,
)
@@ -37,7 +38,6 @@ if TYPE_CHECKING:
from flask import Blueprint
from sqlalchemy.orm import Session
- from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
@@ -59,8 +59,10 @@ if TYPE_CHECKING:
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
+T = TypeVar("T", bound=BaseUser)
-class BaseAuthManager(LoggingMixin):
+
+class BaseAuthManager(Generic[T], LoggingMixin):
"""
Class to derive in order to implement concrete auth managers.
@@ -69,7 +71,7 @@ class BaseAuthManager(LoggingMixin):
:param appbuilder: the flask app builder
"""
- def __init__(self, appbuilder: AirflowAppBuilder) -> None:
+ def __init__(self, appbuilder: AirflowAppBuilder | None = None) -> None:
super().__init__()
self.appbuilder = appbuilder
@@ -93,9 +95,17 @@ class BaseAuthManager(LoggingMixin):
return self.get_user_name()
@abstractmethod
- def get_user(self) -> BaseUser | None:
+ def get_user(self) -> T | None:
"""Return the user associated to the user in session."""
+ @abstractmethod
+ def deserialize_user(self, token: dict[str, Any]) -> T:
+ """Create a user object from dict."""
+
+ @abstractmethod
+ def serialize_user(self, user: T) -> dict[str, Any]:
+ """Create a dict from a user object."""
+
def get_user_id(self) -> str | None:
"""Return the user ID associated to the user in session."""
user = self.get_user()
@@ -132,7 +142,7 @@ class BaseAuthManager(LoggingMixin):
*,
method: ResourceMethod,
details: ConfigurationDetails | None = None,
- user: BaseUser | None = None,
+ user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on
configuration.
@@ -148,7 +158,7 @@ class BaseAuthManager(LoggingMixin):
*,
method: ResourceMethod,
details: ConnectionDetails | None = None,
- user: BaseUser | None = None,
+ user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a
connection.
@@ -165,7 +175,7 @@ class BaseAuthManager(LoggingMixin):
method: ResourceMethod,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
- user: BaseUser | None = None,
+ user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a
DAG.
@@ -183,7 +193,7 @@ class BaseAuthManager(LoggingMixin):
*,
method: ResourceMethod,
details: AssetDetails | None = None,
- user: BaseUser | None = None,
+ user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on an
asset.
@@ -199,7 +209,7 @@ class BaseAuthManager(LoggingMixin):
*,
method: ResourceMethod,
details: PoolDetails | None = None,
- user: BaseUser | None = None,
+ user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a
pool.
@@ -215,7 +225,7 @@ class BaseAuthManager(LoggingMixin):
*,
method: ResourceMethod,
details: VariableDetails | None = None,
- user: BaseUser | None = None,
+ user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a
variable.
@@ -230,7 +240,7 @@ class BaseAuthManager(LoggingMixin):
self,
*,
access_view: AccessView,
- user: BaseUser | None = None,
+ user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to access a read-only state of
the installation.
@@ -241,7 +251,7 @@ class BaseAuthManager(LoggingMixin):
@abstractmethod
def is_authorized_custom_view(
- self, *, method: ResourceMethod | str, resource_name: str, user:
BaseUser | None = None
+ self, *, method: ResourceMethod | str, resource_name: str, user: T |
None = None
):
"""
Return whether the user is authorized to perform a given action on a
custom view.
diff --git a/airflow/auth/managers/simple/simple_auth_manager.py
b/airflow/auth/managers/simple/simple_auth_manager.py
index 78dccf7c2a9..48baa02e7c7 100644
--- a/airflow/auth/managers/simple/simple_auth_manager.py
+++ b/airflow/auth/managers/simple/simple_auth_manager.py
@@ -22,7 +22,7 @@ import os
import random
from collections import namedtuple
from enum import Enum
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from flask import session, url_for
from termcolor import colored
@@ -33,7 +33,6 @@ from airflow.auth.managers.simple.views.auth import
SimpleAuthManagerAuthenticat
from airflow.configuration import AIRFLOW_HOME
if TYPE_CHECKING:
- from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
AccessView,
AssetDetails,
@@ -68,7 +67,7 @@ class
SimpleAuthManagerRole(namedtuple("SimpleAuthManagerRole", "name order"), E
ADMIN = "ADMIN", 3
-class SimpleAuthManager(BaseAuthManager):
+class SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
"""
Simple auth manager.
@@ -89,6 +88,8 @@ class SimpleAuthManager(BaseAuthManager):
)
def init(self) -> None:
+ if not self.appbuilder:
+ return
user_passwords_from_file = {}
# Read passwords from file
@@ -115,8 +116,9 @@ class SimpleAuthManager(BaseAuthManager):
file.write(json.dumps(self.passwords))
def is_logged_in(self) -> bool:
- return "user" in session or self.appbuilder.get_app.config.get(
- "SIMPLE_AUTH_MANAGER_ALL_ADMINS", False
+ return "user" in session or (
+ self.appbuilder is not None
+ and
self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False)
)
def get_url_login(self, **kwargs) -> str:
@@ -128,28 +130,34 @@ class SimpleAuthManager(BaseAuthManager):
def get_user(self) -> SimpleAuthManagerUser | None:
if not self.is_logged_in():
return None
- if
self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False):
+ if self.appbuilder and
self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False):
return SimpleAuthManagerUser(username="anonymous", role="admin")
else:
return session["user"]
+ def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser:
+ return SimpleAuthManagerUser(username=token["username"],
role=token["role"])
+
+ def serialize_user(self, user: SimpleAuthManagerUser) -> dict[str, Any]:
+ return {"username": user.username, "role": user.role}
+
def is_authorized_configuration(
self,
*,
method: ResourceMethod,
details: ConfigurationDetails | None = None,
- user: BaseUser | None = None,
+ user: SimpleAuthManagerUser | None = None,
) -> bool:
- return self._is_authorized(method=method,
allow_role=SimpleAuthManagerRole.OP)
+ return self._is_authorized(method=method,
allow_role=SimpleAuthManagerRole.OP, user=user)
def is_authorized_connection(
self,
*,
method: ResourceMethod,
details: ConnectionDetails | None = None,
- user: BaseUser | None = None,
+ user: SimpleAuthManagerUser | None = None,
) -> bool:
- return self._is_authorized(method=method,
allow_role=SimpleAuthManagerRole.OP)
+ return self._is_authorized(method=method,
allow_role=SimpleAuthManagerRole.OP, user=user)
def is_authorized_dag(
self,
@@ -157,46 +165,65 @@ class SimpleAuthManager(BaseAuthManager):
method: ResourceMethod,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
- user: BaseUser | None = None,
+ user: SimpleAuthManagerUser | None = None,
) -> bool:
return self._is_authorized(
method=method,
allow_get_role=SimpleAuthManagerRole.VIEWER,
allow_role=SimpleAuthManagerRole.USER,
+ user=user,
)
def is_authorized_asset(
- self, *, method: ResourceMethod, details: AssetDetails | None = None,
user: BaseUser | None = None
+ self,
+ *,
+ method: ResourceMethod,
+ details: AssetDetails | None = None,
+ user: SimpleAuthManagerUser | None = None,
) -> bool:
return self._is_authorized(
method=method,
allow_get_role=SimpleAuthManagerRole.VIEWER,
allow_role=SimpleAuthManagerRole.OP,
+ user=user,
)
def is_authorized_pool(
- self, *, method: ResourceMethod, details: PoolDetails | None = None,
user: BaseUser | None = None
+ self,
+ *,
+ method: ResourceMethod,
+ details: PoolDetails | None = None,
+ user: SimpleAuthManagerUser | None = None,
) -> bool:
return self._is_authorized(
method=method,
allow_get_role=SimpleAuthManagerRole.VIEWER,
allow_role=SimpleAuthManagerRole.OP,
+ user=user,
)
def is_authorized_variable(
- self, *, method: ResourceMethod, details: VariableDetails | None =
None, user: BaseUser | None = None
+ self,
+ *,
+ method: ResourceMethod,
+ details: VariableDetails | None = None,
+ user: SimpleAuthManagerUser | None = None,
) -> bool:
- return self._is_authorized(method=method,
allow_role=SimpleAuthManagerRole.OP)
+ return self._is_authorized(method=method,
allow_role=SimpleAuthManagerRole.OP, user=user)
- def is_authorized_view(self, *, access_view: AccessView, user: BaseUser |
None = None) -> bool:
- return self._is_authorized(method="GET",
allow_role=SimpleAuthManagerRole.VIEWER)
+ def is_authorized_view(
+ self, *, access_view: AccessView, user: SimpleAuthManagerUser | None =
None
+ ) -> bool:
+ return self._is_authorized(method="GET",
allow_role=SimpleAuthManagerRole.VIEWER, user=user)
def is_authorized_custom_view(
- self, *, method: ResourceMethod | str, resource_name: str, user:
BaseUser | None = None
+ self, *, method: ResourceMethod | str, resource_name: str, user:
SimpleAuthManagerUser | None = None
):
- return self._is_authorized(method="GET",
allow_role=SimpleAuthManagerRole.VIEWER)
+ return self._is_authorized(method="GET",
allow_role=SimpleAuthManagerRole.VIEWER, user=user)
def register_views(self) -> None:
+ if not self.appbuilder:
+ return
self.appbuilder.add_view_no_menu(
SimpleAuthManagerAuthenticationViews(
users=self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_USERS", []),
@@ -210,6 +237,7 @@ class SimpleAuthManager(BaseAuthManager):
method: ResourceMethod,
allow_role: SimpleAuthManagerRole,
allow_get_role: SimpleAuthManagerRole | None = None,
+ user: SimpleAuthManagerUser | None = None,
):
"""
Return whether the user is authorized to access a given resource.
@@ -219,8 +247,9 @@ class SimpleAuthManager(BaseAuthManager):
equal than this role, they have access
:param allow_get_role: minimal role giving access to the resource, if
the user's role is greater or
equal than this role, they have access. If not provided,
``allow_role`` is used
+ :param user: the user to check the authorization for. If not provided,
the current user is used
"""
- user = self.get_user()
+ user = user or self.get_user()
if not user:
return False
diff --git a/airflow/auth/managers/simple/views/auth.py
b/airflow/auth/managers/simple/views/auth.py
index 8ab02d0a015..6e4cf0c3994 100644
--- a/airflow/auth/managers/simple/views/auth.py
+++ b/airflow/auth/managers/simple/views/auth.py
@@ -23,8 +23,10 @@ from flask_appbuilder import expose
from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.configuration import conf
+from airflow.utils.jwt_signer import JWTSigner
from airflow.utils.state import State
from airflow.www.app import csrf
+from airflow.www.extensions.init_auth_manager import get_auth_manager
from airflow.www.views import AirflowBaseView
logger = logging.getLogger(__name__)
@@ -80,9 +82,18 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
if not username or not password or len(found_users) == 0:
return
redirect(url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"]))
- session["user"] = SimpleAuthManagerUser(
+ user = SimpleAuthManagerUser(
username=username,
role=found_users[0]["role"],
)
+ # Will be removed once Airflow uses the new UI
+ session["user"] = user
- return redirect(url_for("Airflow.index"))
+ signer = JWTSigner(
+ secret_key=conf.get("api", "auth_jwt_secret"),
+ expiration_time_in_seconds=conf.getint("api",
"auth_jwt_expiration_time"),
+ audience="front-apis",
+ )
+ token =
signer.generate_signed_token(get_auth_manager().serialize_user(user))
+
+ return redirect(url_for("Airflow.index", token=token))
diff --git a/airflow/config_templates/config.yml
b/airflow/config_templates/config.yml
index 04ac0d802a9..4a7fe5f1897 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1404,6 +1404,26 @@ api:
version_added: 2.7.0
example: ~
default: "False"
+ auth_jwt_secret:
+ description: |
+ Secret key used to encode and decode JWT tokens to authenticate to
public and private APIs.
+ It should be as random as possible. However, when running more than 1
instances of API services,
+ make sure all of them use the same ``jwt_secret`` otherwise calls will
fail on authentication.
+ version_added: 3.0.0
+ type: string
+ sensitive: true
+ example: ~
+ default: "{JWT_SECRET_KEY}"
+ auth_jwt_expiration_time:
+ description: |
+ Number in seconds until the JWT token used for authentication expires.
When the token expires,
+ all API calls using this token will fail on authentication.
+ Make sure that time on ALL the machines that you run airflow
components on is synchronized
+ (for example using ntpd) otherwise you might get "forbidden" errors.
+ version_added: 3.0.0
+ type: integer
+ example: ~
+ default: "86400"
lineage:
description: ~
options:
diff --git a/airflow/configuration.py b/airflow/configuration.py
index bc808d6bfc2..521af6cbe32 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -2143,6 +2143,7 @@ else:
TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, "plugins")
SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8")
+JWT_SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8")
FERNET_KEY = "" # Set only if needed when generating a new file
WEBSERVER_CONFIG = "" # Set by initialize_config
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 32787428cc5..b2df194de0d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -898,6 +898,7 @@ Jupyter
jupyter
jupytercmd
JWT
+jwt
Kafka
kafka
Kalibrr
diff --git
a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 414907961ce..8271aad3302 100644
---
a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++
b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -430,7 +430,8 @@ class AwsAuthManager(BaseAuthManager):
]
def register_views(self) -> None:
- self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews())
+ if self.appbuilder:
+
self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews())
@staticmethod
def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest:
diff --git
a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index e93e440f5dd..adebe30dd53 100644
--- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -21,7 +21,7 @@ import argparse
import warnings
from functools import cached_property
from pathlib import Path
-from typing import TYPE_CHECKING, Container
+from typing import TYPE_CHECKING, Any, Container
import packaging.version
from connexion import FlaskApi
@@ -82,7 +82,7 @@ from airflow.security.permissions import (
RESOURCE_WEBSITE,
RESOURCE_XCOM,
)
-from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.yaml import safe_load
from airflow.version import version
from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED
@@ -131,13 +131,18 @@ _MAP_ACCESS_VIEW_TO_FAB_RESOURCE_TYPE = {
}
-class FabAuthManager(BaseAuthManager):
+class FabAuthManager(BaseAuthManager[User]):
"""
Flask-AppBuilder auth manager.
This auth manager is responsible for providing a backward compatible user
management experience to users.
"""
+ def init(self) -> None:
+ """Run operations when Airflow is initializing."""
+ if self.appbuilder:
+ self._sync_appbuilder_roles()
+
@staticmethod
def get_cli_commands() -> list[CLICommand]:
"""Vends CLI commands to be included in Airflow CLI."""
@@ -199,9 +204,12 @@ class FabAuthManager(BaseAuthManager):
return current_user
- def init(self) -> None:
- """Run operations when Airflow is initializing."""
- self._sync_appbuilder_roles()
+ def deserialize_user(self, token: dict[str, Any]) -> User:
+ with create_session() as session:
+ return session.get(User, token["id"])
+
+ def serialize_user(self, user: User) -> dict[str, Any]:
+ return {"id": user.id}
def is_logged_in(self) -> bool:
"""Return whether the user is logged in."""
@@ -209,8 +217,10 @@ class FabAuthManager(BaseAuthManager):
if Version(Version(version).base_version) < Version("3.0.0"):
return not user.is_anonymous and user.is_active
else:
- return self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC",
None) or (
- not user.is_anonymous and user.is_active
+ return (
+ self.appbuilder
+ and self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC",
None)
+ or (not user.is_anonymous and user.is_active)
)
def is_authorized_configuration(
@@ -376,6 +386,9 @@ class FabAuthManager(BaseAuthManager):
FabAirflowSecurityManagerOverride,
)
+ if not self.appbuilder:
+ raise AirflowException("AppBuilder is not initialized.")
+
sm_from_config =
self.appbuilder.get_app.config.get("SECURITY_MANAGER_CLASS")
if sm_from_config:
if not issubclass(sm_from_config,
FabAirflowSecurityManagerOverride):
@@ -547,6 +560,9 @@ class FabAuthManager(BaseAuthManager):
:meta private:
"""
+ if not self.appbuilder:
+ raise AirflowException("AppBuilder is not initialized.")
+
if "." in dag_id and hasattr(DagModel, "root_dag_id"):
return self.appbuilder.get_session.scalar(
select(DagModel.dag_id,
DagModel.root_dag_id).where(DagModel.dag_id == dag_id).limit(1)
diff --git a/providers/tests/fab/auth_manager/test_fab_auth_manager.py
b/providers/tests/fab/auth_manager/test_fab_auth_manager.py
index d298f7667ea..1994e910f9e 100644
--- a/providers/tests/fab/auth_manager/test_fab_auth_manager.py
+++ b/providers/tests/fab/auth_manager/test_fab_auth_manager.py
@@ -27,6 +27,8 @@ from flask import Flask, g
from airflow.exceptions import AirflowConfigException, AirflowException
+from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import
create_user
+
try:
from airflow.auth.managers.models.resource_details import AccessView,
DagAccessEntity, DagDetails
except ImportError:
@@ -141,6 +143,16 @@ class TestFabAuthManager:
with user_set(minimal_app_for_auth_api, flask_g_user):
assert auth_manager.get_user() == flask_g_user
+ def test_deserialize_user(self, flask_app, auth_manager_with_appbuilder):
+ user = create_user(flask_app, "test")
+ result = auth_manager_with_appbuilder.deserialize_user({"id": user.id})
+ assert user == result
+
+ def test_serialize_user(self, flask_app, auth_manager_with_appbuilder):
+ user = create_user(flask_app, "test")
+ result = auth_manager_with_appbuilder.serialize_user(user)
+ assert result == {"id": user.id}
+
@pytest.mark.db_test
@mock.patch.object(FabAuthManager, "get_user")
def test_is_logged_in(self, mock_get_user, auth_manager_with_appbuilder):
@@ -338,11 +350,17 @@ class TestFabAuthManager:
],
)
def test_is_authorized_dag(
- self, method, dag_access_entity, dag_details, user_permissions,
expected_result, auth_manager
+ self,
+ method,
+ dag_access_entity,
+ dag_details,
+ user_permissions,
+ expected_result,
+ auth_manager_with_appbuilder,
):
user = Mock()
user.perms = user_permissions
- result = auth_manager.is_authorized_dag(
+ result = auth_manager_with_appbuilder.is_authorized_dag(
method=method, access_entity=dag_access_entity,
details=dag_details, user=user
)
assert result == expected_result
diff --git a/tests/api_fastapi/core_api/test_security.py
b/tests/api_fastapi/core_api/test_security.py
new file mode 100644
index 00000000000..90bf3f647bb
--- /dev/null
+++ b/tests/api_fastapi/core_api/test_security.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+from fastapi import HTTPException
+from jwt import InvalidTokenError
+
+from airflow.api_fastapi.app import create_app
+from airflow.api_fastapi.core_api.security import get_user, requires_access_dag
+from airflow.auth.managers.models.resource_details import DagAccessEntity
+from airflow.auth.managers.simple.user import SimpleAuthManagerUser
+
+from tests_common.test_utils.config import conf_vars
+
+
+class TestFastApiSecurity:
+ @classmethod
+ def setup_class(cls):
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager",
+ }
+ ):
+ create_app()
+
+ @patch("airflow.api_fastapi.core_api.security.get_signer")
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ def test_get_user(self, mock_get_auth_manager, mock_get_signer):
+ token_str = "test-token"
+ user_dict = {"user": "XXXXXXXXX"}
+ user = SimpleAuthManagerUser(username="username", role="admin")
+
+ auth_manager = Mock()
+ auth_manager.deserialize_user.return_value = user
+ mock_get_auth_manager.return_value = auth_manager
+
+ signer = Mock()
+ signer.verify_token.return_value = user_dict
+ mock_get_signer.return_value = signer
+
+ result = get_user(token_str)
+
+ signer.verify_token.assert_called_once_with(token_str)
+ auth_manager.deserialize_user.assert_called_once_with(user_dict)
+ assert result == user
+
+ @patch("airflow.api_fastapi.core_api.security.get_signer")
+ def test_get_user_unsuccessful(self, mock_get_signer):
+ token_str = "test-token"
+
+ signer = Mock()
+ signer.verify_token.side_effect = InvalidTokenError()
+ mock_get_signer.return_value = signer
+
+ with pytest.raises(HTTPException, match="Forbidden"):
+ get_user(token_str)
+
+ signer.verify_token.assert_called_once_with(token_str)
+
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ def test_requires_access_dag_authorized(self, mock_get_auth_manager):
+ auth_manager = Mock()
+ auth_manager.is_authorized_dag.return_value = True
+ mock_get_auth_manager.return_value = auth_manager
+
+ requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock())
+
+ auth_manager.is_authorized_dag.assert_called_once()
+
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ def test_requires_access_dag_unauthorized(self, mock_get_auth_manager):
+ auth_manager = Mock()
+ auth_manager.is_authorized_dag.return_value = False
+ mock_get_auth_manager.return_value = auth_manager
+
+ with pytest.raises(HTTPException, match="Forbidden"):
+ requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock())
+
+ auth_manager.is_authorized_dag.assert_called_once()
diff --git a/tests/auth/managers/simple/test_simple_auth_manager.py
b/tests/auth/managers/simple/test_simple_auth_manager.py
index 434c0d60fcc..07289f6f002 100644
--- a/tests/auth/managers/simple/test_simple_auth_manager.py
+++ b/tests/auth/managers/simple/test_simple_auth_manager.py
@@ -132,6 +132,16 @@ class TestSimpleAuthManager:
assert result is None
+ def test_deserialize_user(self, auth_manager):
+ result = auth_manager.deserialize_user({"username": "test", "role":
"admin"})
+ assert result.username == "test"
+ assert result.role == "admin"
+
+ def test_serialize_user(self, auth_manager):
+ user = SimpleAuthManagerUser(username="test", role="admin")
+ result = auth_manager.serialize_user(user)
+ assert result == {"username": "test", "role": "admin"}
+
@pytest.mark.db_test
@patch.object(SimpleAuthManager, "is_logged_in")
@pytest.mark.parametrize(
@@ -280,10 +290,7 @@ class TestSimpleAuthManager:
assert getattr(auth_manager_with_appbuilder, api)(method=method)
is result
@pytest.mark.db_test
- @patch(
-
"airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value",
return_value="test"
- )
- def test_register_views(self, _, auth_manager_with_appbuilder):
+ def test_register_views(self, auth_manager_with_appbuilder):
auth_manager_with_appbuilder.appbuilder.add_view_no_menu = Mock()
auth_manager_with_appbuilder.register_views()
auth_manager_with_appbuilder.appbuilder.add_view_no_menu.assert_called_once()
diff --git a/tests/auth/managers/simple/views/test_auth.py
b/tests/auth/managers/simple/views/test_auth.py
index cd4628c2d02..e3e7b29e2cc 100644
--- a/tests/auth/managers/simple/views/test_auth.py
+++ b/tests/auth/managers/simple/views/test_auth.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import json
+from unittest.mock import Mock, patch
import pytest
from flask import session, url_for
@@ -67,11 +68,15 @@ class TestSimpleAuthManagerAuthenticationViews:
"username, password, is_successful",
[("test", "test", True), ("test", "test2", False), ("", "", False)],
)
- def test_login_submit(self, simple_app, username, password, is_successful):
+ @patch("airflow.auth.managers.simple.views.auth.JWTSigner")
+ def test_login_submit(self, mock_jwt_signer, simple_app, username,
password, is_successful):
+ signer = Mock()
+ signer.generate_signed_token.return_value = "token"
+ mock_jwt_signer.return_value = signer
with simple_app.test_client() as client:
response = client.post("/login_submit", data={"username":
username, "password": password})
assert response.status_code == 302
if is_successful:
- assert response.location == url_for("Airflow.index")
+ assert response.location == url_for("Airflow.index",
token="token")
else:
assert response.location ==
url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"])
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index 82efe20048b..c62076a4654 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -16,13 +16,14 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock, Mock, patch
import pytest
from flask_appbuilder.menu import Menu
from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
+from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
ConnectionDetails,
DagDetails,
@@ -32,7 +33,6 @@ from airflow.auth.managers.models.resource_details import (
from airflow.exceptions import AirflowException
if TYPE_CHECKING:
- from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
AccessView,
AssetDetails,
@@ -41,10 +41,16 @@ if TYPE_CHECKING:
)
-class EmptyAuthManager(BaseAuthManager):
+class EmptyAuthManager(BaseAuthManager[BaseUser]):
def get_user(self) -> BaseUser:
raise NotImplementedError()
+ def deserialize_user(self, token: dict[str, Any]) -> BaseUser:
+ raise NotImplementedError()
+
+ def serialize_user(self, user: BaseUser) -> dict[str, Any]:
+ raise NotImplementedError()
+
def is_authorized_configuration(
self,
*,
diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py
index 2167d71ed5c..1fb3b270573 100644
--- a/tests/core/test_configuration.py
+++ b/tests/core/test_configuration.py
@@ -1584,6 +1584,7 @@ def test_sensitive_values():
("database", "sql_alchemy_conn"),
("core", "fernet_key"),
("core", "internal_api_secret_key"),
+ ("api", "auth_jwt_secret"),
("webserver", "secret_key"),
("secrets", "backend_kwargs"),
("sentry", "sentry_dsn"),